]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

3451c86b50893859be638e1a608a58e1ebf15881
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime, timezone
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache, get_cache_info, read_cache, write_cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import (
67     FUTURE_FLAG_TO_FEATURE,
68     VERSION_TO_FEATURES,
69     Feature,
70     Mode,
71     TargetVersion,
72     supports_feature,
73 )
74 from black.nodes import (
75     STARS,
76     is_number_token,
77     is_simple_decorator_expression,
78     is_string_token,
79     syms,
80 )
81 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
82 from black.parsing import InvalidInput  # noqa F401
83 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
84 from black.report import Changed, NothingChanged, Report
85 from black.trans import iter_fexpr_spans
86 from blib2to3.pgen2 import token
87 from blib2to3.pytree import Leaf, Node
88
89 COMPILED = Path(__file__).suffix in (".pyd", ".so")
90
91 # types
92 FileContent = str
93 Encoding = str
94 NewLine = str
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102     COLOR_DIFF = 4
103
104     @classmethod
105     def from_configuration(
106         cls, *, check: bool, diff: bool, color: bool = False
107     ) -> "WriteBack":
108         if check and not diff:
109             return cls.CHECK
110
111         if diff and color:
112             return cls.COLOR_DIFF
113
114         return cls.DIFF if diff else cls.YES
115
116
117 # Legacy name, left for integrations.
118 FileMode = Mode
119
120
121 def read_pyproject_toml(
122     ctx: click.Context, param: click.Parameter, value: Optional[str]
123 ) -> Optional[str]:
124     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
125
126     Returns the path to a successfully found and read configuration file, None
127     otherwise.
128     """
129     if not value:
130         value = find_pyproject_toml(
131             ctx.params.get("src", ()), ctx.params.get("stdin_filename", None)
132         )
133         if value is None:
134             return None
135
136     try:
137         config = parse_pyproject_toml(value)
138     except (OSError, ValueError) as e:
139         raise click.FileError(
140             filename=value, hint=f"Error reading configuration file: {e}"
141         ) from None
142
143     if not config:
144         return None
145     else:
146         # Sanitize the values to be Click friendly. For more information please see:
147         # https://github.com/psf/black/issues/1458
148         # https://github.com/pallets/click/issues/1567
149         config = {
150             k: str(v) if not isinstance(v, (list, dict)) else v
151             for k, v in config.items()
152         }
153
154     target_version = config.get("target_version")
155     if target_version is not None and not isinstance(target_version, list):
156         raise click.BadOptionUsage(
157             "target-version", "Config key target-version must be a list"
158         )
159
160     default_map: Dict[str, Any] = {}
161     if ctx.default_map:
162         default_map.update(ctx.default_map)
163     default_map.update(config)
164
165     ctx.default_map = default_map
166     return value
167
168
169 def target_version_option_callback(
170     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
171 ) -> List[TargetVersion]:
172     """Compute the target versions from a --target-version flag.
173
174     This is its own function because mypy couldn't infer the type correctly
175     when it was a lambda, causing mypyc trouble.
176     """
177     return [TargetVersion[val.upper()] for val in v]
178
179
180 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
181     """Compile a regular expression string in `regex`.
182
183     If it contains newlines, use verbose mode.
184     """
185     if "\n" in regex:
186         regex = "(?x)" + regex
187     compiled: Pattern[str] = re.compile(regex)
188     return compiled
189
190
191 def validate_regex(
192     ctx: click.Context,
193     param: click.Parameter,
194     value: Optional[str],
195 ) -> Optional[Pattern[str]]:
196     try:
197         return re_compile_maybe_verbose(value) if value is not None else None
198     except re.error as e:
199         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
200
201
202 @click.command(
203     context_settings={"help_option_names": ["-h", "--help"]},
204     # While Click does set this field automatically using the docstring, mypyc
205     # (annoyingly) strips 'em so we need to set it here too.
206     help="The uncompromising code formatter.",
207 )
208 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
209 @click.option(
210     "-l",
211     "--line-length",
212     type=int,
213     default=DEFAULT_LINE_LENGTH,
214     help="How many characters per line to allow.",
215     show_default=True,
216 )
217 @click.option(
218     "-t",
219     "--target-version",
220     type=click.Choice([v.name.lower() for v in TargetVersion]),
221     callback=target_version_option_callback,
222     multiple=True,
223     help=(
224         "Python versions that should be supported by Black's output. By default, Black"
225         " will try to infer this from the project metadata in pyproject.toml. If this"
226         " does not yield conclusive results, Black will use per-file auto-detection."
227     ),
228 )
229 @click.option(
230     "--pyi",
231     is_flag=True,
232     help=(
233         "Format all input files like typing stubs regardless of file extension (useful"
234         " when piping source on standard input)."
235     ),
236 )
237 @click.option(
238     "--ipynb",
239     is_flag=True,
240     help=(
241         "Format all input files like Jupyter Notebooks regardless of file extension "
242         "(useful when piping source on standard input)."
243     ),
244 )
245 @click.option(
246     "--python-cell-magics",
247     multiple=True,
248     help=(
249         "When processing Jupyter Notebooks, add the given magic to the list"
250         f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
251         " Useful for formatting cells with custom python magics."
252     ),
253     default=[],
254 )
255 @click.option(
256     "-x",
257     "--skip-source-first-line",
258     is_flag=True,
259     help="Skip the first line of the source code.",
260 )
261 @click.option(
262     "-S",
263     "--skip-string-normalization",
264     is_flag=True,
265     help="Don't normalize string quotes or prefixes.",
266 )
267 @click.option(
268     "-C",
269     "--skip-magic-trailing-comma",
270     is_flag=True,
271     help="Don't use trailing commas as a reason to split lines.",
272 )
273 @click.option(
274     "--experimental-string-processing",
275     is_flag=True,
276     hidden=True,
277     help="(DEPRECATED and now included in --preview) Normalize string literals.",
278 )
279 @click.option(
280     "--preview",
281     is_flag=True,
282     help=(
283         "Enable potentially disruptive style changes that may be added to Black's main"
284         " functionality in the next major release."
285     ),
286 )
287 @click.option(
288     "--check",
289     is_flag=True,
290     help=(
291         "Don't write the files back, just return the status. Return code 0 means"
292         " nothing would change. Return code 1 means some files would be reformatted."
293         " Return code 123 means there was an internal error."
294     ),
295 )
296 @click.option(
297     "--diff",
298     is_flag=True,
299     help="Don't write the files back, just output a diff for each file on stdout.",
300 )
301 @click.option(
302     "--color/--no-color",
303     is_flag=True,
304     help="Show colored diff. Only applies when `--diff` is given.",
305 )
306 @click.option(
307     "--fast/--safe",
308     is_flag=True,
309     help="If --fast given, skip temporary sanity checks. [default: --safe]",
310 )
311 @click.option(
312     "--required-version",
313     type=str,
314     help=(
315         "Require a specific version of Black to be running (useful for unifying results"
316         " across many environments e.g. with a pyproject.toml file). It can be"
317         " either a major version number or an exact version."
318     ),
319 )
320 @click.option(
321     "--include",
322     type=str,
323     default=DEFAULT_INCLUDES,
324     callback=validate_regex,
325     help=(
326         "A regular expression that matches files and directories that should be"
327         " included on recursive searches. An empty value means all files are included"
328         " regardless of the name. Use forward slashes for directories on all platforms"
329         " (Windows, too). Exclusions are calculated first, inclusions later."
330     ),
331     show_default=True,
332 )
333 @click.option(
334     "--exclude",
335     type=str,
336     callback=validate_regex,
337     help=(
338         "A regular expression that matches files and directories that should be"
339         " excluded on recursive searches. An empty value means no paths are excluded."
340         " Use forward slashes for directories on all platforms (Windows, too)."
341         " Exclusions are calculated first, inclusions later. [default:"
342         f" {DEFAULT_EXCLUDES}]"
343     ),
344     show_default=False,
345 )
346 @click.option(
347     "--extend-exclude",
348     type=str,
349     callback=validate_regex,
350     help=(
351         "Like --exclude, but adds additional files and directories on top of the"
352         " excluded ones. (Useful if you simply want to add to the default)"
353     ),
354 )
355 @click.option(
356     "--force-exclude",
357     type=str,
358     callback=validate_regex,
359     help=(
360         "Like --exclude, but files and directories matching this regex will be "
361         "excluded even when they are passed explicitly as arguments."
362     ),
363 )
364 @click.option(
365     "--stdin-filename",
366     type=str,
367     is_eager=True,
368     help=(
369         "The name of the file when passing it through stdin. Useful to make "
370         "sure Black will respect --force-exclude option on some "
371         "editors that rely on using stdin."
372     ),
373 )
374 @click.option(
375     "-W",
376     "--workers",
377     type=click.IntRange(min=1),
378     default=None,
379     help=(
380         "Number of parallel workers [default: BLACK_NUM_WORKERS environment variable "
381         "or number of CPUs in the system]"
382     ),
383 )
384 @click.option(
385     "-q",
386     "--quiet",
387     is_flag=True,
388     help=(
389         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
390         " those with 2>/dev/null."
391     ),
392 )
393 @click.option(
394     "-v",
395     "--verbose",
396     is_flag=True,
397     help=(
398         "Also emit messages to stderr about files that were not changed or were ignored"
399         " due to exclusion patterns."
400     ),
401 )
402 @click.version_option(
403     version=__version__,
404     message=(
405         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
406         f"Python ({platform.python_implementation()}) {platform.python_version()}"
407     ),
408 )
409 @click.argument(
410     "src",
411     nargs=-1,
412     type=click.Path(
413         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
414     ),
415     is_eager=True,
416     metavar="SRC ...",
417 )
418 @click.option(
419     "--config",
420     type=click.Path(
421         exists=True,
422         file_okay=True,
423         dir_okay=False,
424         readable=True,
425         allow_dash=False,
426         path_type=str,
427     ),
428     is_eager=True,
429     callback=read_pyproject_toml,
430     help="Read configuration from FILE path.",
431 )
432 @click.pass_context
433 def main(  # noqa: C901
434     ctx: click.Context,
435     code: Optional[str],
436     line_length: int,
437     target_version: List[TargetVersion],
438     check: bool,
439     diff: bool,
440     color: bool,
441     fast: bool,
442     pyi: bool,
443     ipynb: bool,
444     python_cell_magics: Sequence[str],
445     skip_source_first_line: bool,
446     skip_string_normalization: bool,
447     skip_magic_trailing_comma: bool,
448     experimental_string_processing: bool,
449     preview: bool,
450     quiet: bool,
451     verbose: bool,
452     required_version: Optional[str],
453     include: Pattern[str],
454     exclude: Optional[Pattern[str]],
455     extend_exclude: Optional[Pattern[str]],
456     force_exclude: Optional[Pattern[str]],
457     stdin_filename: Optional[str],
458     workers: Optional[int],
459     src: Tuple[str, ...],
460     config: Optional[str],
461 ) -> None:
462     """The uncompromising code formatter."""
463     ctx.ensure_object(dict)
464
465     if src and code is not None:
466         out(
467             main.get_usage(ctx)
468             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
469         )
470         ctx.exit(1)
471     if not src and code is None:
472         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
473         ctx.exit(1)
474
475     root, method = (
476         find_project_root(src, stdin_filename) if code is None else (None, None)
477     )
478     ctx.obj["root"] = root
479
480     if verbose:
481         if root:
482             out(
483                 f"Identified `{root}` as project root containing a {method}.",
484                 fg="blue",
485             )
486
487             normalized = [
488                 (
489                     (source, source)
490                     if source == "-"
491                     else (normalize_path_maybe_ignore(Path(source), root), source)
492                 )
493                 for source in src
494             ]
495             srcs_string = ", ".join(
496                 [
497                     (
498                         f'"{_norm}"'
499                         if _norm
500                         else f'\033[31m"{source} (skipping - invalid)"\033[34m'
501                     )
502                     for _norm, source in normalized
503                 ]
504             )
505             out(f"Sources to be formatted: {srcs_string}", fg="blue")
506
507         if config:
508             config_source = ctx.get_parameter_source("config")
509             user_level_config = str(find_user_pyproject_toml())
510             if config == user_level_config:
511                 out(
512                     "Using configuration from user-level config at "
513                     f"'{user_level_config}'.",
514                     fg="blue",
515                 )
516             elif config_source in (
517                 ParameterSource.DEFAULT,
518                 ParameterSource.DEFAULT_MAP,
519             ):
520                 out("Using configuration from project root.", fg="blue")
521             else:
522                 out(f"Using configuration in '{config}'.", fg="blue")
523             if ctx.default_map:
524                 for param, value in ctx.default_map.items():
525                     out(f"{param}: {value}")
526
527     error_msg = "Oh no! 💥 💔 💥"
528     if (
529         required_version
530         and required_version != __version__
531         and required_version != __version__.split(".")[0]
532     ):
533         err(
534             f"{error_msg} The required version `{required_version}` does not match"
535             f" the running version `{__version__}`!"
536         )
537         ctx.exit(1)
538     if ipynb and pyi:
539         err("Cannot pass both `pyi` and `ipynb` flags!")
540         ctx.exit(1)
541
542     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
543     if target_version:
544         versions = set(target_version)
545     else:
546         # We'll autodetect later.
547         versions = set()
548     mode = Mode(
549         target_versions=versions,
550         line_length=line_length,
551         is_pyi=pyi,
552         is_ipynb=ipynb,
553         skip_source_first_line=skip_source_first_line,
554         string_normalization=not skip_string_normalization,
555         magic_trailing_comma=not skip_magic_trailing_comma,
556         experimental_string_processing=experimental_string_processing,
557         preview=preview,
558         python_cell_magics=set(python_cell_magics),
559     )
560
561     if code is not None:
562         # Run in quiet mode by default with -c; the extra output isn't useful.
563         # You can still pass -v to get verbose output.
564         quiet = True
565
566     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
567
568     if code is not None:
569         reformat_code(
570             content=code, fast=fast, write_back=write_back, mode=mode, report=report
571         )
572     else:
573         try:
574             sources = get_sources(
575                 ctx=ctx,
576                 src=src,
577                 quiet=quiet,
578                 verbose=verbose,
579                 include=include,
580                 exclude=exclude,
581                 extend_exclude=extend_exclude,
582                 force_exclude=force_exclude,
583                 report=report,
584                 stdin_filename=stdin_filename,
585             )
586         except GitWildMatchPatternError:
587             ctx.exit(1)
588
589         path_empty(
590             sources,
591             "No Python files are present to be formatted. Nothing to do 😴",
592             quiet,
593             verbose,
594             ctx,
595         )
596
597         if len(sources) == 1:
598             reformat_one(
599                 src=sources.pop(),
600                 fast=fast,
601                 write_back=write_back,
602                 mode=mode,
603                 report=report,
604             )
605         else:
606             from black.concurrency import reformat_many
607
608             reformat_many(
609                 sources=sources,
610                 fast=fast,
611                 write_back=write_back,
612                 mode=mode,
613                 report=report,
614                 workers=workers,
615             )
616
617     if verbose or not quiet:
618         if code is None and (verbose or report.change_count or report.failure_count):
619             out()
620         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
621         if code is None:
622             click.echo(str(report), err=True)
623     ctx.exit(report.return_code)
624
625
626 def get_sources(
627     *,
628     ctx: click.Context,
629     src: Tuple[str, ...],
630     quiet: bool,
631     verbose: bool,
632     include: Pattern[str],
633     exclude: Optional[Pattern[str]],
634     extend_exclude: Optional[Pattern[str]],
635     force_exclude: Optional[Pattern[str]],
636     report: "Report",
637     stdin_filename: Optional[str],
638 ) -> Set[Path]:
639     """Compute the set of files to be formatted."""
640     sources: Set[Path] = set()
641     root = ctx.obj["root"]
642
643     using_default_exclude = exclude is None
644     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
645     gitignore: Optional[Dict[Path, PathSpec]] = None
646     root_gitignore = get_gitignore(root)
647
648     for s in src:
649         if s == "-" and stdin_filename:
650             p = Path(stdin_filename)
651             is_stdin = True
652         else:
653             p = Path(s)
654             is_stdin = False
655
656         if is_stdin or p.is_file():
657             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
658             if normalized_path is None:
659                 continue
660
661             normalized_path = "/" + normalized_path
662             # Hard-exclude any files that matches the `--force-exclude` regex.
663             if force_exclude:
664                 force_exclude_match = force_exclude.search(normalized_path)
665             else:
666                 force_exclude_match = None
667             if force_exclude_match and force_exclude_match.group(0):
668                 report.path_ignored(p, "matches the --force-exclude regular expression")
669                 continue
670
671             if is_stdin:
672                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
673
674             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
675                 verbose=verbose, quiet=quiet
676             ):
677                 continue
678
679             sources.add(p)
680         elif p.is_dir():
681             p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
682             if using_default_exclude:
683                 gitignore = {
684                     root: root_gitignore,
685                     p: get_gitignore(p),
686                 }
687             sources.update(
688                 gen_python_files(
689                     p.iterdir(),
690                     ctx.obj["root"],
691                     include,
692                     exclude,
693                     extend_exclude,
694                     force_exclude,
695                     report,
696                     gitignore,
697                     verbose=verbose,
698                     quiet=quiet,
699                 )
700             )
701         elif s == "-":
702             sources.add(p)
703         else:
704             err(f"invalid path: {s}")
705     return sources
706
707
708 def path_empty(
709     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
710 ) -> None:
711     """
712     Exit if there is no `src` provided for formatting
713     """
714     if not src:
715         if verbose or not quiet:
716             out(msg)
717         ctx.exit(0)
718
719
720 def reformat_code(
721     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
722 ) -> None:
723     """
724     Reformat and print out `content` without spawning child processes.
725     Similar to `reformat_one`, but for string content.
726
727     `fast`, `write_back`, and `mode` options are passed to
728     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
729     """
730     path = Path("<string>")
731     try:
732         changed = Changed.NO
733         if format_stdin_to_stdout(
734             content=content, fast=fast, write_back=write_back, mode=mode
735         ):
736             changed = Changed.YES
737         report.done(path, changed)
738     except Exception as exc:
739         if report.verbose:
740             traceback.print_exc()
741         report.failed(path, str(exc))
742
743
744 # diff-shades depends on being to monkeypatch this function to operate. I know it's
745 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
746 @mypyc_attr(patchable=True)
747 def reformat_one(
748     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
749 ) -> None:
750     """Reformat a single file under `src` without spawning child processes.
751
752     `fast`, `write_back`, and `mode` options are passed to
753     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
754     """
755     try:
756         changed = Changed.NO
757
758         if str(src) == "-":
759             is_stdin = True
760         elif str(src).startswith(STDIN_PLACEHOLDER):
761             is_stdin = True
762             # Use the original name again in case we want to print something
763             # to the user
764             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
765         else:
766             is_stdin = False
767
768         if is_stdin:
769             if src.suffix == ".pyi":
770                 mode = replace(mode, is_pyi=True)
771             elif src.suffix == ".ipynb":
772                 mode = replace(mode, is_ipynb=True)
773             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
774                 changed = Changed.YES
775         else:
776             cache: Cache = {}
777             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
778                 cache = read_cache(mode)
779                 res_src = src.resolve()
780                 res_src_s = str(res_src)
781                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
782                     changed = Changed.CACHED
783             if changed is not Changed.CACHED and format_file_in_place(
784                 src, fast=fast, write_back=write_back, mode=mode
785             ):
786                 changed = Changed.YES
787             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
788                 write_back is WriteBack.CHECK and changed is Changed.NO
789             ):
790                 write_cache(cache, [src], mode)
791         report.done(src, changed)
792     except Exception as exc:
793         if report.verbose:
794             traceback.print_exc()
795         report.failed(src, str(exc))
796
797
798 def format_file_in_place(
799     src: Path,
800     fast: bool,
801     mode: Mode,
802     write_back: WriteBack = WriteBack.NO,
803     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
804 ) -> bool:
805     """Format file under `src` path. Return True if changed.
806
807     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
808     code to the file.
809     `mode` and `fast` options are passed to :func:`format_file_contents`.
810     """
811     if src.suffix == ".pyi":
812         mode = replace(mode, is_pyi=True)
813     elif src.suffix == ".ipynb":
814         mode = replace(mode, is_ipynb=True)
815
816     then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)
817     header = b""
818     with open(src, "rb") as buf:
819         if mode.skip_source_first_line:
820             header = buf.readline()
821         src_contents, encoding, newline = decode_bytes(buf.read())
822     try:
823         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
824     except NothingChanged:
825         return False
826     except JSONDecodeError:
827         raise ValueError(
828             f"File '{src}' cannot be parsed as valid Jupyter notebook."
829         ) from None
830     src_contents = header.decode(encoding) + src_contents
831     dst_contents = header.decode(encoding) + dst_contents
832
833     if write_back == WriteBack.YES:
834         with open(src, "w", encoding=encoding, newline=newline) as f:
835             f.write(dst_contents)
836     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
837         now = datetime.now(timezone.utc)
838         src_name = f"{src}\t{then}"
839         dst_name = f"{src}\t{now}"
840         if mode.is_ipynb:
841             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
842         else:
843             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
844
845         if write_back == WriteBack.COLOR_DIFF:
846             diff_contents = color_diff(diff_contents)
847
848         with lock or nullcontext():
849             f = io.TextIOWrapper(
850                 sys.stdout.buffer,
851                 encoding=encoding,
852                 newline=newline,
853                 write_through=True,
854             )
855             f = wrap_stream_for_windows(f)
856             f.write(diff_contents)
857             f.detach()
858
859     return True
860
861
862 def format_stdin_to_stdout(
863     fast: bool,
864     *,
865     content: Optional[str] = None,
866     write_back: WriteBack = WriteBack.NO,
867     mode: Mode,
868 ) -> bool:
869     """Format file on stdin. Return True if changed.
870
871     If content is None, it's read from sys.stdin.
872
873     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
874     write a diff to stdout. The `mode` argument is passed to
875     :func:`format_file_contents`.
876     """
877     then = datetime.now(timezone.utc)
878
879     if content is None:
880         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
881     else:
882         src, encoding, newline = content, "utf-8", ""
883
884     dst = src
885     try:
886         dst = format_file_contents(src, fast=fast, mode=mode)
887         return True
888
889     except NothingChanged:
890         return False
891
892     finally:
893         f = io.TextIOWrapper(
894             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
895         )
896         if write_back == WriteBack.YES:
897             # Make sure there's a newline after the content
898             if dst and dst[-1] != "\n":
899                 dst += "\n"
900             f.write(dst)
901         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
902             now = datetime.now(timezone.utc)
903             src_name = f"STDIN\t{then}"
904             dst_name = f"STDOUT\t{now}"
905             d = diff(src, dst, src_name, dst_name)
906             if write_back == WriteBack.COLOR_DIFF:
907                 d = color_diff(d)
908                 f = wrap_stream_for_windows(f)
909             f.write(d)
910         f.detach()
911
912
913 def check_stability_and_equivalence(
914     src_contents: str, dst_contents: str, *, mode: Mode
915 ) -> None:
916     """Perform stability and equivalence checks.
917
918     Raise AssertionError if source and destination contents are not
919     equivalent, or if a second pass of the formatter would format the
920     content differently.
921     """
922     assert_equivalent(src_contents, dst_contents)
923     assert_stable(src_contents, dst_contents, mode=mode)
924
925
926 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
927     """Reformat contents of a file and return new contents.
928
929     If `fast` is False, additionally confirm that the reformatted code is
930     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
931     `mode` is passed to :func:`format_str`.
932     """
933     if mode.is_ipynb:
934         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
935     else:
936         dst_contents = format_str(src_contents, mode=mode)
937     if src_contents == dst_contents:
938         raise NothingChanged
939
940     if not fast and not mode.is_ipynb:
941         # Jupyter notebooks will already have been checked above.
942         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
943     return dst_contents
944
945
946 def validate_cell(src: str, mode: Mode) -> None:
947     """Check that cell does not already contain TransformerManager transformations,
948     or non-Python cell magics, which might cause tokenizer_rt to break because of
949     indentations.
950
951     If a cell contains ``!ls``, then it'll be transformed to
952     ``get_ipython().system('ls')``. However, if the cell originally contained
953     ``get_ipython().system('ls')``, then it would get transformed in the same way:
954
955         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
956         "get_ipython().system('ls')\n"
957         >>> TransformerManager().transform_cell("!ls")
958         "get_ipython().system('ls')\n"
959
960     Due to the impossibility of safely roundtripping in such situations, cells
961     containing transformed magics will be ignored.
962     """
963     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
964         raise NothingChanged
965     if (
966         src[:2] == "%%"
967         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
968     ):
969         raise NothingChanged
970
971
972 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
973     """Format code in given cell of Jupyter notebook.
974
975     General idea is:
976
977       - if cell has trailing semicolon, remove it;
978       - if cell has IPython magics, mask them;
979       - format cell;
980       - reinstate IPython magics;
981       - reinstate trailing semicolon (if originally present);
982       - strip trailing newlines.
983
984     Cells with syntax errors will not be processed, as they
985     could potentially be automagics or multi-line magics, which
986     are currently not supported.
987     """
988     validate_cell(src, mode)
989     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
990         src
991     )
992     try:
993         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
994     except SyntaxError:
995         raise NothingChanged from None
996     masked_dst = format_str(masked_src, mode=mode)
997     if not fast:
998         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
999     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1000     dst = put_trailing_semicolon_back(
1001         dst_without_trailing_semicolon, has_trailing_semicolon
1002     )
1003     dst = dst.rstrip("\n")
1004     if dst == src:
1005         raise NothingChanged from None
1006     return dst
1007
1008
1009 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1010     """If notebook is marked as non-Python, don't format it.
1011
1012     All notebook metadata fields are optional, see
1013     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1014     if a notebook has empty metadata, we will try to parse it anyway.
1015     """
1016     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1017     if language is not None and language != "python":
1018         raise NothingChanged from None
1019
1020
1021 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1022     """Format Jupyter notebook.
1023
1024     Operate cell-by-cell, only on code cells, only for Python notebooks.
1025     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1026     """
1027     if not src_contents:
1028         raise NothingChanged
1029
1030     trailing_newline = src_contents[-1] == "\n"
1031     modified = False
1032     nb = json.loads(src_contents)
1033     validate_metadata(nb)
1034     for cell in nb["cells"]:
1035         if cell.get("cell_type", None) == "code":
1036             try:
1037                 src = "".join(cell["source"])
1038                 dst = format_cell(src, fast=fast, mode=mode)
1039             except NothingChanged:
1040                 pass
1041             else:
1042                 cell["source"] = dst.splitlines(keepends=True)
1043                 modified = True
1044     if modified:
1045         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1046         if trailing_newline:
1047             dst_contents = dst_contents + "\n"
1048         return dst_contents
1049     else:
1050         raise NothingChanged
1051
1052
1053 def format_str(src_contents: str, *, mode: Mode) -> str:
1054     """Reformat a string and return new contents.
1055
1056     `mode` determines formatting options, such as how many characters per line are
1057     allowed.  Example:
1058
1059     >>> import black
1060     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1061     def f(arg: str = "") -> None:
1062         ...
1063
1064     A more complex example:
1065
1066     >>> print(
1067     ...   black.format_str(
1068     ...     "def f(arg:str='')->None: hey",
1069     ...     mode=black.Mode(
1070     ...       target_versions={black.TargetVersion.PY36},
1071     ...       line_length=10,
1072     ...       string_normalization=False,
1073     ...       is_pyi=False,
1074     ...     ),
1075     ...   ),
1076     ... )
1077     def f(
1078         arg: str = '',
1079     ) -> None:
1080         hey
1081
1082     """
1083     dst_contents = _format_str_once(src_contents, mode=mode)
1084     # Forced second pass to work around optional trailing commas (becoming
1085     # forced trailing commas on pass 2) interacting differently with optional
1086     # parentheses.  Admittedly ugly.
1087     if src_contents != dst_contents:
1088         return _format_str_once(dst_contents, mode=mode)
1089     return dst_contents
1090
1091
1092 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1093     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1094     dst_blocks: List[LinesBlock] = []
1095     if mode.target_versions:
1096         versions = mode.target_versions
1097     else:
1098         future_imports = get_future_imports(src_node)
1099         versions = detect_target_versions(src_node, future_imports=future_imports)
1100
1101     context_manager_features = {
1102         feature
1103         for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
1104         if supports_feature(versions, feature)
1105     }
1106     normalize_fmt_off(src_node)
1107     lines = LineGenerator(mode=mode, features=context_manager_features)
1108     elt = EmptyLineTracker(mode=mode)
1109     split_line_features = {
1110         feature
1111         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1112         if supports_feature(versions, feature)
1113     }
1114     block: Optional[LinesBlock] = None
1115     for current_line in lines.visit(src_node):
1116         block = elt.maybe_empty_lines(current_line)
1117         dst_blocks.append(block)
1118         for line in transform_line(
1119             current_line, mode=mode, features=split_line_features
1120         ):
1121             block.content_lines.append(str(line))
1122     if dst_blocks:
1123         dst_blocks[-1].after = 0
1124     dst_contents = []
1125     for block in dst_blocks:
1126         dst_contents.extend(block.all_lines())
1127     if not dst_contents:
1128         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1129         # and check if normalized_content has more than one line
1130         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1131         if "\n" in normalized_content:
1132             return newline
1133         return ""
1134     return "".join(dst_contents)
1135
1136
1137 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1138     """Return a tuple of (decoded_contents, encoding, newline).
1139
1140     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1141     universal newlines (i.e. only contains LF).
1142     """
1143     srcbuf = io.BytesIO(src)
1144     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1145     if not lines:
1146         return "", encoding, "\n"
1147
1148     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1149     srcbuf.seek(0)
1150     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1151         return tiow.read(), encoding, newline
1152
1153
1154 def get_features_used(  # noqa: C901
1155     node: Node, *, future_imports: Optional[Set[str]] = None
1156 ) -> Set[Feature]:
1157     """Return a set of (relatively) new Python features used in this file.
1158
1159     Currently looking for:
1160     - f-strings;
1161     - self-documenting expressions in f-strings (f"{x=}");
1162     - underscores in numeric literals;
1163     - trailing commas after * or ** in function signatures and calls;
1164     - positional only arguments in function signatures and lambdas;
1165     - assignment expression;
1166     - relaxed decorator syntax;
1167     - usage of __future__ flags (annotations);
1168     - print / exec statements;
1169     - parenthesized context managers;
1170     - match statements;
1171     - except* clause;
1172     - variadic generics;
1173     """
1174     features: Set[Feature] = set()
1175     if future_imports:
1176         features |= {
1177             FUTURE_FLAG_TO_FEATURE[future_import]
1178             for future_import in future_imports
1179             if future_import in FUTURE_FLAG_TO_FEATURE
1180         }
1181
1182     for n in node.pre_order():
1183         if is_string_token(n):
1184             value_head = n.value[:2]
1185             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1186                 features.add(Feature.F_STRINGS)
1187                 if Feature.DEBUG_F_STRINGS not in features:
1188                     for span_beg, span_end in iter_fexpr_spans(n.value):
1189                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1190                             features.add(Feature.DEBUG_F_STRINGS)
1191                             break
1192
1193         elif is_number_token(n):
1194             if "_" in n.value:
1195                 features.add(Feature.NUMERIC_UNDERSCORES)
1196
1197         elif n.type == token.SLASH:
1198             if n.parent and n.parent.type in {
1199                 syms.typedargslist,
1200                 syms.arglist,
1201                 syms.varargslist,
1202             }:
1203                 features.add(Feature.POS_ONLY_ARGUMENTS)
1204
1205         elif n.type == token.COLONEQUAL:
1206             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1207
1208         elif n.type == syms.decorator:
1209             if len(n.children) > 1 and not is_simple_decorator_expression(
1210                 n.children[1]
1211             ):
1212                 features.add(Feature.RELAXED_DECORATORS)
1213
1214         elif (
1215             n.type in {syms.typedargslist, syms.arglist}
1216             and n.children
1217             and n.children[-1].type == token.COMMA
1218         ):
1219             if n.type == syms.typedargslist:
1220                 feature = Feature.TRAILING_COMMA_IN_DEF
1221             else:
1222                 feature = Feature.TRAILING_COMMA_IN_CALL
1223
1224             for ch in n.children:
1225                 if ch.type in STARS:
1226                     features.add(feature)
1227
1228                 if ch.type == syms.argument:
1229                     for argch in ch.children:
1230                         if argch.type in STARS:
1231                             features.add(feature)
1232
1233         elif (
1234             n.type in {syms.return_stmt, syms.yield_expr}
1235             and len(n.children) >= 2
1236             and n.children[1].type == syms.testlist_star_expr
1237             and any(child.type == syms.star_expr for child in n.children[1].children)
1238         ):
1239             features.add(Feature.UNPACKING_ON_FLOW)
1240
1241         elif (
1242             n.type == syms.annassign
1243             and len(n.children) >= 4
1244             and n.children[3].type == syms.testlist_star_expr
1245         ):
1246             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1247
1248         elif (
1249             n.type == syms.with_stmt
1250             and len(n.children) > 2
1251             and n.children[1].type == syms.atom
1252         ):
1253             atom_children = n.children[1].children
1254             if (
1255                 len(atom_children) == 3
1256                 and atom_children[0].type == token.LPAR
1257                 and atom_children[1].type == syms.testlist_gexp
1258                 and atom_children[2].type == token.RPAR
1259             ):
1260                 features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
1261
1262         elif n.type == syms.match_stmt:
1263             features.add(Feature.PATTERN_MATCHING)
1264
1265         elif (
1266             n.type == syms.except_clause
1267             and len(n.children) >= 2
1268             and n.children[1].type == token.STAR
1269         ):
1270             features.add(Feature.EXCEPT_STAR)
1271
1272         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1273             child.type == syms.star_expr for child in n.children
1274         ):
1275             features.add(Feature.VARIADIC_GENERICS)
1276
1277         elif (
1278             n.type == syms.tname_star
1279             and len(n.children) == 3
1280             and n.children[2].type == syms.star_expr
1281         ):
1282             features.add(Feature.VARIADIC_GENERICS)
1283
1284         elif n.type in (syms.type_stmt, syms.typeparams):
1285             features.add(Feature.TYPE_PARAMS)
1286
1287     return features
1288
1289
1290 def detect_target_versions(
1291     node: Node, *, future_imports: Optional[Set[str]] = None
1292 ) -> Set[TargetVersion]:
1293     """Detect the version to target based on the nodes used."""
1294     features = get_features_used(node, future_imports=future_imports)
1295     return {
1296         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1297     }
1298
1299
1300 def get_future_imports(node: Node) -> Set[str]:
1301     """Return a set of __future__ imports in the file."""
1302     imports: Set[str] = set()
1303
1304     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1305         for child in children:
1306             if isinstance(child, Leaf):
1307                 if child.type == token.NAME:
1308                     yield child.value
1309
1310             elif child.type == syms.import_as_name:
1311                 orig_name = child.children[0]
1312                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1313                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1314                 yield orig_name.value
1315
1316             elif child.type == syms.import_as_names:
1317                 yield from get_imports_from_children(child.children)
1318
1319             else:
1320                 raise AssertionError("Invalid syntax parsing imports")
1321
1322     for child in node.children:
1323         if child.type != syms.simple_stmt:
1324             break
1325
1326         first_child = child.children[0]
1327         if isinstance(first_child, Leaf):
1328             # Continue looking if we see a docstring; otherwise stop.
1329             if (
1330                 len(child.children) == 2
1331                 and first_child.type == token.STRING
1332                 and child.children[1].type == token.NEWLINE
1333             ):
1334                 continue
1335
1336             break
1337
1338         elif first_child.type == syms.import_from:
1339             module_name = first_child.children[1]
1340             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1341                 break
1342
1343             imports |= set(get_imports_from_children(first_child.children[3:]))
1344         else:
1345             break
1346
1347     return imports
1348
1349
1350 def assert_equivalent(src: str, dst: str) -> None:
1351     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1352     try:
1353         src_ast = parse_ast(src)
1354     except Exception as exc:
1355         raise AssertionError(
1356             "cannot use --safe with this file; failed to parse source file AST: "
1357             f"{exc}\n"
1358             "This could be caused by running Black with an older Python version "
1359             "that does not support new syntax used in your source file."
1360         ) from exc
1361
1362     try:
1363         dst_ast = parse_ast(dst)
1364     except Exception as exc:
1365         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1366         raise AssertionError(
1367             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1368             "Please report a bug on https://github.com/psf/black/issues.  "
1369             f"This invalid output might be helpful: {log}"
1370         ) from None
1371
1372     src_ast_str = "\n".join(stringify_ast(src_ast))
1373     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1374     if src_ast_str != dst_ast_str:
1375         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1376         raise AssertionError(
1377             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1378             " source.  Please report a bug on "
1379             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1380         ) from None
1381
1382
1383 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1384     """Raise AssertionError if `dst` reformats differently the second time."""
1385     # We shouldn't call format_str() here, because that formats the string
1386     # twice and may hide a bug where we bounce back and forth between two
1387     # versions.
1388     newdst = _format_str_once(dst, mode=mode)
1389     if dst != newdst:
1390         log = dump_to_file(
1391             str(mode),
1392             diff(src, dst, "source", "first pass"),
1393             diff(dst, newdst, "first pass", "second pass"),
1394         )
1395         raise AssertionError(
1396             "INTERNAL ERROR: Black produced different code on the second pass of the"
1397             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1398             f"  This diff might be helpful: {log}"
1399         ) from None
1400
1401
1402 @contextmanager
1403 def nullcontext() -> Iterator[None]:
1404     """Return an empty context manager.
1405
1406     To be used like `nullcontext` in Python 3.7.
1407     """
1408     yield
1409
1410
1411 def patch_click() -> None:
1412     """Make Click not crash on Python 3.6 with LANG=C.
1413
1414     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1415     default which restricts paths that it can access during the lifetime of the
1416     application.  Click refuses to work in this scenario by raising a RuntimeError.
1417
1418     In case of Black the likelihood that non-ASCII characters are going to be used in
1419     file paths is minimal since it's Python source code.  Moreover, this crash was
1420     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1421     """
1422     modules: List[Any] = []
1423     try:
1424         from click import core
1425     except ImportError:
1426         pass
1427     else:
1428         modules.append(core)
1429     try:
1430         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1431         # older versions installed.
1432         from click import _unicodefun  # type: ignore
1433     except ImportError:
1434         pass
1435     else:
1436         modules.append(_unicodefun)
1437
1438     for module in modules:
1439         if hasattr(module, "_verify_python3_env"):
1440             module._verify_python3_env = lambda: None
1441         if hasattr(module, "_verify_python_env"):
1442             module._verify_python_env = lambda: None
1443
1444
1445 def patched_main() -> None:
1446     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1447     # environments so just assume we always need to call it if frozen.
1448     if getattr(sys, "frozen", False):
1449         from multiprocessing import freeze_support
1450
1451         freeze_support()
1452
1453     patch_click()
1454     main()
1455
1456
1457 if __name__ == "__main__":
1458     patched_main()