]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump peter-evans/create-or-update-comment from 2.1.0 to 2.1.1 (#3548)
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache, get_cache_info, read_cache, write_cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import (
67     FUTURE_FLAG_TO_FEATURE,
68     VERSION_TO_FEATURES,
69     Feature,
70     Mode,
71     TargetVersion,
72     supports_feature,
73 )
74 from black.nodes import (
75     STARS,
76     is_number_token,
77     is_simple_decorator_expression,
78     is_string_token,
79     syms,
80 )
81 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
82 from black.parsing import InvalidInput  # noqa F401
83 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
84 from black.report import Changed, NothingChanged, Report
85 from black.trans import iter_fexpr_spans
86 from blib2to3.pgen2 import token
87 from blib2to3.pytree import Leaf, Node
88
89 COMPILED = Path(__file__).suffix in (".pyd", ".so")
90
91 # types
92 FileContent = str
93 Encoding = str
94 NewLine = str
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102     COLOR_DIFF = 4
103
104     @classmethod
105     def from_configuration(
106         cls, *, check: bool, diff: bool, color: bool = False
107     ) -> "WriteBack":
108         if check and not diff:
109             return cls.CHECK
110
111         if diff and color:
112             return cls.COLOR_DIFF
113
114         return cls.DIFF if diff else cls.YES
115
116
117 # Legacy name, left for integrations.
118 FileMode = Mode
119
120
121 def read_pyproject_toml(
122     ctx: click.Context, param: click.Parameter, value: Optional[str]
123 ) -> Optional[str]:
124     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
125
126     Returns the path to a successfully found and read configuration file, None
127     otherwise.
128     """
129     if not value:
130         value = find_pyproject_toml(ctx.params.get("src", ()))
131         if value is None:
132             return None
133
134     try:
135         config = parse_pyproject_toml(value)
136     except (OSError, ValueError) as e:
137         raise click.FileError(
138             filename=value, hint=f"Error reading configuration file: {e}"
139         ) from None
140
141     if not config:
142         return None
143     else:
144         # Sanitize the values to be Click friendly. For more information please see:
145         # https://github.com/psf/black/issues/1458
146         # https://github.com/pallets/click/issues/1567
147         config = {
148             k: str(v) if not isinstance(v, (list, dict)) else v
149             for k, v in config.items()
150         }
151
152     target_version = config.get("target_version")
153     if target_version is not None and not isinstance(target_version, list):
154         raise click.BadOptionUsage(
155             "target-version", "Config key target-version must be a list"
156         )
157
158     default_map: Dict[str, Any] = {}
159     if ctx.default_map:
160         default_map.update(ctx.default_map)
161     default_map.update(config)
162
163     ctx.default_map = default_map
164     return value
165
166
167 def target_version_option_callback(
168     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
169 ) -> List[TargetVersion]:
170     """Compute the target versions from a --target-version flag.
171
172     This is its own function because mypy couldn't infer the type correctly
173     when it was a lambda, causing mypyc trouble.
174     """
175     return [TargetVersion[val.upper()] for val in v]
176
177
178 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
179     """Compile a regular expression string in `regex`.
180
181     If it contains newlines, use verbose mode.
182     """
183     if "\n" in regex:
184         regex = "(?x)" + regex
185     compiled: Pattern[str] = re.compile(regex)
186     return compiled
187
188
189 def validate_regex(
190     ctx: click.Context,
191     param: click.Parameter,
192     value: Optional[str],
193 ) -> Optional[Pattern[str]]:
194     try:
195         return re_compile_maybe_verbose(value) if value is not None else None
196     except re.error as e:
197         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
198
199
200 @click.command(
201     context_settings={"help_option_names": ["-h", "--help"]},
202     # While Click does set this field automatically using the docstring, mypyc
203     # (annoyingly) strips 'em so we need to set it here too.
204     help="The uncompromising code formatter.",
205 )
206 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
207 @click.option(
208     "-l",
209     "--line-length",
210     type=int,
211     default=DEFAULT_LINE_LENGTH,
212     help="How many characters per line to allow.",
213     show_default=True,
214 )
215 @click.option(
216     "-t",
217     "--target-version",
218     type=click.Choice([v.name.lower() for v in TargetVersion]),
219     callback=target_version_option_callback,
220     multiple=True,
221     help=(
222         "Python versions that should be supported by Black's output. By default, Black"
223         " will try to infer this from the project metadata in pyproject.toml. If this"
224         " does not yield conclusive results, Black will use per-file auto-detection."
225     ),
226 )
227 @click.option(
228     "--pyi",
229     is_flag=True,
230     help=(
231         "Format all input files like typing stubs regardless of file extension (useful"
232         " when piping source on standard input)."
233     ),
234 )
235 @click.option(
236     "--ipynb",
237     is_flag=True,
238     help=(
239         "Format all input files like Jupyter Notebooks regardless of file extension "
240         "(useful when piping source on standard input)."
241     ),
242 )
243 @click.option(
244     "--python-cell-magics",
245     multiple=True,
246     help=(
247         "When processing Jupyter Notebooks, add the given magic to the list"
248         f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
249         " Useful for formatting cells with custom python magics."
250     ),
251     default=[],
252 )
253 @click.option(
254     "-x",
255     "--skip-source-first-line",
256     is_flag=True,
257     help="Skip the first line of the source code.",
258 )
259 @click.option(
260     "-S",
261     "--skip-string-normalization",
262     is_flag=True,
263     help="Don't normalize string quotes or prefixes.",
264 )
265 @click.option(
266     "-C",
267     "--skip-magic-trailing-comma",
268     is_flag=True,
269     help="Don't use trailing commas as a reason to split lines.",
270 )
271 @click.option(
272     "--experimental-string-processing",
273     is_flag=True,
274     hidden=True,
275     help="(DEPRECATED and now included in --preview) Normalize string literals.",
276 )
277 @click.option(
278     "--preview",
279     is_flag=True,
280     help=(
281         "Enable potentially disruptive style changes that may be added to Black's main"
282         " functionality in the next major release."
283     ),
284 )
285 @click.option(
286     "--check",
287     is_flag=True,
288     help=(
289         "Don't write the files back, just return the status. Return code 0 means"
290         " nothing would change. Return code 1 means some files would be reformatted."
291         " Return code 123 means there was an internal error."
292     ),
293 )
294 @click.option(
295     "--diff",
296     is_flag=True,
297     help="Don't write the files back, just output a diff for each file on stdout.",
298 )
299 @click.option(
300     "--color/--no-color",
301     is_flag=True,
302     help="Show colored diff. Only applies when `--diff` is given.",
303 )
304 @click.option(
305     "--fast/--safe",
306     is_flag=True,
307     help="If --fast given, skip temporary sanity checks. [default: --safe]",
308 )
309 @click.option(
310     "--required-version",
311     type=str,
312     help=(
313         "Require a specific version of Black to be running (useful for unifying results"
314         " across many environments e.g. with a pyproject.toml file). It can be"
315         " either a major version number or an exact version."
316     ),
317 )
318 @click.option(
319     "--include",
320     type=str,
321     default=DEFAULT_INCLUDES,
322     callback=validate_regex,
323     help=(
324         "A regular expression that matches files and directories that should be"
325         " included on recursive searches. An empty value means all files are included"
326         " regardless of the name. Use forward slashes for directories on all platforms"
327         " (Windows, too). Exclusions are calculated first, inclusions later."
328     ),
329     show_default=True,
330 )
331 @click.option(
332     "--exclude",
333     type=str,
334     callback=validate_regex,
335     help=(
336         "A regular expression that matches files and directories that should be"
337         " excluded on recursive searches. An empty value means no paths are excluded."
338         " Use forward slashes for directories on all platforms (Windows, too)."
339         " Exclusions are calculated first, inclusions later. [default:"
340         f" {DEFAULT_EXCLUDES}]"
341     ),
342     show_default=False,
343 )
344 @click.option(
345     "--extend-exclude",
346     type=str,
347     callback=validate_regex,
348     help=(
349         "Like --exclude, but adds additional files and directories on top of the"
350         " excluded ones. (Useful if you simply want to add to the default)"
351     ),
352 )
353 @click.option(
354     "--force-exclude",
355     type=str,
356     callback=validate_regex,
357     help=(
358         "Like --exclude, but files and directories matching this regex will be "
359         "excluded even when they are passed explicitly as arguments."
360     ),
361 )
362 @click.option(
363     "--stdin-filename",
364     type=str,
365     help=(
366         "The name of the file when passing it through stdin. Useful to make "
367         "sure Black will respect --force-exclude option on some "
368         "editors that rely on using stdin."
369     ),
370 )
371 @click.option(
372     "-W",
373     "--workers",
374     type=click.IntRange(min=1),
375     default=None,
376     help="Number of parallel workers [default: number of CPUs in the system]",
377 )
378 @click.option(
379     "-q",
380     "--quiet",
381     is_flag=True,
382     help=(
383         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
384         " those with 2>/dev/null."
385     ),
386 )
387 @click.option(
388     "-v",
389     "--verbose",
390     is_flag=True,
391     help=(
392         "Also emit messages to stderr about files that were not changed or were ignored"
393         " due to exclusion patterns."
394     ),
395 )
396 @click.version_option(
397     version=__version__,
398     message=(
399         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
400         f"Python ({platform.python_implementation()}) {platform.python_version()}"
401     ),
402 )
403 @click.argument(
404     "src",
405     nargs=-1,
406     type=click.Path(
407         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
408     ),
409     is_eager=True,
410     metavar="SRC ...",
411 )
412 @click.option(
413     "--config",
414     type=click.Path(
415         exists=True,
416         file_okay=True,
417         dir_okay=False,
418         readable=True,
419         allow_dash=False,
420         path_type=str,
421     ),
422     is_eager=True,
423     callback=read_pyproject_toml,
424     help="Read configuration from FILE path.",
425 )
426 @click.pass_context
427 def main(  # noqa: C901
428     ctx: click.Context,
429     code: Optional[str],
430     line_length: int,
431     target_version: List[TargetVersion],
432     check: bool,
433     diff: bool,
434     color: bool,
435     fast: bool,
436     pyi: bool,
437     ipynb: bool,
438     python_cell_magics: Sequence[str],
439     skip_source_first_line: bool,
440     skip_string_normalization: bool,
441     skip_magic_trailing_comma: bool,
442     experimental_string_processing: bool,
443     preview: bool,
444     quiet: bool,
445     verbose: bool,
446     required_version: Optional[str],
447     include: Pattern[str],
448     exclude: Optional[Pattern[str]],
449     extend_exclude: Optional[Pattern[str]],
450     force_exclude: Optional[Pattern[str]],
451     stdin_filename: Optional[str],
452     workers: Optional[int],
453     src: Tuple[str, ...],
454     config: Optional[str],
455 ) -> None:
456     """The uncompromising code formatter."""
457     ctx.ensure_object(dict)
458
459     if src and code is not None:
460         out(
461             main.get_usage(ctx)
462             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
463         )
464         ctx.exit(1)
465     if not src and code is None:
466         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
467         ctx.exit(1)
468
469     root, method = (
470         find_project_root(src, stdin_filename) if code is None else (None, None)
471     )
472     ctx.obj["root"] = root
473
474     if verbose:
475         if root:
476             out(
477                 f"Identified `{root}` as project root containing a {method}.",
478                 fg="blue",
479             )
480
481             normalized = [
482                 (
483                     (source, source)
484                     if source == "-"
485                     else (normalize_path_maybe_ignore(Path(source), root), source)
486                 )
487                 for source in src
488             ]
489             srcs_string = ", ".join(
490                 [
491                     (
492                         f'"{_norm}"'
493                         if _norm
494                         else f'\033[31m"{source} (skipping - invalid)"\033[34m'
495                     )
496                     for _norm, source in normalized
497                 ]
498             )
499             out(f"Sources to be formatted: {srcs_string}", fg="blue")
500
501         if config:
502             config_source = ctx.get_parameter_source("config")
503             user_level_config = str(find_user_pyproject_toml())
504             if config == user_level_config:
505                 out(
506                     (
507                         "Using configuration from user-level config at "
508                         f"'{user_level_config}'."
509                     ),
510                     fg="blue",
511                 )
512             elif config_source in (
513                 ParameterSource.DEFAULT,
514                 ParameterSource.DEFAULT_MAP,
515             ):
516                 out("Using configuration from project root.", fg="blue")
517             else:
518                 out(f"Using configuration in '{config}'.", fg="blue")
519             if ctx.default_map:
520                 for param, value in ctx.default_map.items():
521                     out(f"{param}: {value}")
522
523     error_msg = "Oh no! 💥 💔 💥"
524     if (
525         required_version
526         and required_version != __version__
527         and required_version != __version__.split(".")[0]
528     ):
529         err(
530             f"{error_msg} The required version `{required_version}` does not match"
531             f" the running version `{__version__}`!"
532         )
533         ctx.exit(1)
534     if ipynb and pyi:
535         err("Cannot pass both `pyi` and `ipynb` flags!")
536         ctx.exit(1)
537
538     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
539     if target_version:
540         versions = set(target_version)
541     else:
542         # We'll autodetect later.
543         versions = set()
544     mode = Mode(
545         target_versions=versions,
546         line_length=line_length,
547         is_pyi=pyi,
548         is_ipynb=ipynb,
549         skip_source_first_line=skip_source_first_line,
550         string_normalization=not skip_string_normalization,
551         magic_trailing_comma=not skip_magic_trailing_comma,
552         experimental_string_processing=experimental_string_processing,
553         preview=preview,
554         python_cell_magics=set(python_cell_magics),
555     )
556
557     if code is not None:
558         # Run in quiet mode by default with -c; the extra output isn't useful.
559         # You can still pass -v to get verbose output.
560         quiet = True
561
562     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
563
564     if code is not None:
565         reformat_code(
566             content=code, fast=fast, write_back=write_back, mode=mode, report=report
567         )
568     else:
569         try:
570             sources = get_sources(
571                 ctx=ctx,
572                 src=src,
573                 quiet=quiet,
574                 verbose=verbose,
575                 include=include,
576                 exclude=exclude,
577                 extend_exclude=extend_exclude,
578                 force_exclude=force_exclude,
579                 report=report,
580                 stdin_filename=stdin_filename,
581             )
582         except GitWildMatchPatternError:
583             ctx.exit(1)
584
585         path_empty(
586             sources,
587             "No Python files are present to be formatted. Nothing to do 😴",
588             quiet,
589             verbose,
590             ctx,
591         )
592
593         if len(sources) == 1:
594             reformat_one(
595                 src=sources.pop(),
596                 fast=fast,
597                 write_back=write_back,
598                 mode=mode,
599                 report=report,
600             )
601         else:
602             from black.concurrency import reformat_many
603
604             reformat_many(
605                 sources=sources,
606                 fast=fast,
607                 write_back=write_back,
608                 mode=mode,
609                 report=report,
610                 workers=workers,
611             )
612
613     if verbose or not quiet:
614         if code is None and (verbose or report.change_count or report.failure_count):
615             out()
616         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
617         if code is None:
618             click.echo(str(report), err=True)
619     ctx.exit(report.return_code)
620
621
622 def get_sources(
623     *,
624     ctx: click.Context,
625     src: Tuple[str, ...],
626     quiet: bool,
627     verbose: bool,
628     include: Pattern[str],
629     exclude: Optional[Pattern[str]],
630     extend_exclude: Optional[Pattern[str]],
631     force_exclude: Optional[Pattern[str]],
632     report: "Report",
633     stdin_filename: Optional[str],
634 ) -> Set[Path]:
635     """Compute the set of files to be formatted."""
636     sources: Set[Path] = set()
637     root = ctx.obj["root"]
638
639     using_default_exclude = exclude is None
640     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
641     gitignore: Optional[Dict[Path, PathSpec]] = None
642     root_gitignore = get_gitignore(root)
643
644     for s in src:
645         if s == "-" and stdin_filename:
646             p = Path(stdin_filename)
647             is_stdin = True
648         else:
649             p = Path(s)
650             is_stdin = False
651
652         if is_stdin or p.is_file():
653             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
654             if normalized_path is None:
655                 continue
656
657             normalized_path = "/" + normalized_path
658             # Hard-exclude any files that matches the `--force-exclude` regex.
659             if force_exclude:
660                 force_exclude_match = force_exclude.search(normalized_path)
661             else:
662                 force_exclude_match = None
663             if force_exclude_match and force_exclude_match.group(0):
664                 report.path_ignored(p, "matches the --force-exclude regular expression")
665                 continue
666
667             if is_stdin:
668                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
669
670             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
671                 verbose=verbose, quiet=quiet
672             ):
673                 continue
674
675             sources.add(p)
676         elif p.is_dir():
677             p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
678             if using_default_exclude:
679                 gitignore = {
680                     root: root_gitignore,
681                     p: get_gitignore(p),
682                 }
683             sources.update(
684                 gen_python_files(
685                     p.iterdir(),
686                     ctx.obj["root"],
687                     include,
688                     exclude,
689                     extend_exclude,
690                     force_exclude,
691                     report,
692                     gitignore,
693                     verbose=verbose,
694                     quiet=quiet,
695                 )
696             )
697         elif s == "-":
698             sources.add(p)
699         else:
700             err(f"invalid path: {s}")
701     return sources
702
703
704 def path_empty(
705     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
706 ) -> None:
707     """
708     Exit if there is no `src` provided for formatting
709     """
710     if not src:
711         if verbose or not quiet:
712             out(msg)
713         ctx.exit(0)
714
715
716 def reformat_code(
717     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
718 ) -> None:
719     """
720     Reformat and print out `content` without spawning child processes.
721     Similar to `reformat_one`, but for string content.
722
723     `fast`, `write_back`, and `mode` options are passed to
724     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
725     """
726     path = Path("<string>")
727     try:
728         changed = Changed.NO
729         if format_stdin_to_stdout(
730             content=content, fast=fast, write_back=write_back, mode=mode
731         ):
732             changed = Changed.YES
733         report.done(path, changed)
734     except Exception as exc:
735         if report.verbose:
736             traceback.print_exc()
737         report.failed(path, str(exc))
738
739
740 # diff-shades depends on being to monkeypatch this function to operate. I know it's
741 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
742 @mypyc_attr(patchable=True)
743 def reformat_one(
744     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
745 ) -> None:
746     """Reformat a single file under `src` without spawning child processes.
747
748     `fast`, `write_back`, and `mode` options are passed to
749     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
750     """
751     try:
752         changed = Changed.NO
753
754         if str(src) == "-":
755             is_stdin = True
756         elif str(src).startswith(STDIN_PLACEHOLDER):
757             is_stdin = True
758             # Use the original name again in case we want to print something
759             # to the user
760             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
761         else:
762             is_stdin = False
763
764         if is_stdin:
765             if src.suffix == ".pyi":
766                 mode = replace(mode, is_pyi=True)
767             elif src.suffix == ".ipynb":
768                 mode = replace(mode, is_ipynb=True)
769             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
770                 changed = Changed.YES
771         else:
772             cache: Cache = {}
773             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
774                 cache = read_cache(mode)
775                 res_src = src.resolve()
776                 res_src_s = str(res_src)
777                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
778                     changed = Changed.CACHED
779             if changed is not Changed.CACHED and format_file_in_place(
780                 src, fast=fast, write_back=write_back, mode=mode
781             ):
782                 changed = Changed.YES
783             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
784                 write_back is WriteBack.CHECK and changed is Changed.NO
785             ):
786                 write_cache(cache, [src], mode)
787         report.done(src, changed)
788     except Exception as exc:
789         if report.verbose:
790             traceback.print_exc()
791         report.failed(src, str(exc))
792
793
794 def format_file_in_place(
795     src: Path,
796     fast: bool,
797     mode: Mode,
798     write_back: WriteBack = WriteBack.NO,
799     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
800 ) -> bool:
801     """Format file under `src` path. Return True if changed.
802
803     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
804     code to the file.
805     `mode` and `fast` options are passed to :func:`format_file_contents`.
806     """
807     if src.suffix == ".pyi":
808         mode = replace(mode, is_pyi=True)
809     elif src.suffix == ".ipynb":
810         mode = replace(mode, is_ipynb=True)
811
812     then = datetime.utcfromtimestamp(src.stat().st_mtime)
813     header = b""
814     with open(src, "rb") as buf:
815         if mode.skip_source_first_line:
816             header = buf.readline()
817         src_contents, encoding, newline = decode_bytes(buf.read())
818     try:
819         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
820     except NothingChanged:
821         return False
822     except JSONDecodeError:
823         raise ValueError(
824             f"File '{src}' cannot be parsed as valid Jupyter notebook."
825         ) from None
826     src_contents = header.decode(encoding) + src_contents
827     dst_contents = header.decode(encoding) + dst_contents
828
829     if write_back == WriteBack.YES:
830         with open(src, "w", encoding=encoding, newline=newline) as f:
831             f.write(dst_contents)
832     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
833         now = datetime.utcnow()
834         src_name = f"{src}\t{then} +0000"
835         dst_name = f"{src}\t{now} +0000"
836         if mode.is_ipynb:
837             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
838         else:
839             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
840
841         if write_back == WriteBack.COLOR_DIFF:
842             diff_contents = color_diff(diff_contents)
843
844         with lock or nullcontext():
845             f = io.TextIOWrapper(
846                 sys.stdout.buffer,
847                 encoding=encoding,
848                 newline=newline,
849                 write_through=True,
850             )
851             f = wrap_stream_for_windows(f)
852             f.write(diff_contents)
853             f.detach()
854
855     return True
856
857
858 def format_stdin_to_stdout(
859     fast: bool,
860     *,
861     content: Optional[str] = None,
862     write_back: WriteBack = WriteBack.NO,
863     mode: Mode,
864 ) -> bool:
865     """Format file on stdin. Return True if changed.
866
867     If content is None, it's read from sys.stdin.
868
869     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
870     write a diff to stdout. The `mode` argument is passed to
871     :func:`format_file_contents`.
872     """
873     then = datetime.utcnow()
874
875     if content is None:
876         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
877     else:
878         src, encoding, newline = content, "utf-8", ""
879
880     dst = src
881     try:
882         dst = format_file_contents(src, fast=fast, mode=mode)
883         return True
884
885     except NothingChanged:
886         return False
887
888     finally:
889         f = io.TextIOWrapper(
890             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
891         )
892         if write_back == WriteBack.YES:
893             # Make sure there's a newline after the content
894             if dst and dst[-1] != "\n":
895                 dst += "\n"
896             f.write(dst)
897         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
898             now = datetime.utcnow()
899             src_name = f"STDIN\t{then} +0000"
900             dst_name = f"STDOUT\t{now} +0000"
901             d = diff(src, dst, src_name, dst_name)
902             if write_back == WriteBack.COLOR_DIFF:
903                 d = color_diff(d)
904                 f = wrap_stream_for_windows(f)
905             f.write(d)
906         f.detach()
907
908
909 def check_stability_and_equivalence(
910     src_contents: str, dst_contents: str, *, mode: Mode
911 ) -> None:
912     """Perform stability and equivalence checks.
913
914     Raise AssertionError if source and destination contents are not
915     equivalent, or if a second pass of the formatter would format the
916     content differently.
917     """
918     assert_equivalent(src_contents, dst_contents)
919     assert_stable(src_contents, dst_contents, mode=mode)
920
921
922 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
923     """Reformat contents of a file and return new contents.
924
925     If `fast` is False, additionally confirm that the reformatted code is
926     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
927     `mode` is passed to :func:`format_str`.
928     """
929     if mode.is_ipynb:
930         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
931     else:
932         dst_contents = format_str(src_contents, mode=mode)
933     if src_contents == dst_contents:
934         raise NothingChanged
935
936     if not fast and not mode.is_ipynb:
937         # Jupyter notebooks will already have been checked above.
938         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
939     return dst_contents
940
941
942 def validate_cell(src: str, mode: Mode) -> None:
943     """Check that cell does not already contain TransformerManager transformations,
944     or non-Python cell magics, which might cause tokenizer_rt to break because of
945     indentations.
946
947     If a cell contains ``!ls``, then it'll be transformed to
948     ``get_ipython().system('ls')``. However, if the cell originally contained
949     ``get_ipython().system('ls')``, then it would get transformed in the same way:
950
951         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
952         "get_ipython().system('ls')\n"
953         >>> TransformerManager().transform_cell("!ls")
954         "get_ipython().system('ls')\n"
955
956     Due to the impossibility of safely roundtripping in such situations, cells
957     containing transformed magics will be ignored.
958     """
959     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
960         raise NothingChanged
961     if (
962         src[:2] == "%%"
963         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
964     ):
965         raise NothingChanged
966
967
968 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
969     """Format code in given cell of Jupyter notebook.
970
971     General idea is:
972
973       - if cell has trailing semicolon, remove it;
974       - if cell has IPython magics, mask them;
975       - format cell;
976       - reinstate IPython magics;
977       - reinstate trailing semicolon (if originally present);
978       - strip trailing newlines.
979
980     Cells with syntax errors will not be processed, as they
981     could potentially be automagics or multi-line magics, which
982     are currently not supported.
983     """
984     validate_cell(src, mode)
985     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
986         src
987     )
988     try:
989         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
990     except SyntaxError:
991         raise NothingChanged from None
992     masked_dst = format_str(masked_src, mode=mode)
993     if not fast:
994         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
995     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
996     dst = put_trailing_semicolon_back(
997         dst_without_trailing_semicolon, has_trailing_semicolon
998     )
999     dst = dst.rstrip("\n")
1000     if dst == src:
1001         raise NothingChanged from None
1002     return dst
1003
1004
1005 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1006     """If notebook is marked as non-Python, don't format it.
1007
1008     All notebook metadata fields are optional, see
1009     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1010     if a notebook has empty metadata, we will try to parse it anyway.
1011     """
1012     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1013     if language is not None and language != "python":
1014         raise NothingChanged from None
1015
1016
1017 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1018     """Format Jupyter notebook.
1019
1020     Operate cell-by-cell, only on code cells, only for Python notebooks.
1021     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1022     """
1023     if not src_contents:
1024         raise NothingChanged
1025
1026     trailing_newline = src_contents[-1] == "\n"
1027     modified = False
1028     nb = json.loads(src_contents)
1029     validate_metadata(nb)
1030     for cell in nb["cells"]:
1031         if cell.get("cell_type", None) == "code":
1032             try:
1033                 src = "".join(cell["source"])
1034                 dst = format_cell(src, fast=fast, mode=mode)
1035             except NothingChanged:
1036                 pass
1037             else:
1038                 cell["source"] = dst.splitlines(keepends=True)
1039                 modified = True
1040     if modified:
1041         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1042         if trailing_newline:
1043             dst_contents = dst_contents + "\n"
1044         return dst_contents
1045     else:
1046         raise NothingChanged
1047
1048
1049 def format_str(src_contents: str, *, mode: Mode) -> str:
1050     """Reformat a string and return new contents.
1051
1052     `mode` determines formatting options, such as how many characters per line are
1053     allowed.  Example:
1054
1055     >>> import black
1056     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1057     def f(arg: str = "") -> None:
1058         ...
1059
1060     A more complex example:
1061
1062     >>> print(
1063     ...   black.format_str(
1064     ...     "def f(arg:str='')->None: hey",
1065     ...     mode=black.Mode(
1066     ...       target_versions={black.TargetVersion.PY36},
1067     ...       line_length=10,
1068     ...       string_normalization=False,
1069     ...       is_pyi=False,
1070     ...     ),
1071     ...   ),
1072     ... )
1073     def f(
1074         arg: str = '',
1075     ) -> None:
1076         hey
1077
1078     """
1079     dst_contents = _format_str_once(src_contents, mode=mode)
1080     # Forced second pass to work around optional trailing commas (becoming
1081     # forced trailing commas on pass 2) interacting differently with optional
1082     # parentheses.  Admittedly ugly.
1083     if src_contents != dst_contents:
1084         return _format_str_once(dst_contents, mode=mode)
1085     return dst_contents
1086
1087
1088 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1089     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1090     dst_blocks: List[LinesBlock] = []
1091     if mode.target_versions:
1092         versions = mode.target_versions
1093     else:
1094         future_imports = get_future_imports(src_node)
1095         versions = detect_target_versions(src_node, future_imports=future_imports)
1096
1097     context_manager_features = {
1098         feature
1099         for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
1100         if supports_feature(versions, feature)
1101     }
1102     normalize_fmt_off(src_node)
1103     lines = LineGenerator(mode=mode, features=context_manager_features)
1104     elt = EmptyLineTracker(mode=mode)
1105     split_line_features = {
1106         feature
1107         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1108         if supports_feature(versions, feature)
1109     }
1110     block: Optional[LinesBlock] = None
1111     for current_line in lines.visit(src_node):
1112         block = elt.maybe_empty_lines(current_line)
1113         dst_blocks.append(block)
1114         for line in transform_line(
1115             current_line, mode=mode, features=split_line_features
1116         ):
1117             block.content_lines.append(str(line))
1118     if dst_blocks:
1119         dst_blocks[-1].after = 0
1120     dst_contents = []
1121     for block in dst_blocks:
1122         dst_contents.extend(block.all_lines())
1123     if not dst_contents:
1124         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1125         # and check if normalized_content has more than one line
1126         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1127         if "\n" in normalized_content:
1128             return newline
1129         return ""
1130     return "".join(dst_contents)
1131
1132
1133 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1134     """Return a tuple of (decoded_contents, encoding, newline).
1135
1136     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1137     universal newlines (i.e. only contains LF).
1138     """
1139     srcbuf = io.BytesIO(src)
1140     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1141     if not lines:
1142         return "", encoding, "\n"
1143
1144     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1145     srcbuf.seek(0)
1146     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1147         return tiow.read(), encoding, newline
1148
1149
1150 def get_features_used(  # noqa: C901
1151     node: Node, *, future_imports: Optional[Set[str]] = None
1152 ) -> Set[Feature]:
1153     """Return a set of (relatively) new Python features used in this file.
1154
1155     Currently looking for:
1156     - f-strings;
1157     - self-documenting expressions in f-strings (f"{x=}");
1158     - underscores in numeric literals;
1159     - trailing commas after * or ** in function signatures and calls;
1160     - positional only arguments in function signatures and lambdas;
1161     - assignment expression;
1162     - relaxed decorator syntax;
1163     - usage of __future__ flags (annotations);
1164     - print / exec statements;
1165     - parenthesized context managers;
1166     - match statements;
1167     - except* clause;
1168     - variadic generics;
1169     """
1170     features: Set[Feature] = set()
1171     if future_imports:
1172         features |= {
1173             FUTURE_FLAG_TO_FEATURE[future_import]
1174             for future_import in future_imports
1175             if future_import in FUTURE_FLAG_TO_FEATURE
1176         }
1177
1178     for n in node.pre_order():
1179         if is_string_token(n):
1180             value_head = n.value[:2]
1181             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1182                 features.add(Feature.F_STRINGS)
1183                 if Feature.DEBUG_F_STRINGS not in features:
1184                     for span_beg, span_end in iter_fexpr_spans(n.value):
1185                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1186                             features.add(Feature.DEBUG_F_STRINGS)
1187                             break
1188
1189         elif is_number_token(n):
1190             if "_" in n.value:
1191                 features.add(Feature.NUMERIC_UNDERSCORES)
1192
1193         elif n.type == token.SLASH:
1194             if n.parent and n.parent.type in {
1195                 syms.typedargslist,
1196                 syms.arglist,
1197                 syms.varargslist,
1198             }:
1199                 features.add(Feature.POS_ONLY_ARGUMENTS)
1200
1201         elif n.type == token.COLONEQUAL:
1202             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1203
1204         elif n.type == syms.decorator:
1205             if len(n.children) > 1 and not is_simple_decorator_expression(
1206                 n.children[1]
1207             ):
1208                 features.add(Feature.RELAXED_DECORATORS)
1209
1210         elif (
1211             n.type in {syms.typedargslist, syms.arglist}
1212             and n.children
1213             and n.children[-1].type == token.COMMA
1214         ):
1215             if n.type == syms.typedargslist:
1216                 feature = Feature.TRAILING_COMMA_IN_DEF
1217             else:
1218                 feature = Feature.TRAILING_COMMA_IN_CALL
1219
1220             for ch in n.children:
1221                 if ch.type in STARS:
1222                     features.add(feature)
1223
1224                 if ch.type == syms.argument:
1225                     for argch in ch.children:
1226                         if argch.type in STARS:
1227                             features.add(feature)
1228
1229         elif (
1230             n.type in {syms.return_stmt, syms.yield_expr}
1231             and len(n.children) >= 2
1232             and n.children[1].type == syms.testlist_star_expr
1233             and any(child.type == syms.star_expr for child in n.children[1].children)
1234         ):
1235             features.add(Feature.UNPACKING_ON_FLOW)
1236
1237         elif (
1238             n.type == syms.annassign
1239             and len(n.children) >= 4
1240             and n.children[3].type == syms.testlist_star_expr
1241         ):
1242             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1243
1244         elif (
1245             n.type == syms.with_stmt
1246             and len(n.children) > 2
1247             and n.children[1].type == syms.atom
1248         ):
1249             atom_children = n.children[1].children
1250             if (
1251                 len(atom_children) == 3
1252                 and atom_children[0].type == token.LPAR
1253                 and atom_children[1].type == syms.testlist_gexp
1254                 and atom_children[2].type == token.RPAR
1255             ):
1256                 features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
1257
1258         elif n.type == syms.match_stmt:
1259             features.add(Feature.PATTERN_MATCHING)
1260
1261         elif (
1262             n.type == syms.except_clause
1263             and len(n.children) >= 2
1264             and n.children[1].type == token.STAR
1265         ):
1266             features.add(Feature.EXCEPT_STAR)
1267
1268         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1269             child.type == syms.star_expr for child in n.children
1270         ):
1271             features.add(Feature.VARIADIC_GENERICS)
1272
1273         elif (
1274             n.type == syms.tname_star
1275             and len(n.children) == 3
1276             and n.children[2].type == syms.star_expr
1277         ):
1278             features.add(Feature.VARIADIC_GENERICS)
1279
1280     return features
1281
1282
1283 def detect_target_versions(
1284     node: Node, *, future_imports: Optional[Set[str]] = None
1285 ) -> Set[TargetVersion]:
1286     """Detect the version to target based on the nodes used."""
1287     features = get_features_used(node, future_imports=future_imports)
1288     return {
1289         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1290     }
1291
1292
1293 def get_future_imports(node: Node) -> Set[str]:
1294     """Return a set of __future__ imports in the file."""
1295     imports: Set[str] = set()
1296
1297     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1298         for child in children:
1299             if isinstance(child, Leaf):
1300                 if child.type == token.NAME:
1301                     yield child.value
1302
1303             elif child.type == syms.import_as_name:
1304                 orig_name = child.children[0]
1305                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1306                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1307                 yield orig_name.value
1308
1309             elif child.type == syms.import_as_names:
1310                 yield from get_imports_from_children(child.children)
1311
1312             else:
1313                 raise AssertionError("Invalid syntax parsing imports")
1314
1315     for child in node.children:
1316         if child.type != syms.simple_stmt:
1317             break
1318
1319         first_child = child.children[0]
1320         if isinstance(first_child, Leaf):
1321             # Continue looking if we see a docstring; otherwise stop.
1322             if (
1323                 len(child.children) == 2
1324                 and first_child.type == token.STRING
1325                 and child.children[1].type == token.NEWLINE
1326             ):
1327                 continue
1328
1329             break
1330
1331         elif first_child.type == syms.import_from:
1332             module_name = first_child.children[1]
1333             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1334                 break
1335
1336             imports |= set(get_imports_from_children(first_child.children[3:]))
1337         else:
1338             break
1339
1340     return imports
1341
1342
1343 def assert_equivalent(src: str, dst: str) -> None:
1344     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1345     try:
1346         src_ast = parse_ast(src)
1347     except Exception as exc:
1348         raise AssertionError(
1349             "cannot use --safe with this file; failed to parse source file AST: "
1350             f"{exc}\n"
1351             "This could be caused by running Black with an older Python version "
1352             "that does not support new syntax used in your source file."
1353         ) from exc
1354
1355     try:
1356         dst_ast = parse_ast(dst)
1357     except Exception as exc:
1358         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1359         raise AssertionError(
1360             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1361             "Please report a bug on https://github.com/psf/black/issues.  "
1362             f"This invalid output might be helpful: {log}"
1363         ) from None
1364
1365     src_ast_str = "\n".join(stringify_ast(src_ast))
1366     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1367     if src_ast_str != dst_ast_str:
1368         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1369         raise AssertionError(
1370             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1371             " source.  Please report a bug on "
1372             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1373         ) from None
1374
1375
1376 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1377     """Raise AssertionError if `dst` reformats differently the second time."""
1378     # We shouldn't call format_str() here, because that formats the string
1379     # twice and may hide a bug where we bounce back and forth between two
1380     # versions.
1381     newdst = _format_str_once(dst, mode=mode)
1382     if dst != newdst:
1383         log = dump_to_file(
1384             str(mode),
1385             diff(src, dst, "source", "first pass"),
1386             diff(dst, newdst, "first pass", "second pass"),
1387         )
1388         raise AssertionError(
1389             "INTERNAL ERROR: Black produced different code on the second pass of the"
1390             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1391             f"  This diff might be helpful: {log}"
1392         ) from None
1393
1394
1395 @contextmanager
1396 def nullcontext() -> Iterator[None]:
1397     """Return an empty context manager.
1398
1399     To be used like `nullcontext` in Python 3.7.
1400     """
1401     yield
1402
1403
1404 def patch_click() -> None:
1405     """Make Click not crash on Python 3.6 with LANG=C.
1406
1407     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1408     default which restricts paths that it can access during the lifetime of the
1409     application.  Click refuses to work in this scenario by raising a RuntimeError.
1410
1411     In case of Black the likelihood that non-ASCII characters are going to be used in
1412     file paths is minimal since it's Python source code.  Moreover, this crash was
1413     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1414     """
1415     modules: List[Any] = []
1416     try:
1417         from click import core
1418     except ImportError:
1419         pass
1420     else:
1421         modules.append(core)
1422     try:
1423         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1424         # older versions installed.
1425         from click import _unicodefun  # type: ignore
1426     except ImportError:
1427         pass
1428     else:
1429         modules.append(_unicodefun)
1430
1431     for module in modules:
1432         if hasattr(module, "_verify_python3_env"):
1433             module._verify_python3_env = lambda: None
1434         if hasattr(module, "_verify_python_env"):
1435             module._verify_python_env = lambda: None
1436
1437
1438 def patched_main() -> None:
1439     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1440     # environments so just assume we always need to call it if frozen.
1441     if getattr(sys, "frozen", False):
1442         from multiprocessing import freeze_support
1443
1444         freeze_support()
1445
1446     patch_click()
1447     main()
1448
1449
1450 if __name__ == "__main__":
1451     patched_main()