]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Migrate mypy config to pyproject.toml (#3936)
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime, timezone
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import FUTURE_FLAG_TO_FEATURE, VERSION_TO_FEATURES, Feature
67 from black.mode import Mode as Mode  # re-exported
68 from black.mode import TargetVersion, supports_feature
69 from black.nodes import (
70     STARS,
71     is_number_token,
72     is_simple_decorator_expression,
73     is_string_token,
74     syms,
75 )
76 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
77 from black.parsing import InvalidInput  # noqa F401
78 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
79 from black.report import Changed, NothingChanged, Report
80 from black.trans import iter_fexpr_spans
81 from blib2to3.pgen2 import token
82 from blib2to3.pytree import Leaf, Node
83
84 COMPILED = Path(__file__).suffix in (".pyd", ".so")
85
86 # types
87 FileContent = str
88 Encoding = str
89 NewLine = str
90
91
92 class WriteBack(Enum):
93     NO = 0
94     YES = 1
95     DIFF = 2
96     CHECK = 3
97     COLOR_DIFF = 4
98
99     @classmethod
100     def from_configuration(
101         cls, *, check: bool, diff: bool, color: bool = False
102     ) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         if diff and color:
107             return cls.COLOR_DIFF
108
109         return cls.DIFF if diff else cls.YES
110
111
112 # Legacy name, left for integrations.
113 FileMode = Mode
114
115
116 def read_pyproject_toml(
117     ctx: click.Context, param: click.Parameter, value: Optional[str]
118 ) -> Optional[str]:
119     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
120
121     Returns the path to a successfully found and read configuration file, None
122     otherwise.
123     """
124     if not value:
125         value = find_pyproject_toml(
126             ctx.params.get("src", ()), ctx.params.get("stdin_filename", None)
127         )
128         if value is None:
129             return None
130
131     try:
132         config = parse_pyproject_toml(value)
133     except (OSError, ValueError) as e:
134         raise click.FileError(
135             filename=value, hint=f"Error reading configuration file: {e}"
136         ) from None
137
138     if not config:
139         return None
140     else:
141         # Sanitize the values to be Click friendly. For more information please see:
142         # https://github.com/psf/black/issues/1458
143         # https://github.com/pallets/click/issues/1567
144         config = {
145             k: str(v) if not isinstance(v, (list, dict)) else v
146             for k, v in config.items()
147         }
148
149     target_version = config.get("target_version")
150     if target_version is not None and not isinstance(target_version, list):
151         raise click.BadOptionUsage(
152             "target-version", "Config key target-version must be a list"
153         )
154
155     exclude = config.get("exclude")
156     if exclude is not None and not isinstance(exclude, str):
157         raise click.BadOptionUsage("exclude", "Config key exclude must be a string")
158
159     extend_exclude = config.get("extend_exclude")
160     if extend_exclude is not None and not isinstance(extend_exclude, str):
161         raise click.BadOptionUsage(
162             "extend-exclude", "Config key extend-exclude must be a string"
163         )
164
165     default_map: Dict[str, Any] = {}
166     if ctx.default_map:
167         default_map.update(ctx.default_map)
168     default_map.update(config)
169
170     ctx.default_map = default_map
171     return value
172
173
174 def target_version_option_callback(
175     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
176 ) -> List[TargetVersion]:
177     """Compute the target versions from a --target-version flag.
178
179     This is its own function because mypy couldn't infer the type correctly
180     when it was a lambda, causing mypyc trouble.
181     """
182     return [TargetVersion[val.upper()] for val in v]
183
184
185 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
186     """Compile a regular expression string in `regex`.
187
188     If it contains newlines, use verbose mode.
189     """
190     if "\n" in regex:
191         regex = "(?x)" + regex
192     compiled: Pattern[str] = re.compile(regex)
193     return compiled
194
195
196 def validate_regex(
197     ctx: click.Context,
198     param: click.Parameter,
199     value: Optional[str],
200 ) -> Optional[Pattern[str]]:
201     try:
202         return re_compile_maybe_verbose(value) if value is not None else None
203     except re.error as e:
204         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
205
206
207 @click.command(
208     context_settings={"help_option_names": ["-h", "--help"]},
209     # While Click does set this field automatically using the docstring, mypyc
210     # (annoyingly) strips 'em so we need to set it here too.
211     help="The uncompromising code formatter.",
212 )
213 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
214 @click.option(
215     "-l",
216     "--line-length",
217     type=int,
218     default=DEFAULT_LINE_LENGTH,
219     help="How many characters per line to allow.",
220     show_default=True,
221 )
222 @click.option(
223     "-t",
224     "--target-version",
225     type=click.Choice([v.name.lower() for v in TargetVersion]),
226     callback=target_version_option_callback,
227     multiple=True,
228     help=(
229         "Python versions that should be supported by Black's output. By default, Black"
230         " will try to infer this from the project metadata in pyproject.toml. If this"
231         " does not yield conclusive results, Black will use per-file auto-detection."
232     ),
233 )
234 @click.option(
235     "--pyi",
236     is_flag=True,
237     help=(
238         "Format all input files like typing stubs regardless of file extension (useful"
239         " when piping source on standard input)."
240     ),
241 )
242 @click.option(
243     "--ipynb",
244     is_flag=True,
245     help=(
246         "Format all input files like Jupyter Notebooks regardless of file extension "
247         "(useful when piping source on standard input)."
248     ),
249 )
250 @click.option(
251     "--python-cell-magics",
252     multiple=True,
253     help=(
254         "When processing Jupyter Notebooks, add the given magic to the list"
255         f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
256         " Useful for formatting cells with custom python magics."
257     ),
258     default=[],
259 )
260 @click.option(
261     "-x",
262     "--skip-source-first-line",
263     is_flag=True,
264     help="Skip the first line of the source code.",
265 )
266 @click.option(
267     "-S",
268     "--skip-string-normalization",
269     is_flag=True,
270     help="Don't normalize string quotes or prefixes.",
271 )
272 @click.option(
273     "-C",
274     "--skip-magic-trailing-comma",
275     is_flag=True,
276     help="Don't use trailing commas as a reason to split lines.",
277 )
278 @click.option(
279     "--experimental-string-processing",
280     is_flag=True,
281     hidden=True,
282     help="(DEPRECATED and now included in --preview) Normalize string literals.",
283 )
284 @click.option(
285     "--preview",
286     is_flag=True,
287     help=(
288         "Enable potentially disruptive style changes that may be added to Black's main"
289         " functionality in the next major release."
290     ),
291 )
292 @click.option(
293     "--check",
294     is_flag=True,
295     help=(
296         "Don't write the files back, just return the status. Return code 0 means"
297         " nothing would change. Return code 1 means some files would be reformatted."
298         " Return code 123 means there was an internal error."
299     ),
300 )
301 @click.option(
302     "--diff",
303     is_flag=True,
304     help="Don't write the files back, just output a diff for each file on stdout.",
305 )
306 @click.option(
307     "--color/--no-color",
308     is_flag=True,
309     help="Show colored diff. Only applies when `--diff` is given.",
310 )
311 @click.option(
312     "--fast/--safe",
313     is_flag=True,
314     help="If --fast given, skip temporary sanity checks. [default: --safe]",
315 )
316 @click.option(
317     "--required-version",
318     type=str,
319     help=(
320         "Require a specific version of Black to be running (useful for unifying results"
321         " across many environments e.g. with a pyproject.toml file). It can be"
322         " either a major version number or an exact version."
323     ),
324 )
325 @click.option(
326     "--include",
327     type=str,
328     default=DEFAULT_INCLUDES,
329     callback=validate_regex,
330     help=(
331         "A regular expression that matches files and directories that should be"
332         " included on recursive searches. An empty value means all files are included"
333         " regardless of the name. Use forward slashes for directories on all platforms"
334         " (Windows, too). Exclusions are calculated first, inclusions later."
335     ),
336     show_default=True,
337 )
338 @click.option(
339     "--exclude",
340     type=str,
341     callback=validate_regex,
342     help=(
343         "A regular expression that matches files and directories that should be"
344         " excluded on recursive searches. An empty value means no paths are excluded."
345         " Use forward slashes for directories on all platforms (Windows, too)."
346         " Exclusions are calculated first, inclusions later. [default:"
347         f" {DEFAULT_EXCLUDES}]"
348     ),
349     show_default=False,
350 )
351 @click.option(
352     "--extend-exclude",
353     type=str,
354     callback=validate_regex,
355     help=(
356         "Like --exclude, but adds additional files and directories on top of the"
357         " excluded ones. (Useful if you simply want to add to the default)"
358     ),
359 )
360 @click.option(
361     "--force-exclude",
362     type=str,
363     callback=validate_regex,
364     help=(
365         "Like --exclude, but files and directories matching this regex will be "
366         "excluded even when they are passed explicitly as arguments."
367     ),
368 )
369 @click.option(
370     "--stdin-filename",
371     type=str,
372     is_eager=True,
373     help=(
374         "The name of the file when passing it through stdin. Useful to make "
375         "sure Black will respect --force-exclude option on some "
376         "editors that rely on using stdin."
377     ),
378 )
379 @click.option(
380     "-W",
381     "--workers",
382     type=click.IntRange(min=1),
383     default=None,
384     help=(
385         "Number of parallel workers [default: BLACK_NUM_WORKERS environment variable "
386         "or number of CPUs in the system]"
387     ),
388 )
389 @click.option(
390     "-q",
391     "--quiet",
392     is_flag=True,
393     help=(
394         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
395         " those with 2>/dev/null."
396     ),
397 )
398 @click.option(
399     "-v",
400     "--verbose",
401     is_flag=True,
402     help=(
403         "Also emit messages to stderr about files that were not changed or were ignored"
404         " due to exclusion patterns."
405     ),
406 )
407 @click.version_option(
408     version=__version__,
409     message=(
410         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
411         f"Python ({platform.python_implementation()}) {platform.python_version()}"
412     ),
413 )
414 @click.argument(
415     "src",
416     nargs=-1,
417     type=click.Path(
418         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
419     ),
420     is_eager=True,
421     metavar="SRC ...",
422 )
423 @click.option(
424     "--config",
425     type=click.Path(
426         exists=True,
427         file_okay=True,
428         dir_okay=False,
429         readable=True,
430         allow_dash=False,
431         path_type=str,
432     ),
433     is_eager=True,
434     callback=read_pyproject_toml,
435     help="Read configuration from FILE path.",
436 )
437 @click.pass_context
438 def main(  # noqa: C901
439     ctx: click.Context,
440     code: Optional[str],
441     line_length: int,
442     target_version: List[TargetVersion],
443     check: bool,
444     diff: bool,
445     color: bool,
446     fast: bool,
447     pyi: bool,
448     ipynb: bool,
449     python_cell_magics: Sequence[str],
450     skip_source_first_line: bool,
451     skip_string_normalization: bool,
452     skip_magic_trailing_comma: bool,
453     experimental_string_processing: bool,
454     preview: bool,
455     quiet: bool,
456     verbose: bool,
457     required_version: Optional[str],
458     include: Pattern[str],
459     exclude: Optional[Pattern[str]],
460     extend_exclude: Optional[Pattern[str]],
461     force_exclude: Optional[Pattern[str]],
462     stdin_filename: Optional[str],
463     workers: Optional[int],
464     src: Tuple[str, ...],
465     config: Optional[str],
466 ) -> None:
467     """The uncompromising code formatter."""
468     ctx.ensure_object(dict)
469
470     if src and code is not None:
471         out(
472             main.get_usage(ctx)
473             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
474         )
475         ctx.exit(1)
476     if not src and code is None:
477         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
478         ctx.exit(1)
479
480     root, method = (
481         find_project_root(src, stdin_filename) if code is None else (None, None)
482     )
483     ctx.obj["root"] = root
484
485     if verbose:
486         if root:
487             out(
488                 f"Identified `{root}` as project root containing a {method}.",
489                 fg="blue",
490             )
491
492         if config:
493             config_source = ctx.get_parameter_source("config")
494             user_level_config = str(find_user_pyproject_toml())
495             if config == user_level_config:
496                 out(
497                     "Using configuration from user-level config at "
498                     f"'{user_level_config}'.",
499                     fg="blue",
500                 )
501             elif config_source in (
502                 ParameterSource.DEFAULT,
503                 ParameterSource.DEFAULT_MAP,
504             ):
505                 out("Using configuration from project root.", fg="blue")
506             else:
507                 out(f"Using configuration in '{config}'.", fg="blue")
508             if ctx.default_map:
509                 for param, value in ctx.default_map.items():
510                     out(f"{param}: {value}")
511
512     error_msg = "Oh no! 💥 💔 💥"
513     if (
514         required_version
515         and required_version != __version__
516         and required_version != __version__.split(".")[0]
517     ):
518         err(
519             f"{error_msg} The required version `{required_version}` does not match"
520             f" the running version `{__version__}`!"
521         )
522         ctx.exit(1)
523     if ipynb and pyi:
524         err("Cannot pass both `pyi` and `ipynb` flags!")
525         ctx.exit(1)
526
527     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
528     if target_version:
529         versions = set(target_version)
530     else:
531         # We'll autodetect later.
532         versions = set()
533     mode = Mode(
534         target_versions=versions,
535         line_length=line_length,
536         is_pyi=pyi,
537         is_ipynb=ipynb,
538         skip_source_first_line=skip_source_first_line,
539         string_normalization=not skip_string_normalization,
540         magic_trailing_comma=not skip_magic_trailing_comma,
541         experimental_string_processing=experimental_string_processing,
542         preview=preview,
543         python_cell_magics=set(python_cell_magics),
544     )
545
546     if code is not None:
547         # Run in quiet mode by default with -c; the extra output isn't useful.
548         # You can still pass -v to get verbose output.
549         quiet = True
550
551     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
552
553     if code is not None:
554         reformat_code(
555             content=code, fast=fast, write_back=write_back, mode=mode, report=report
556         )
557     else:
558         assert root is not None  # root is only None if code is not None
559         try:
560             sources = get_sources(
561                 root=root,
562                 src=src,
563                 quiet=quiet,
564                 verbose=verbose,
565                 include=include,
566                 exclude=exclude,
567                 extend_exclude=extend_exclude,
568                 force_exclude=force_exclude,
569                 report=report,
570                 stdin_filename=stdin_filename,
571             )
572         except GitWildMatchPatternError:
573             ctx.exit(1)
574
575         path_empty(
576             sources,
577             "No Python files are present to be formatted. Nothing to do 😴",
578             quiet,
579             verbose,
580             ctx,
581         )
582
583         if len(sources) == 1:
584             reformat_one(
585                 src=sources.pop(),
586                 fast=fast,
587                 write_back=write_back,
588                 mode=mode,
589                 report=report,
590             )
591         else:
592             from black.concurrency import reformat_many
593
594             reformat_many(
595                 sources=sources,
596                 fast=fast,
597                 write_back=write_back,
598                 mode=mode,
599                 report=report,
600                 workers=workers,
601             )
602
603     if verbose or not quiet:
604         if code is None and (verbose or report.change_count or report.failure_count):
605             out()
606         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
607         if code is None:
608             click.echo(str(report), err=True)
609     ctx.exit(report.return_code)
610
611
612 def get_sources(
613     *,
614     root: Path,
615     src: Tuple[str, ...],
616     quiet: bool,
617     verbose: bool,
618     include: Pattern[str],
619     exclude: Optional[Pattern[str]],
620     extend_exclude: Optional[Pattern[str]],
621     force_exclude: Optional[Pattern[str]],
622     report: "Report",
623     stdin_filename: Optional[str],
624 ) -> Set[Path]:
625     """Compute the set of files to be formatted."""
626     sources: Set[Path] = set()
627
628     using_default_exclude = exclude is None
629     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
630     gitignore: Optional[Dict[Path, PathSpec]] = None
631     root_gitignore = get_gitignore(root)
632
633     for s in src:
634         if s == "-" and stdin_filename:
635             p = Path(stdin_filename)
636             is_stdin = True
637         else:
638             p = Path(s)
639             is_stdin = False
640
641         if is_stdin or p.is_file():
642             normalized_path: Optional[str] = normalize_path_maybe_ignore(
643                 p, root, report
644             )
645             if normalized_path is None:
646                 if verbose:
647                     out(f'Skipping invalid source: "{normalized_path}"', fg="red")
648                 continue
649             if verbose:
650                 out(f'Found input source: "{normalized_path}"', fg="blue")
651
652             normalized_path = "/" + normalized_path
653             # Hard-exclude any files that matches the `--force-exclude` regex.
654             if force_exclude:
655                 force_exclude_match = force_exclude.search(normalized_path)
656             else:
657                 force_exclude_match = None
658             if force_exclude_match and force_exclude_match.group(0):
659                 report.path_ignored(p, "matches the --force-exclude regular expression")
660                 continue
661
662             if is_stdin:
663                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
664
665             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
666                 warn=verbose or not quiet
667             ):
668                 continue
669
670             sources.add(p)
671         elif p.is_dir():
672             p_relative = normalize_path_maybe_ignore(p, root, report)
673             assert p_relative is not None
674             p = root / p_relative
675             if verbose:
676                 out(f'Found input source directory: "{p}"', fg="blue")
677
678             if using_default_exclude:
679                 gitignore = {
680                     root: root_gitignore,
681                     p: get_gitignore(p),
682                 }
683             sources.update(
684                 gen_python_files(
685                     p.iterdir(),
686                     root,
687                     include,
688                     exclude,
689                     extend_exclude,
690                     force_exclude,
691                     report,
692                     gitignore,
693                     verbose=verbose,
694                     quiet=quiet,
695                 )
696             )
697         elif s == "-":
698             if verbose:
699                 out("Found input source stdin", fg="blue")
700             sources.add(p)
701         else:
702             err(f"invalid path: {s}")
703
704     return sources
705
706
707 def path_empty(
708     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
709 ) -> None:
710     """
711     Exit if there is no `src` provided for formatting
712     """
713     if not src:
714         if verbose or not quiet:
715             out(msg)
716         ctx.exit(0)
717
718
719 def reformat_code(
720     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
721 ) -> None:
722     """
723     Reformat and print out `content` without spawning child processes.
724     Similar to `reformat_one`, but for string content.
725
726     `fast`, `write_back`, and `mode` options are passed to
727     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
728     """
729     path = Path("<string>")
730     try:
731         changed = Changed.NO
732         if format_stdin_to_stdout(
733             content=content, fast=fast, write_back=write_back, mode=mode
734         ):
735             changed = Changed.YES
736         report.done(path, changed)
737     except Exception as exc:
738         if report.verbose:
739             traceback.print_exc()
740         report.failed(path, str(exc))
741
742
743 # diff-shades depends on being to monkeypatch this function to operate. I know it's
744 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
745 @mypyc_attr(patchable=True)
746 def reformat_one(
747     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
748 ) -> None:
749     """Reformat a single file under `src` without spawning child processes.
750
751     `fast`, `write_back`, and `mode` options are passed to
752     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
753     """
754     try:
755         changed = Changed.NO
756
757         if str(src) == "-":
758             is_stdin = True
759         elif str(src).startswith(STDIN_PLACEHOLDER):
760             is_stdin = True
761             # Use the original name again in case we want to print something
762             # to the user
763             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
764         else:
765             is_stdin = False
766
767         if is_stdin:
768             if src.suffix == ".pyi":
769                 mode = replace(mode, is_pyi=True)
770             elif src.suffix == ".ipynb":
771                 mode = replace(mode, is_ipynb=True)
772             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
773                 changed = Changed.YES
774         else:
775             cache = Cache.read(mode)
776             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
777                 if not cache.is_changed(src):
778                     changed = Changed.CACHED
779             if changed is not Changed.CACHED and format_file_in_place(
780                 src, fast=fast, write_back=write_back, mode=mode
781             ):
782                 changed = Changed.YES
783             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
784                 write_back is WriteBack.CHECK and changed is Changed.NO
785             ):
786                 cache.write([src])
787         report.done(src, changed)
788     except Exception as exc:
789         if report.verbose:
790             traceback.print_exc()
791         report.failed(src, str(exc))
792
793
794 def format_file_in_place(
795     src: Path,
796     fast: bool,
797     mode: Mode,
798     write_back: WriteBack = WriteBack.NO,
799     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
800 ) -> bool:
801     """Format file under `src` path. Return True if changed.
802
803     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
804     code to the file.
805     `mode` and `fast` options are passed to :func:`format_file_contents`.
806     """
807     if src.suffix == ".pyi":
808         mode = replace(mode, is_pyi=True)
809     elif src.suffix == ".ipynb":
810         mode = replace(mode, is_ipynb=True)
811
812     then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)
813     header = b""
814     with open(src, "rb") as buf:
815         if mode.skip_source_first_line:
816             header = buf.readline()
817         src_contents, encoding, newline = decode_bytes(buf.read())
818     try:
819         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
820     except NothingChanged:
821         return False
822     except JSONDecodeError:
823         raise ValueError(
824             f"File '{src}' cannot be parsed as valid Jupyter notebook."
825         ) from None
826     src_contents = header.decode(encoding) + src_contents
827     dst_contents = header.decode(encoding) + dst_contents
828
829     if write_back == WriteBack.YES:
830         with open(src, "w", encoding=encoding, newline=newline) as f:
831             f.write(dst_contents)
832     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
833         now = datetime.now(timezone.utc)
834         src_name = f"{src}\t{then}"
835         dst_name = f"{src}\t{now}"
836         if mode.is_ipynb:
837             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
838         else:
839             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
840
841         if write_back == WriteBack.COLOR_DIFF:
842             diff_contents = color_diff(diff_contents)
843
844         with lock or nullcontext():
845             f = io.TextIOWrapper(
846                 sys.stdout.buffer,
847                 encoding=encoding,
848                 newline=newline,
849                 write_through=True,
850             )
851             f = wrap_stream_for_windows(f)
852             f.write(diff_contents)
853             f.detach()
854
855     return True
856
857
858 def format_stdin_to_stdout(
859     fast: bool,
860     *,
861     content: Optional[str] = None,
862     write_back: WriteBack = WriteBack.NO,
863     mode: Mode,
864 ) -> bool:
865     """Format file on stdin. Return True if changed.
866
867     If content is None, it's read from sys.stdin.
868
869     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
870     write a diff to stdout. The `mode` argument is passed to
871     :func:`format_file_contents`.
872     """
873     then = datetime.now(timezone.utc)
874
875     if content is None:
876         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
877     else:
878         src, encoding, newline = content, "utf-8", ""
879
880     dst = src
881     try:
882         dst = format_file_contents(src, fast=fast, mode=mode)
883         return True
884
885     except NothingChanged:
886         return False
887
888     finally:
889         f = io.TextIOWrapper(
890             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
891         )
892         if write_back == WriteBack.YES:
893             # Make sure there's a newline after the content
894             if dst and dst[-1] != "\n":
895                 dst += "\n"
896             f.write(dst)
897         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
898             now = datetime.now(timezone.utc)
899             src_name = f"STDIN\t{then}"
900             dst_name = f"STDOUT\t{now}"
901             d = diff(src, dst, src_name, dst_name)
902             if write_back == WriteBack.COLOR_DIFF:
903                 d = color_diff(d)
904                 f = wrap_stream_for_windows(f)
905             f.write(d)
906         f.detach()
907
908
909 def check_stability_and_equivalence(
910     src_contents: str, dst_contents: str, *, mode: Mode
911 ) -> None:
912     """Perform stability and equivalence checks.
913
914     Raise AssertionError if source and destination contents are not
915     equivalent, or if a second pass of the formatter would format the
916     content differently.
917     """
918     assert_equivalent(src_contents, dst_contents)
919     assert_stable(src_contents, dst_contents, mode=mode)
920
921
922 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
923     """Reformat contents of a file and return new contents.
924
925     If `fast` is False, additionally confirm that the reformatted code is
926     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
927     `mode` is passed to :func:`format_str`.
928     """
929     if mode.is_ipynb:
930         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
931     else:
932         dst_contents = format_str(src_contents, mode=mode)
933     if src_contents == dst_contents:
934         raise NothingChanged
935
936     if not fast and not mode.is_ipynb:
937         # Jupyter notebooks will already have been checked above.
938         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
939     return dst_contents
940
941
942 def validate_cell(src: str, mode: Mode) -> None:
943     """Check that cell does not already contain TransformerManager transformations,
944     or non-Python cell magics, which might cause tokenizer_rt to break because of
945     indentations.
946
947     If a cell contains ``!ls``, then it'll be transformed to
948     ``get_ipython().system('ls')``. However, if the cell originally contained
949     ``get_ipython().system('ls')``, then it would get transformed in the same way:
950
951         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
952         "get_ipython().system('ls')\n"
953         >>> TransformerManager().transform_cell("!ls")
954         "get_ipython().system('ls')\n"
955
956     Due to the impossibility of safely roundtripping in such situations, cells
957     containing transformed magics will be ignored.
958     """
959     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
960         raise NothingChanged
961     if (
962         src[:2] == "%%"
963         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
964     ):
965         raise NothingChanged
966
967
968 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
969     """Format code in given cell of Jupyter notebook.
970
971     General idea is:
972
973       - if cell has trailing semicolon, remove it;
974       - if cell has IPython magics, mask them;
975       - format cell;
976       - reinstate IPython magics;
977       - reinstate trailing semicolon (if originally present);
978       - strip trailing newlines.
979
980     Cells with syntax errors will not be processed, as they
981     could potentially be automagics or multi-line magics, which
982     are currently not supported.
983     """
984     validate_cell(src, mode)
985     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
986         src
987     )
988     try:
989         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
990     except SyntaxError:
991         raise NothingChanged from None
992     masked_dst = format_str(masked_src, mode=mode)
993     if not fast:
994         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
995     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
996     dst = put_trailing_semicolon_back(
997         dst_without_trailing_semicolon, has_trailing_semicolon
998     )
999     dst = dst.rstrip("\n")
1000     if dst == src:
1001         raise NothingChanged from None
1002     return dst
1003
1004
1005 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1006     """If notebook is marked as non-Python, don't format it.
1007
1008     All notebook metadata fields are optional, see
1009     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1010     if a notebook has empty metadata, we will try to parse it anyway.
1011     """
1012     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1013     if language is not None and language != "python":
1014         raise NothingChanged from None
1015
1016
1017 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1018     """Format Jupyter notebook.
1019
1020     Operate cell-by-cell, only on code cells, only for Python notebooks.
1021     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1022     """
1023     if not src_contents:
1024         raise NothingChanged
1025
1026     trailing_newline = src_contents[-1] == "\n"
1027     modified = False
1028     nb = json.loads(src_contents)
1029     validate_metadata(nb)
1030     for cell in nb["cells"]:
1031         if cell.get("cell_type", None) == "code":
1032             try:
1033                 src = "".join(cell["source"])
1034                 dst = format_cell(src, fast=fast, mode=mode)
1035             except NothingChanged:
1036                 pass
1037             else:
1038                 cell["source"] = dst.splitlines(keepends=True)
1039                 modified = True
1040     if modified:
1041         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1042         if trailing_newline:
1043             dst_contents = dst_contents + "\n"
1044         return dst_contents
1045     else:
1046         raise NothingChanged
1047
1048
1049 def format_str(src_contents: str, *, mode: Mode) -> str:
1050     """Reformat a string and return new contents.
1051
1052     `mode` determines formatting options, such as how many characters per line are
1053     allowed.  Example:
1054
1055     >>> import black
1056     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1057     def f(arg: str = "") -> None:
1058         ...
1059
1060     A more complex example:
1061
1062     >>> print(
1063     ...   black.format_str(
1064     ...     "def f(arg:str='')->None: hey",
1065     ...     mode=black.Mode(
1066     ...       target_versions={black.TargetVersion.PY36},
1067     ...       line_length=10,
1068     ...       string_normalization=False,
1069     ...       is_pyi=False,
1070     ...     ),
1071     ...   ),
1072     ... )
1073     def f(
1074         arg: str = '',
1075     ) -> None:
1076         hey
1077
1078     """
1079     dst_contents = _format_str_once(src_contents, mode=mode)
1080     # Forced second pass to work around optional trailing commas (becoming
1081     # forced trailing commas on pass 2) interacting differently with optional
1082     # parentheses.  Admittedly ugly.
1083     if src_contents != dst_contents:
1084         return _format_str_once(dst_contents, mode=mode)
1085     return dst_contents
1086
1087
1088 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1089     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1090     dst_blocks: List[LinesBlock] = []
1091     if mode.target_versions:
1092         versions = mode.target_versions
1093     else:
1094         future_imports = get_future_imports(src_node)
1095         versions = detect_target_versions(src_node, future_imports=future_imports)
1096
1097     context_manager_features = {
1098         feature
1099         for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
1100         if supports_feature(versions, feature)
1101     }
1102     normalize_fmt_off(src_node)
1103     lines = LineGenerator(mode=mode, features=context_manager_features)
1104     elt = EmptyLineTracker(mode=mode)
1105     split_line_features = {
1106         feature
1107         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1108         if supports_feature(versions, feature)
1109     }
1110     block: Optional[LinesBlock] = None
1111     for current_line in lines.visit(src_node):
1112         block = elt.maybe_empty_lines(current_line)
1113         dst_blocks.append(block)
1114         for line in transform_line(
1115             current_line, mode=mode, features=split_line_features
1116         ):
1117             block.content_lines.append(str(line))
1118     if dst_blocks:
1119         dst_blocks[-1].after = 0
1120     dst_contents = []
1121     for block in dst_blocks:
1122         dst_contents.extend(block.all_lines())
1123     if not dst_contents:
1124         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1125         # and check if normalized_content has more than one line
1126         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1127         if "\n" in normalized_content:
1128             return newline
1129         return ""
1130     return "".join(dst_contents)
1131
1132
1133 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1134     """Return a tuple of (decoded_contents, encoding, newline).
1135
1136     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1137     universal newlines (i.e. only contains LF).
1138     """
1139     srcbuf = io.BytesIO(src)
1140     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1141     if not lines:
1142         return "", encoding, "\n"
1143
1144     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1145     srcbuf.seek(0)
1146     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1147         return tiow.read(), encoding, newline
1148
1149
1150 def get_features_used(  # noqa: C901
1151     node: Node, *, future_imports: Optional[Set[str]] = None
1152 ) -> Set[Feature]:
1153     """Return a set of (relatively) new Python features used in this file.
1154
1155     Currently looking for:
1156     - f-strings;
1157     - self-documenting expressions in f-strings (f"{x=}");
1158     - underscores in numeric literals;
1159     - trailing commas after * or ** in function signatures and calls;
1160     - positional only arguments in function signatures and lambdas;
1161     - assignment expression;
1162     - relaxed decorator syntax;
1163     - usage of __future__ flags (annotations);
1164     - print / exec statements;
1165     - parenthesized context managers;
1166     - match statements;
1167     - except* clause;
1168     - variadic generics;
1169     """
1170     features: Set[Feature] = set()
1171     if future_imports:
1172         features |= {
1173             FUTURE_FLAG_TO_FEATURE[future_import]
1174             for future_import in future_imports
1175             if future_import in FUTURE_FLAG_TO_FEATURE
1176         }
1177
1178     for n in node.pre_order():
1179         if is_string_token(n):
1180             value_head = n.value[:2]
1181             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1182                 features.add(Feature.F_STRINGS)
1183                 if Feature.DEBUG_F_STRINGS not in features:
1184                     for span_beg, span_end in iter_fexpr_spans(n.value):
1185                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1186                             features.add(Feature.DEBUG_F_STRINGS)
1187                             break
1188
1189         elif is_number_token(n):
1190             if "_" in n.value:
1191                 features.add(Feature.NUMERIC_UNDERSCORES)
1192
1193         elif n.type == token.SLASH:
1194             if n.parent and n.parent.type in {
1195                 syms.typedargslist,
1196                 syms.arglist,
1197                 syms.varargslist,
1198             }:
1199                 features.add(Feature.POS_ONLY_ARGUMENTS)
1200
1201         elif n.type == token.COLONEQUAL:
1202             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1203
1204         elif n.type == syms.decorator:
1205             if len(n.children) > 1 and not is_simple_decorator_expression(
1206                 n.children[1]
1207             ):
1208                 features.add(Feature.RELAXED_DECORATORS)
1209
1210         elif (
1211             n.type in {syms.typedargslist, syms.arglist}
1212             and n.children
1213             and n.children[-1].type == token.COMMA
1214         ):
1215             if n.type == syms.typedargslist:
1216                 feature = Feature.TRAILING_COMMA_IN_DEF
1217             else:
1218                 feature = Feature.TRAILING_COMMA_IN_CALL
1219
1220             for ch in n.children:
1221                 if ch.type in STARS:
1222                     features.add(feature)
1223
1224                 if ch.type == syms.argument:
1225                     for argch in ch.children:
1226                         if argch.type in STARS:
1227                             features.add(feature)
1228
1229         elif (
1230             n.type in {syms.return_stmt, syms.yield_expr}
1231             and len(n.children) >= 2
1232             and n.children[1].type == syms.testlist_star_expr
1233             and any(child.type == syms.star_expr for child in n.children[1].children)
1234         ):
1235             features.add(Feature.UNPACKING_ON_FLOW)
1236
1237         elif (
1238             n.type == syms.annassign
1239             and len(n.children) >= 4
1240             and n.children[3].type == syms.testlist_star_expr
1241         ):
1242             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1243
1244         elif (
1245             n.type == syms.with_stmt
1246             and len(n.children) > 2
1247             and n.children[1].type == syms.atom
1248         ):
1249             atom_children = n.children[1].children
1250             if (
1251                 len(atom_children) == 3
1252                 and atom_children[0].type == token.LPAR
1253                 and atom_children[1].type == syms.testlist_gexp
1254                 and atom_children[2].type == token.RPAR
1255             ):
1256                 features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
1257
1258         elif n.type == syms.match_stmt:
1259             features.add(Feature.PATTERN_MATCHING)
1260
1261         elif (
1262             n.type == syms.except_clause
1263             and len(n.children) >= 2
1264             and n.children[1].type == token.STAR
1265         ):
1266             features.add(Feature.EXCEPT_STAR)
1267
1268         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1269             child.type == syms.star_expr for child in n.children
1270         ):
1271             features.add(Feature.VARIADIC_GENERICS)
1272
1273         elif (
1274             n.type == syms.tname_star
1275             and len(n.children) == 3
1276             and n.children[2].type == syms.star_expr
1277         ):
1278             features.add(Feature.VARIADIC_GENERICS)
1279
1280         elif n.type in (syms.type_stmt, syms.typeparams):
1281             features.add(Feature.TYPE_PARAMS)
1282
1283     return features
1284
1285
1286 def detect_target_versions(
1287     node: Node, *, future_imports: Optional[Set[str]] = None
1288 ) -> Set[TargetVersion]:
1289     """Detect the version to target based on the nodes used."""
1290     features = get_features_used(node, future_imports=future_imports)
1291     return {
1292         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1293     }
1294
1295
1296 def get_future_imports(node: Node) -> Set[str]:
1297     """Return a set of __future__ imports in the file."""
1298     imports: Set[str] = set()
1299
1300     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1301         for child in children:
1302             if isinstance(child, Leaf):
1303                 if child.type == token.NAME:
1304                     yield child.value
1305
1306             elif child.type == syms.import_as_name:
1307                 orig_name = child.children[0]
1308                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1309                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1310                 yield orig_name.value
1311
1312             elif child.type == syms.import_as_names:
1313                 yield from get_imports_from_children(child.children)
1314
1315             else:
1316                 raise AssertionError("Invalid syntax parsing imports")
1317
1318     for child in node.children:
1319         if child.type != syms.simple_stmt:
1320             break
1321
1322         first_child = child.children[0]
1323         if isinstance(first_child, Leaf):
1324             # Continue looking if we see a docstring; otherwise stop.
1325             if (
1326                 len(child.children) == 2
1327                 and first_child.type == token.STRING
1328                 and child.children[1].type == token.NEWLINE
1329             ):
1330                 continue
1331
1332             break
1333
1334         elif first_child.type == syms.import_from:
1335             module_name = first_child.children[1]
1336             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1337                 break
1338
1339             imports |= set(get_imports_from_children(first_child.children[3:]))
1340         else:
1341             break
1342
1343     return imports
1344
1345
1346 def assert_equivalent(src: str, dst: str) -> None:
1347     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1348     try:
1349         src_ast = parse_ast(src)
1350     except Exception as exc:
1351         raise AssertionError(
1352             "cannot use --safe with this file; failed to parse source file AST: "
1353             f"{exc}\n"
1354             "This could be caused by running Black with an older Python version "
1355             "that does not support new syntax used in your source file."
1356         ) from exc
1357
1358     try:
1359         dst_ast = parse_ast(dst)
1360     except Exception as exc:
1361         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1362         raise AssertionError(
1363             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1364             "Please report a bug on https://github.com/psf/black/issues.  "
1365             f"This invalid output might be helpful: {log}"
1366         ) from None
1367
1368     src_ast_str = "\n".join(stringify_ast(src_ast))
1369     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1370     if src_ast_str != dst_ast_str:
1371         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1372         raise AssertionError(
1373             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1374             " source.  Please report a bug on "
1375             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1376         ) from None
1377
1378
1379 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1380     """Raise AssertionError if `dst` reformats differently the second time."""
1381     # We shouldn't call format_str() here, because that formats the string
1382     # twice and may hide a bug where we bounce back and forth between two
1383     # versions.
1384     newdst = _format_str_once(dst, mode=mode)
1385     if dst != newdst:
1386         log = dump_to_file(
1387             str(mode),
1388             diff(src, dst, "source", "first pass"),
1389             diff(dst, newdst, "first pass", "second pass"),
1390         )
1391         raise AssertionError(
1392             "INTERNAL ERROR: Black produced different code on the second pass of the"
1393             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1394             f"  This diff might be helpful: {log}"
1395         ) from None
1396
1397
1398 @contextmanager
1399 def nullcontext() -> Iterator[None]:
1400     """Return an empty context manager.
1401
1402     To be used like `nullcontext` in Python 3.7.
1403     """
1404     yield
1405
1406
1407 def patched_main() -> None:
1408     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1409     # environments so just assume we always need to call it if frozen.
1410     if getattr(sys, "frozen", False):
1411         from multiprocessing import freeze_support
1412
1413         freeze_support()
1414
1415     main()
1416
1417
1418 if __name__ == "__main__":
1419     patched_main()