]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

dc06eab8dd0b9ca7204c7dd31ccea842d97f492d
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime, timezone
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import (
67     FUTURE_FLAG_TO_FEATURE,
68     VERSION_TO_FEATURES,
69     Feature,
70     Mode,
71     TargetVersion,
72     supports_feature,
73 )
74 from black.nodes import (
75     STARS,
76     is_number_token,
77     is_simple_decorator_expression,
78     is_string_token,
79     syms,
80 )
81 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
82 from black.parsing import InvalidInput  # noqa F401
83 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
84 from black.report import Changed, NothingChanged, Report
85 from black.trans import iter_fexpr_spans
86 from blib2to3.pgen2 import token
87 from blib2to3.pytree import Leaf, Node
88
89 COMPILED = Path(__file__).suffix in (".pyd", ".so")
90
91 # types
92 FileContent = str
93 Encoding = str
94 NewLine = str
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102     COLOR_DIFF = 4
103
104     @classmethod
105     def from_configuration(
106         cls, *, check: bool, diff: bool, color: bool = False
107     ) -> "WriteBack":
108         if check and not diff:
109             return cls.CHECK
110
111         if diff and color:
112             return cls.COLOR_DIFF
113
114         return cls.DIFF if diff else cls.YES
115
116
117 # Legacy name, left for integrations.
118 FileMode = Mode
119
120
121 def read_pyproject_toml(
122     ctx: click.Context, param: click.Parameter, value: Optional[str]
123 ) -> Optional[str]:
124     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
125
126     Returns the path to a successfully found and read configuration file, None
127     otherwise.
128     """
129     if not value:
130         value = find_pyproject_toml(
131             ctx.params.get("src", ()), ctx.params.get("stdin_filename", None)
132         )
133         if value is None:
134             return None
135
136     try:
137         config = parse_pyproject_toml(value)
138     except (OSError, ValueError) as e:
139         raise click.FileError(
140             filename=value, hint=f"Error reading configuration file: {e}"
141         ) from None
142
143     if not config:
144         return None
145     else:
146         # Sanitize the values to be Click friendly. For more information please see:
147         # https://github.com/psf/black/issues/1458
148         # https://github.com/pallets/click/issues/1567
149         config = {
150             k: str(v) if not isinstance(v, (list, dict)) else v
151             for k, v in config.items()
152         }
153
154     target_version = config.get("target_version")
155     if target_version is not None and not isinstance(target_version, list):
156         raise click.BadOptionUsage(
157             "target-version", "Config key target-version must be a list"
158         )
159
160     exclude = config.get("exclude")
161     if exclude is not None and not isinstance(exclude, str):
162         raise click.BadOptionUsage("exclude", "Config key exclude must be a string")
163
164     extend_exclude = config.get("extend_exclude")
165     if extend_exclude is not None and not isinstance(extend_exclude, str):
166         raise click.BadOptionUsage(
167             "extend-exclude", "Config key extend-exclude must be a string"
168         )
169
170     default_map: Dict[str, Any] = {}
171     if ctx.default_map:
172         default_map.update(ctx.default_map)
173     default_map.update(config)
174
175     ctx.default_map = default_map
176     return value
177
178
179 def target_version_option_callback(
180     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
181 ) -> List[TargetVersion]:
182     """Compute the target versions from a --target-version flag.
183
184     This is its own function because mypy couldn't infer the type correctly
185     when it was a lambda, causing mypyc trouble.
186     """
187     return [TargetVersion[val.upper()] for val in v]
188
189
190 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
191     """Compile a regular expression string in `regex`.
192
193     If it contains newlines, use verbose mode.
194     """
195     if "\n" in regex:
196         regex = "(?x)" + regex
197     compiled: Pattern[str] = re.compile(regex)
198     return compiled
199
200
201 def validate_regex(
202     ctx: click.Context,
203     param: click.Parameter,
204     value: Optional[str],
205 ) -> Optional[Pattern[str]]:
206     try:
207         return re_compile_maybe_verbose(value) if value is not None else None
208     except re.error as e:
209         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
210
211
212 @click.command(
213     context_settings={"help_option_names": ["-h", "--help"]},
214     # While Click does set this field automatically using the docstring, mypyc
215     # (annoyingly) strips 'em so we need to set it here too.
216     help="The uncompromising code formatter.",
217 )
218 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
219 @click.option(
220     "-l",
221     "--line-length",
222     type=int,
223     default=DEFAULT_LINE_LENGTH,
224     help="How many characters per line to allow.",
225     show_default=True,
226 )
227 @click.option(
228     "-t",
229     "--target-version",
230     type=click.Choice([v.name.lower() for v in TargetVersion]),
231     callback=target_version_option_callback,
232     multiple=True,
233     help=(
234         "Python versions that should be supported by Black's output. By default, Black"
235         " will try to infer this from the project metadata in pyproject.toml. If this"
236         " does not yield conclusive results, Black will use per-file auto-detection."
237     ),
238 )
239 @click.option(
240     "--pyi",
241     is_flag=True,
242     help=(
243         "Format all input files like typing stubs regardless of file extension (useful"
244         " when piping source on standard input)."
245     ),
246 )
247 @click.option(
248     "--ipynb",
249     is_flag=True,
250     help=(
251         "Format all input files like Jupyter Notebooks regardless of file extension "
252         "(useful when piping source on standard input)."
253     ),
254 )
255 @click.option(
256     "--python-cell-magics",
257     multiple=True,
258     help=(
259         "When processing Jupyter Notebooks, add the given magic to the list"
260         f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
261         " Useful for formatting cells with custom python magics."
262     ),
263     default=[],
264 )
265 @click.option(
266     "-x",
267     "--skip-source-first-line",
268     is_flag=True,
269     help="Skip the first line of the source code.",
270 )
271 @click.option(
272     "-S",
273     "--skip-string-normalization",
274     is_flag=True,
275     help="Don't normalize string quotes or prefixes.",
276 )
277 @click.option(
278     "-C",
279     "--skip-magic-trailing-comma",
280     is_flag=True,
281     help="Don't use trailing commas as a reason to split lines.",
282 )
283 @click.option(
284     "--experimental-string-processing",
285     is_flag=True,
286     hidden=True,
287     help="(DEPRECATED and now included in --preview) Normalize string literals.",
288 )
289 @click.option(
290     "--preview",
291     is_flag=True,
292     help=(
293         "Enable potentially disruptive style changes that may be added to Black's main"
294         " functionality in the next major release."
295     ),
296 )
297 @click.option(
298     "--check",
299     is_flag=True,
300     help=(
301         "Don't write the files back, just return the status. Return code 0 means"
302         " nothing would change. Return code 1 means some files would be reformatted."
303         " Return code 123 means there was an internal error."
304     ),
305 )
306 @click.option(
307     "--diff",
308     is_flag=True,
309     help="Don't write the files back, just output a diff for each file on stdout.",
310 )
311 @click.option(
312     "--color/--no-color",
313     is_flag=True,
314     help="Show colored diff. Only applies when `--diff` is given.",
315 )
316 @click.option(
317     "--fast/--safe",
318     is_flag=True,
319     help="If --fast given, skip temporary sanity checks. [default: --safe]",
320 )
321 @click.option(
322     "--required-version",
323     type=str,
324     help=(
325         "Require a specific version of Black to be running (useful for unifying results"
326         " across many environments e.g. with a pyproject.toml file). It can be"
327         " either a major version number or an exact version."
328     ),
329 )
330 @click.option(
331     "--include",
332     type=str,
333     default=DEFAULT_INCLUDES,
334     callback=validate_regex,
335     help=(
336         "A regular expression that matches files and directories that should be"
337         " included on recursive searches. An empty value means all files are included"
338         " regardless of the name. Use forward slashes for directories on all platforms"
339         " (Windows, too). Exclusions are calculated first, inclusions later."
340     ),
341     show_default=True,
342 )
343 @click.option(
344     "--exclude",
345     type=str,
346     callback=validate_regex,
347     help=(
348         "A regular expression that matches files and directories that should be"
349         " excluded on recursive searches. An empty value means no paths are excluded."
350         " Use forward slashes for directories on all platforms (Windows, too)."
351         " Exclusions are calculated first, inclusions later. [default:"
352         f" {DEFAULT_EXCLUDES}]"
353     ),
354     show_default=False,
355 )
356 @click.option(
357     "--extend-exclude",
358     type=str,
359     callback=validate_regex,
360     help=(
361         "Like --exclude, but adds additional files and directories on top of the"
362         " excluded ones. (Useful if you simply want to add to the default)"
363     ),
364 )
365 @click.option(
366     "--force-exclude",
367     type=str,
368     callback=validate_regex,
369     help=(
370         "Like --exclude, but files and directories matching this regex will be "
371         "excluded even when they are passed explicitly as arguments."
372     ),
373 )
374 @click.option(
375     "--stdin-filename",
376     type=str,
377     is_eager=True,
378     help=(
379         "The name of the file when passing it through stdin. Useful to make "
380         "sure Black will respect --force-exclude option on some "
381         "editors that rely on using stdin."
382     ),
383 )
384 @click.option(
385     "-W",
386     "--workers",
387     type=click.IntRange(min=1),
388     default=None,
389     help=(
390         "Number of parallel workers [default: BLACK_NUM_WORKERS environment variable "
391         "or number of CPUs in the system]"
392     ),
393 )
394 @click.option(
395     "-q",
396     "--quiet",
397     is_flag=True,
398     help=(
399         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
400         " those with 2>/dev/null."
401     ),
402 )
403 @click.option(
404     "-v",
405     "--verbose",
406     is_flag=True,
407     help=(
408         "Also emit messages to stderr about files that were not changed or were ignored"
409         " due to exclusion patterns."
410     ),
411 )
412 @click.version_option(
413     version=__version__,
414     message=(
415         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
416         f"Python ({platform.python_implementation()}) {platform.python_version()}"
417     ),
418 )
419 @click.argument(
420     "src",
421     nargs=-1,
422     type=click.Path(
423         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
424     ),
425     is_eager=True,
426     metavar="SRC ...",
427 )
428 @click.option(
429     "--config",
430     type=click.Path(
431         exists=True,
432         file_okay=True,
433         dir_okay=False,
434         readable=True,
435         allow_dash=False,
436         path_type=str,
437     ),
438     is_eager=True,
439     callback=read_pyproject_toml,
440     help="Read configuration from FILE path.",
441 )
442 @click.pass_context
443 def main(  # noqa: C901
444     ctx: click.Context,
445     code: Optional[str],
446     line_length: int,
447     target_version: List[TargetVersion],
448     check: bool,
449     diff: bool,
450     color: bool,
451     fast: bool,
452     pyi: bool,
453     ipynb: bool,
454     python_cell_magics: Sequence[str],
455     skip_source_first_line: bool,
456     skip_string_normalization: bool,
457     skip_magic_trailing_comma: bool,
458     experimental_string_processing: bool,
459     preview: bool,
460     quiet: bool,
461     verbose: bool,
462     required_version: Optional[str],
463     include: Pattern[str],
464     exclude: Optional[Pattern[str]],
465     extend_exclude: Optional[Pattern[str]],
466     force_exclude: Optional[Pattern[str]],
467     stdin_filename: Optional[str],
468     workers: Optional[int],
469     src: Tuple[str, ...],
470     config: Optional[str],
471 ) -> None:
472     """The uncompromising code formatter."""
473     ctx.ensure_object(dict)
474
475     if src and code is not None:
476         out(
477             main.get_usage(ctx)
478             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
479         )
480         ctx.exit(1)
481     if not src and code is None:
482         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
483         ctx.exit(1)
484
485     root, method = (
486         find_project_root(src, stdin_filename) if code is None else (None, None)
487     )
488     ctx.obj["root"] = root
489
490     if verbose:
491         if root:
492             out(
493                 f"Identified `{root}` as project root containing a {method}.",
494                 fg="blue",
495             )
496
497         if config:
498             config_source = ctx.get_parameter_source("config")
499             user_level_config = str(find_user_pyproject_toml())
500             if config == user_level_config:
501                 out(
502                     "Using configuration from user-level config at "
503                     f"'{user_level_config}'.",
504                     fg="blue",
505                 )
506             elif config_source in (
507                 ParameterSource.DEFAULT,
508                 ParameterSource.DEFAULT_MAP,
509             ):
510                 out("Using configuration from project root.", fg="blue")
511             else:
512                 out(f"Using configuration in '{config}'.", fg="blue")
513             if ctx.default_map:
514                 for param, value in ctx.default_map.items():
515                     out(f"{param}: {value}")
516
517     error_msg = "Oh no! 💥 💔 💥"
518     if (
519         required_version
520         and required_version != __version__
521         and required_version != __version__.split(".")[0]
522     ):
523         err(
524             f"{error_msg} The required version `{required_version}` does not match"
525             f" the running version `{__version__}`!"
526         )
527         ctx.exit(1)
528     if ipynb and pyi:
529         err("Cannot pass both `pyi` and `ipynb` flags!")
530         ctx.exit(1)
531
532     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
533     if target_version:
534         versions = set(target_version)
535     else:
536         # We'll autodetect later.
537         versions = set()
538     mode = Mode(
539         target_versions=versions,
540         line_length=line_length,
541         is_pyi=pyi,
542         is_ipynb=ipynb,
543         skip_source_first_line=skip_source_first_line,
544         string_normalization=not skip_string_normalization,
545         magic_trailing_comma=not skip_magic_trailing_comma,
546         experimental_string_processing=experimental_string_processing,
547         preview=preview,
548         python_cell_magics=set(python_cell_magics),
549     )
550
551     if code is not None:
552         # Run in quiet mode by default with -c; the extra output isn't useful.
553         # You can still pass -v to get verbose output.
554         quiet = True
555
556     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
557
558     if code is not None:
559         reformat_code(
560             content=code, fast=fast, write_back=write_back, mode=mode, report=report
561         )
562     else:
563         try:
564             sources = get_sources(
565                 ctx=ctx,
566                 src=src,
567                 quiet=quiet,
568                 verbose=verbose,
569                 include=include,
570                 exclude=exclude,
571                 extend_exclude=extend_exclude,
572                 force_exclude=force_exclude,
573                 report=report,
574                 stdin_filename=stdin_filename,
575             )
576         except GitWildMatchPatternError:
577             ctx.exit(1)
578
579         path_empty(
580             sources,
581             "No Python files are present to be formatted. Nothing to do 😴",
582             quiet,
583             verbose,
584             ctx,
585         )
586
587         if len(sources) == 1:
588             reformat_one(
589                 src=sources.pop(),
590                 fast=fast,
591                 write_back=write_back,
592                 mode=mode,
593                 report=report,
594             )
595         else:
596             from black.concurrency import reformat_many
597
598             reformat_many(
599                 sources=sources,
600                 fast=fast,
601                 write_back=write_back,
602                 mode=mode,
603                 report=report,
604                 workers=workers,
605             )
606
607     if verbose or not quiet:
608         if code is None and (verbose or report.change_count or report.failure_count):
609             out()
610         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
611         if code is None:
612             click.echo(str(report), err=True)
613     ctx.exit(report.return_code)
614
615
616 def get_sources(
617     *,
618     ctx: click.Context,
619     src: Tuple[str, ...],
620     quiet: bool,
621     verbose: bool,
622     include: Pattern[str],
623     exclude: Optional[Pattern[str]],
624     extend_exclude: Optional[Pattern[str]],
625     force_exclude: Optional[Pattern[str]],
626     report: "Report",
627     stdin_filename: Optional[str],
628 ) -> Set[Path]:
629     """Compute the set of files to be formatted."""
630     sources: Set[Path] = set()
631     root = ctx.obj["root"]
632
633     using_default_exclude = exclude is None
634     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
635     gitignore: Optional[Dict[Path, PathSpec]] = None
636     root_gitignore = get_gitignore(root)
637
638     for s in src:
639         if s == "-" and stdin_filename:
640             p = Path(stdin_filename)
641             is_stdin = True
642         else:
643             p = Path(s)
644             is_stdin = False
645
646         if is_stdin or p.is_file():
647             normalized_path: Optional[str] = normalize_path_maybe_ignore(
648                 p, ctx.obj["root"], report
649             )
650             if normalized_path is None:
651                 if verbose:
652                     out(f'Skipping invalid source: "{normalized_path}"', fg="red")
653                 continue
654             if verbose:
655                 out(f'Found input source: "{normalized_path}"', fg="blue")
656
657             normalized_path = "/" + normalized_path
658             # Hard-exclude any files that matches the `--force-exclude` regex.
659             if force_exclude:
660                 force_exclude_match = force_exclude.search(normalized_path)
661             else:
662                 force_exclude_match = None
663             if force_exclude_match and force_exclude_match.group(0):
664                 report.path_ignored(p, "matches the --force-exclude regular expression")
665                 continue
666
667             if is_stdin:
668                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
669
670             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
671                 warn=verbose or not quiet
672             ):
673                 continue
674
675             sources.add(p)
676         elif p.is_dir():
677             p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
678             if verbose:
679                 out(f'Found input source directory: "{p}"', fg="blue")
680
681             if using_default_exclude:
682                 gitignore = {
683                     root: root_gitignore,
684                     p: get_gitignore(p),
685                 }
686             sources.update(
687                 gen_python_files(
688                     p.iterdir(),
689                     ctx.obj["root"],
690                     include,
691                     exclude,
692                     extend_exclude,
693                     force_exclude,
694                     report,
695                     gitignore,
696                     verbose=verbose,
697                     quiet=quiet,
698                 )
699             )
700         elif s == "-":
701             if verbose:
702                 out("Found input source stdin", fg="blue")
703             sources.add(p)
704         else:
705             err(f"invalid path: {s}")
706
707     return sources
708
709
710 def path_empty(
711     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
712 ) -> None:
713     """
714     Exit if there is no `src` provided for formatting
715     """
716     if not src:
717         if verbose or not quiet:
718             out(msg)
719         ctx.exit(0)
720
721
722 def reformat_code(
723     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
724 ) -> None:
725     """
726     Reformat and print out `content` without spawning child processes.
727     Similar to `reformat_one`, but for string content.
728
729     `fast`, `write_back`, and `mode` options are passed to
730     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
731     """
732     path = Path("<string>")
733     try:
734         changed = Changed.NO
735         if format_stdin_to_stdout(
736             content=content, fast=fast, write_back=write_back, mode=mode
737         ):
738             changed = Changed.YES
739         report.done(path, changed)
740     except Exception as exc:
741         if report.verbose:
742             traceback.print_exc()
743         report.failed(path, str(exc))
744
745
746 # diff-shades depends on being to monkeypatch this function to operate. I know it's
747 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
748 @mypyc_attr(patchable=True)
749 def reformat_one(
750     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
751 ) -> None:
752     """Reformat a single file under `src` without spawning child processes.
753
754     `fast`, `write_back`, and `mode` options are passed to
755     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
756     """
757     try:
758         changed = Changed.NO
759
760         if str(src) == "-":
761             is_stdin = True
762         elif str(src).startswith(STDIN_PLACEHOLDER):
763             is_stdin = True
764             # Use the original name again in case we want to print something
765             # to the user
766             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
767         else:
768             is_stdin = False
769
770         if is_stdin:
771             if src.suffix == ".pyi":
772                 mode = replace(mode, is_pyi=True)
773             elif src.suffix == ".ipynb":
774                 mode = replace(mode, is_ipynb=True)
775             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
776                 changed = Changed.YES
777         else:
778             cache = Cache.read(mode)
779             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
780                 if not cache.is_changed(src):
781                     changed = Changed.CACHED
782             if changed is not Changed.CACHED and format_file_in_place(
783                 src, fast=fast, write_back=write_back, mode=mode
784             ):
785                 changed = Changed.YES
786             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
787                 write_back is WriteBack.CHECK and changed is Changed.NO
788             ):
789                 cache.write([src])
790         report.done(src, changed)
791     except Exception as exc:
792         if report.verbose:
793             traceback.print_exc()
794         report.failed(src, str(exc))
795
796
797 def format_file_in_place(
798     src: Path,
799     fast: bool,
800     mode: Mode,
801     write_back: WriteBack = WriteBack.NO,
802     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
803 ) -> bool:
804     """Format file under `src` path. Return True if changed.
805
806     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
807     code to the file.
808     `mode` and `fast` options are passed to :func:`format_file_contents`.
809     """
810     if src.suffix == ".pyi":
811         mode = replace(mode, is_pyi=True)
812     elif src.suffix == ".ipynb":
813         mode = replace(mode, is_ipynb=True)
814
815     then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)
816     header = b""
817     with open(src, "rb") as buf:
818         if mode.skip_source_first_line:
819             header = buf.readline()
820         src_contents, encoding, newline = decode_bytes(buf.read())
821     try:
822         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
823     except NothingChanged:
824         return False
825     except JSONDecodeError:
826         raise ValueError(
827             f"File '{src}' cannot be parsed as valid Jupyter notebook."
828         ) from None
829     src_contents = header.decode(encoding) + src_contents
830     dst_contents = header.decode(encoding) + dst_contents
831
832     if write_back == WriteBack.YES:
833         with open(src, "w", encoding=encoding, newline=newline) as f:
834             f.write(dst_contents)
835     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
836         now = datetime.now(timezone.utc)
837         src_name = f"{src}\t{then}"
838         dst_name = f"{src}\t{now}"
839         if mode.is_ipynb:
840             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
841         else:
842             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
843
844         if write_back == WriteBack.COLOR_DIFF:
845             diff_contents = color_diff(diff_contents)
846
847         with lock or nullcontext():
848             f = io.TextIOWrapper(
849                 sys.stdout.buffer,
850                 encoding=encoding,
851                 newline=newline,
852                 write_through=True,
853             )
854             f = wrap_stream_for_windows(f)
855             f.write(diff_contents)
856             f.detach()
857
858     return True
859
860
861 def format_stdin_to_stdout(
862     fast: bool,
863     *,
864     content: Optional[str] = None,
865     write_back: WriteBack = WriteBack.NO,
866     mode: Mode,
867 ) -> bool:
868     """Format file on stdin. Return True if changed.
869
870     If content is None, it's read from sys.stdin.
871
872     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
873     write a diff to stdout. The `mode` argument is passed to
874     :func:`format_file_contents`.
875     """
876     then = datetime.now(timezone.utc)
877
878     if content is None:
879         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
880     else:
881         src, encoding, newline = content, "utf-8", ""
882
883     dst = src
884     try:
885         dst = format_file_contents(src, fast=fast, mode=mode)
886         return True
887
888     except NothingChanged:
889         return False
890
891     finally:
892         f = io.TextIOWrapper(
893             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
894         )
895         if write_back == WriteBack.YES:
896             # Make sure there's a newline after the content
897             if dst and dst[-1] != "\n":
898                 dst += "\n"
899             f.write(dst)
900         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
901             now = datetime.now(timezone.utc)
902             src_name = f"STDIN\t{then}"
903             dst_name = f"STDOUT\t{now}"
904             d = diff(src, dst, src_name, dst_name)
905             if write_back == WriteBack.COLOR_DIFF:
906                 d = color_diff(d)
907                 f = wrap_stream_for_windows(f)
908             f.write(d)
909         f.detach()
910
911
912 def check_stability_and_equivalence(
913     src_contents: str, dst_contents: str, *, mode: Mode
914 ) -> None:
915     """Perform stability and equivalence checks.
916
917     Raise AssertionError if source and destination contents are not
918     equivalent, or if a second pass of the formatter would format the
919     content differently.
920     """
921     assert_equivalent(src_contents, dst_contents)
922     assert_stable(src_contents, dst_contents, mode=mode)
923
924
925 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
926     """Reformat contents of a file and return new contents.
927
928     If `fast` is False, additionally confirm that the reformatted code is
929     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
930     `mode` is passed to :func:`format_str`.
931     """
932     if mode.is_ipynb:
933         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
934     else:
935         dst_contents = format_str(src_contents, mode=mode)
936     if src_contents == dst_contents:
937         raise NothingChanged
938
939     if not fast and not mode.is_ipynb:
940         # Jupyter notebooks will already have been checked above.
941         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
942     return dst_contents
943
944
945 def validate_cell(src: str, mode: Mode) -> None:
946     """Check that cell does not already contain TransformerManager transformations,
947     or non-Python cell magics, which might cause tokenizer_rt to break because of
948     indentations.
949
950     If a cell contains ``!ls``, then it'll be transformed to
951     ``get_ipython().system('ls')``. However, if the cell originally contained
952     ``get_ipython().system('ls')``, then it would get transformed in the same way:
953
954         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
955         "get_ipython().system('ls')\n"
956         >>> TransformerManager().transform_cell("!ls")
957         "get_ipython().system('ls')\n"
958
959     Due to the impossibility of safely roundtripping in such situations, cells
960     containing transformed magics will be ignored.
961     """
962     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
963         raise NothingChanged
964     if (
965         src[:2] == "%%"
966         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
967     ):
968         raise NothingChanged
969
970
971 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
972     """Format code in given cell of Jupyter notebook.
973
974     General idea is:
975
976       - if cell has trailing semicolon, remove it;
977       - if cell has IPython magics, mask them;
978       - format cell;
979       - reinstate IPython magics;
980       - reinstate trailing semicolon (if originally present);
981       - strip trailing newlines.
982
983     Cells with syntax errors will not be processed, as they
984     could potentially be automagics or multi-line magics, which
985     are currently not supported.
986     """
987     validate_cell(src, mode)
988     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
989         src
990     )
991     try:
992         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
993     except SyntaxError:
994         raise NothingChanged from None
995     masked_dst = format_str(masked_src, mode=mode)
996     if not fast:
997         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
998     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
999     dst = put_trailing_semicolon_back(
1000         dst_without_trailing_semicolon, has_trailing_semicolon
1001     )
1002     dst = dst.rstrip("\n")
1003     if dst == src:
1004         raise NothingChanged from None
1005     return dst
1006
1007
1008 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1009     """If notebook is marked as non-Python, don't format it.
1010
1011     All notebook metadata fields are optional, see
1012     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1013     if a notebook has empty metadata, we will try to parse it anyway.
1014     """
1015     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1016     if language is not None and language != "python":
1017         raise NothingChanged from None
1018
1019
1020 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1021     """Format Jupyter notebook.
1022
1023     Operate cell-by-cell, only on code cells, only for Python notebooks.
1024     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1025     """
1026     if not src_contents:
1027         raise NothingChanged
1028
1029     trailing_newline = src_contents[-1] == "\n"
1030     modified = False
1031     nb = json.loads(src_contents)
1032     validate_metadata(nb)
1033     for cell in nb["cells"]:
1034         if cell.get("cell_type", None) == "code":
1035             try:
1036                 src = "".join(cell["source"])
1037                 dst = format_cell(src, fast=fast, mode=mode)
1038             except NothingChanged:
1039                 pass
1040             else:
1041                 cell["source"] = dst.splitlines(keepends=True)
1042                 modified = True
1043     if modified:
1044         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1045         if trailing_newline:
1046             dst_contents = dst_contents + "\n"
1047         return dst_contents
1048     else:
1049         raise NothingChanged
1050
1051
1052 def format_str(src_contents: str, *, mode: Mode) -> str:
1053     """Reformat a string and return new contents.
1054
1055     `mode` determines formatting options, such as how many characters per line are
1056     allowed.  Example:
1057
1058     >>> import black
1059     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1060     def f(arg: str = "") -> None:
1061         ...
1062
1063     A more complex example:
1064
1065     >>> print(
1066     ...   black.format_str(
1067     ...     "def f(arg:str='')->None: hey",
1068     ...     mode=black.Mode(
1069     ...       target_versions={black.TargetVersion.PY36},
1070     ...       line_length=10,
1071     ...       string_normalization=False,
1072     ...       is_pyi=False,
1073     ...     ),
1074     ...   ),
1075     ... )
1076     def f(
1077         arg: str = '',
1078     ) -> None:
1079         hey
1080
1081     """
1082     dst_contents = _format_str_once(src_contents, mode=mode)
1083     # Forced second pass to work around optional trailing commas (becoming
1084     # forced trailing commas on pass 2) interacting differently with optional
1085     # parentheses.  Admittedly ugly.
1086     if src_contents != dst_contents:
1087         return _format_str_once(dst_contents, mode=mode)
1088     return dst_contents
1089
1090
1091 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1092     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1093     dst_blocks: List[LinesBlock] = []
1094     if mode.target_versions:
1095         versions = mode.target_versions
1096     else:
1097         future_imports = get_future_imports(src_node)
1098         versions = detect_target_versions(src_node, future_imports=future_imports)
1099
1100     context_manager_features = {
1101         feature
1102         for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
1103         if supports_feature(versions, feature)
1104     }
1105     normalize_fmt_off(src_node)
1106     lines = LineGenerator(mode=mode, features=context_manager_features)
1107     elt = EmptyLineTracker(mode=mode)
1108     split_line_features = {
1109         feature
1110         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1111         if supports_feature(versions, feature)
1112     }
1113     block: Optional[LinesBlock] = None
1114     for current_line in lines.visit(src_node):
1115         block = elt.maybe_empty_lines(current_line)
1116         dst_blocks.append(block)
1117         for line in transform_line(
1118             current_line, mode=mode, features=split_line_features
1119         ):
1120             block.content_lines.append(str(line))
1121     if dst_blocks:
1122         dst_blocks[-1].after = 0
1123     dst_contents = []
1124     for block in dst_blocks:
1125         dst_contents.extend(block.all_lines())
1126     if not dst_contents:
1127         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1128         # and check if normalized_content has more than one line
1129         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1130         if "\n" in normalized_content:
1131             return newline
1132         return ""
1133     return "".join(dst_contents)
1134
1135
1136 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1137     """Return a tuple of (decoded_contents, encoding, newline).
1138
1139     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1140     universal newlines (i.e. only contains LF).
1141     """
1142     srcbuf = io.BytesIO(src)
1143     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1144     if not lines:
1145         return "", encoding, "\n"
1146
1147     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1148     srcbuf.seek(0)
1149     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1150         return tiow.read(), encoding, newline
1151
1152
1153 def get_features_used(  # noqa: C901
1154     node: Node, *, future_imports: Optional[Set[str]] = None
1155 ) -> Set[Feature]:
1156     """Return a set of (relatively) new Python features used in this file.
1157
1158     Currently looking for:
1159     - f-strings;
1160     - self-documenting expressions in f-strings (f"{x=}");
1161     - underscores in numeric literals;
1162     - trailing commas after * or ** in function signatures and calls;
1163     - positional only arguments in function signatures and lambdas;
1164     - assignment expression;
1165     - relaxed decorator syntax;
1166     - usage of __future__ flags (annotations);
1167     - print / exec statements;
1168     - parenthesized context managers;
1169     - match statements;
1170     - except* clause;
1171     - variadic generics;
1172     """
1173     features: Set[Feature] = set()
1174     if future_imports:
1175         features |= {
1176             FUTURE_FLAG_TO_FEATURE[future_import]
1177             for future_import in future_imports
1178             if future_import in FUTURE_FLAG_TO_FEATURE
1179         }
1180
1181     for n in node.pre_order():
1182         if is_string_token(n):
1183             value_head = n.value[:2]
1184             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1185                 features.add(Feature.F_STRINGS)
1186                 if Feature.DEBUG_F_STRINGS not in features:
1187                     for span_beg, span_end in iter_fexpr_spans(n.value):
1188                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1189                             features.add(Feature.DEBUG_F_STRINGS)
1190                             break
1191
1192         elif is_number_token(n):
1193             if "_" in n.value:
1194                 features.add(Feature.NUMERIC_UNDERSCORES)
1195
1196         elif n.type == token.SLASH:
1197             if n.parent and n.parent.type in {
1198                 syms.typedargslist,
1199                 syms.arglist,
1200                 syms.varargslist,
1201             }:
1202                 features.add(Feature.POS_ONLY_ARGUMENTS)
1203
1204         elif n.type == token.COLONEQUAL:
1205             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1206
1207         elif n.type == syms.decorator:
1208             if len(n.children) > 1 and not is_simple_decorator_expression(
1209                 n.children[1]
1210             ):
1211                 features.add(Feature.RELAXED_DECORATORS)
1212
1213         elif (
1214             n.type in {syms.typedargslist, syms.arglist}
1215             and n.children
1216             and n.children[-1].type == token.COMMA
1217         ):
1218             if n.type == syms.typedargslist:
1219                 feature = Feature.TRAILING_COMMA_IN_DEF
1220             else:
1221                 feature = Feature.TRAILING_COMMA_IN_CALL
1222
1223             for ch in n.children:
1224                 if ch.type in STARS:
1225                     features.add(feature)
1226
1227                 if ch.type == syms.argument:
1228                     for argch in ch.children:
1229                         if argch.type in STARS:
1230                             features.add(feature)
1231
1232         elif (
1233             n.type in {syms.return_stmt, syms.yield_expr}
1234             and len(n.children) >= 2
1235             and n.children[1].type == syms.testlist_star_expr
1236             and any(child.type == syms.star_expr for child in n.children[1].children)
1237         ):
1238             features.add(Feature.UNPACKING_ON_FLOW)
1239
1240         elif (
1241             n.type == syms.annassign
1242             and len(n.children) >= 4
1243             and n.children[3].type == syms.testlist_star_expr
1244         ):
1245             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1246
1247         elif (
1248             n.type == syms.with_stmt
1249             and len(n.children) > 2
1250             and n.children[1].type == syms.atom
1251         ):
1252             atom_children = n.children[1].children
1253             if (
1254                 len(atom_children) == 3
1255                 and atom_children[0].type == token.LPAR
1256                 and atom_children[1].type == syms.testlist_gexp
1257                 and atom_children[2].type == token.RPAR
1258             ):
1259                 features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
1260
1261         elif n.type == syms.match_stmt:
1262             features.add(Feature.PATTERN_MATCHING)
1263
1264         elif (
1265             n.type == syms.except_clause
1266             and len(n.children) >= 2
1267             and n.children[1].type == token.STAR
1268         ):
1269             features.add(Feature.EXCEPT_STAR)
1270
1271         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1272             child.type == syms.star_expr for child in n.children
1273         ):
1274             features.add(Feature.VARIADIC_GENERICS)
1275
1276         elif (
1277             n.type == syms.tname_star
1278             and len(n.children) == 3
1279             and n.children[2].type == syms.star_expr
1280         ):
1281             features.add(Feature.VARIADIC_GENERICS)
1282
1283         elif n.type in (syms.type_stmt, syms.typeparams):
1284             features.add(Feature.TYPE_PARAMS)
1285
1286     return features
1287
1288
1289 def detect_target_versions(
1290     node: Node, *, future_imports: Optional[Set[str]] = None
1291 ) -> Set[TargetVersion]:
1292     """Detect the version to target based on the nodes used."""
1293     features = get_features_used(node, future_imports=future_imports)
1294     return {
1295         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1296     }
1297
1298
1299 def get_future_imports(node: Node) -> Set[str]:
1300     """Return a set of __future__ imports in the file."""
1301     imports: Set[str] = set()
1302
1303     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1304         for child in children:
1305             if isinstance(child, Leaf):
1306                 if child.type == token.NAME:
1307                     yield child.value
1308
1309             elif child.type == syms.import_as_name:
1310                 orig_name = child.children[0]
1311                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1312                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1313                 yield orig_name.value
1314
1315             elif child.type == syms.import_as_names:
1316                 yield from get_imports_from_children(child.children)
1317
1318             else:
1319                 raise AssertionError("Invalid syntax parsing imports")
1320
1321     for child in node.children:
1322         if child.type != syms.simple_stmt:
1323             break
1324
1325         first_child = child.children[0]
1326         if isinstance(first_child, Leaf):
1327             # Continue looking if we see a docstring; otherwise stop.
1328             if (
1329                 len(child.children) == 2
1330                 and first_child.type == token.STRING
1331                 and child.children[1].type == token.NEWLINE
1332             ):
1333                 continue
1334
1335             break
1336
1337         elif first_child.type == syms.import_from:
1338             module_name = first_child.children[1]
1339             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1340                 break
1341
1342             imports |= set(get_imports_from_children(first_child.children[3:]))
1343         else:
1344             break
1345
1346     return imports
1347
1348
1349 def assert_equivalent(src: str, dst: str) -> None:
1350     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1351     try:
1352         src_ast = parse_ast(src)
1353     except Exception as exc:
1354         raise AssertionError(
1355             "cannot use --safe with this file; failed to parse source file AST: "
1356             f"{exc}\n"
1357             "This could be caused by running Black with an older Python version "
1358             "that does not support new syntax used in your source file."
1359         ) from exc
1360
1361     try:
1362         dst_ast = parse_ast(dst)
1363     except Exception as exc:
1364         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1365         raise AssertionError(
1366             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1367             "Please report a bug on https://github.com/psf/black/issues.  "
1368             f"This invalid output might be helpful: {log}"
1369         ) from None
1370
1371     src_ast_str = "\n".join(stringify_ast(src_ast))
1372     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1373     if src_ast_str != dst_ast_str:
1374         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1375         raise AssertionError(
1376             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1377             " source.  Please report a bug on "
1378             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1379         ) from None
1380
1381
1382 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1383     """Raise AssertionError if `dst` reformats differently the second time."""
1384     # We shouldn't call format_str() here, because that formats the string
1385     # twice and may hide a bug where we bounce back and forth between two
1386     # versions.
1387     newdst = _format_str_once(dst, mode=mode)
1388     if dst != newdst:
1389         log = dump_to_file(
1390             str(mode),
1391             diff(src, dst, "source", "first pass"),
1392             diff(dst, newdst, "first pass", "second pass"),
1393         )
1394         raise AssertionError(
1395             "INTERNAL ERROR: Black produced different code on the second pass of the"
1396             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1397             f"  This diff might be helpful: {log}"
1398         ) from None
1399
1400
1401 @contextmanager
1402 def nullcontext() -> Iterator[None]:
1403     """Return an empty context manager.
1404
1405     To be used like `nullcontext` in Python 3.7.
1406     """
1407     yield
1408
1409
1410 def patched_main() -> None:
1411     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1412     # environments so just assume we always need to call it if frozen.
1413     if getattr(sys, "frozen", False):
1414         from multiprocessing import freeze_support
1415
1416         freeze_support()
1417
1418     main()
1419
1420
1421 if __name__ == "__main__":
1422     patched_main()