]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Move coverage configurations to `pyproject.toml` (#3858)
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime, timezone
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import (
67     FUTURE_FLAG_TO_FEATURE,
68     VERSION_TO_FEATURES,
69     Feature,
70     Mode,
71     TargetVersion,
72     supports_feature,
73 )
74 from black.nodes import (
75     STARS,
76     is_number_token,
77     is_simple_decorator_expression,
78     is_string_token,
79     syms,
80 )
81 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
82 from black.parsing import InvalidInput  # noqa F401
83 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
84 from black.report import Changed, NothingChanged, Report
85 from black.trans import iter_fexpr_spans
86 from blib2to3.pgen2 import token
87 from blib2to3.pytree import Leaf, Node
88
89 COMPILED = Path(__file__).suffix in (".pyd", ".so")
90
91 # types
92 FileContent = str
93 Encoding = str
94 NewLine = str
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102     COLOR_DIFF = 4
103
104     @classmethod
105     def from_configuration(
106         cls, *, check: bool, diff: bool, color: bool = False
107     ) -> "WriteBack":
108         if check and not diff:
109             return cls.CHECK
110
111         if diff and color:
112             return cls.COLOR_DIFF
113
114         return cls.DIFF if diff else cls.YES
115
116
117 # Legacy name, left for integrations.
118 FileMode = Mode
119
120
121 def read_pyproject_toml(
122     ctx: click.Context, param: click.Parameter, value: Optional[str]
123 ) -> Optional[str]:
124     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
125
126     Returns the path to a successfully found and read configuration file, None
127     otherwise.
128     """
129     if not value:
130         value = find_pyproject_toml(
131             ctx.params.get("src", ()), ctx.params.get("stdin_filename", None)
132         )
133         if value is None:
134             return None
135
136     try:
137         config = parse_pyproject_toml(value)
138     except (OSError, ValueError) as e:
139         raise click.FileError(
140             filename=value, hint=f"Error reading configuration file: {e}"
141         ) from None
142
143     if not config:
144         return None
145     else:
146         # Sanitize the values to be Click friendly. For more information please see:
147         # https://github.com/psf/black/issues/1458
148         # https://github.com/pallets/click/issues/1567
149         config = {
150             k: str(v) if not isinstance(v, (list, dict)) else v
151             for k, v in config.items()
152         }
153
154     target_version = config.get("target_version")
155     if target_version is not None and not isinstance(target_version, list):
156         raise click.BadOptionUsage(
157             "target-version", "Config key target-version must be a list"
158         )
159
160     exclude = config.get("exclude")
161     if exclude is not None and not isinstance(exclude, str):
162         raise click.BadOptionUsage("exclude", "Config key exclude must be a string")
163
164     extend_exclude = config.get("extend_exclude")
165     if extend_exclude is not None and not isinstance(extend_exclude, str):
166         raise click.BadOptionUsage(
167             "extend-exclude", "Config key extend-exclude must be a string"
168         )
169
170     default_map: Dict[str, Any] = {}
171     if ctx.default_map:
172         default_map.update(ctx.default_map)
173     default_map.update(config)
174
175     ctx.default_map = default_map
176     return value
177
178
179 def target_version_option_callback(
180     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
181 ) -> List[TargetVersion]:
182     """Compute the target versions from a --target-version flag.
183
184     This is its own function because mypy couldn't infer the type correctly
185     when it was a lambda, causing mypyc trouble.
186     """
187     return [TargetVersion[val.upper()] for val in v]
188
189
190 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
191     """Compile a regular expression string in `regex`.
192
193     If it contains newlines, use verbose mode.
194     """
195     if "\n" in regex:
196         regex = "(?x)" + regex
197     compiled: Pattern[str] = re.compile(regex)
198     return compiled
199
200
201 def validate_regex(
202     ctx: click.Context,
203     param: click.Parameter,
204     value: Optional[str],
205 ) -> Optional[Pattern[str]]:
206     try:
207         return re_compile_maybe_verbose(value) if value is not None else None
208     except re.error as e:
209         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
210
211
212 @click.command(
213     context_settings={"help_option_names": ["-h", "--help"]},
214     # While Click does set this field automatically using the docstring, mypyc
215     # (annoyingly) strips 'em so we need to set it here too.
216     help="The uncompromising code formatter.",
217 )
218 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
219 @click.option(
220     "-l",
221     "--line-length",
222     type=int,
223     default=DEFAULT_LINE_LENGTH,
224     help="How many characters per line to allow.",
225     show_default=True,
226 )
227 @click.option(
228     "-t",
229     "--target-version",
230     type=click.Choice([v.name.lower() for v in TargetVersion]),
231     callback=target_version_option_callback,
232     multiple=True,
233     help=(
234         "Python versions that should be supported by Black's output. By default, Black"
235         " will try to infer this from the project metadata in pyproject.toml. If this"
236         " does not yield conclusive results, Black will use per-file auto-detection."
237     ),
238 )
239 @click.option(
240     "--pyi",
241     is_flag=True,
242     help=(
243         "Format all input files like typing stubs regardless of file extension (useful"
244         " when piping source on standard input)."
245     ),
246 )
247 @click.option(
248     "--ipynb",
249     is_flag=True,
250     help=(
251         "Format all input files like Jupyter Notebooks regardless of file extension "
252         "(useful when piping source on standard input)."
253     ),
254 )
255 @click.option(
256     "--python-cell-magics",
257     multiple=True,
258     help=(
259         "When processing Jupyter Notebooks, add the given magic to the list"
260         f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
261         " Useful for formatting cells with custom python magics."
262     ),
263     default=[],
264 )
265 @click.option(
266     "-x",
267     "--skip-source-first-line",
268     is_flag=True,
269     help="Skip the first line of the source code.",
270 )
271 @click.option(
272     "-S",
273     "--skip-string-normalization",
274     is_flag=True,
275     help="Don't normalize string quotes or prefixes.",
276 )
277 @click.option(
278     "-C",
279     "--skip-magic-trailing-comma",
280     is_flag=True,
281     help="Don't use trailing commas as a reason to split lines.",
282 )
283 @click.option(
284     "--experimental-string-processing",
285     is_flag=True,
286     hidden=True,
287     help="(DEPRECATED and now included in --preview) Normalize string literals.",
288 )
289 @click.option(
290     "--preview",
291     is_flag=True,
292     help=(
293         "Enable potentially disruptive style changes that may be added to Black's main"
294         " functionality in the next major release."
295     ),
296 )
297 @click.option(
298     "--check",
299     is_flag=True,
300     help=(
301         "Don't write the files back, just return the status. Return code 0 means"
302         " nothing would change. Return code 1 means some files would be reformatted."
303         " Return code 123 means there was an internal error."
304     ),
305 )
306 @click.option(
307     "--diff",
308     is_flag=True,
309     help="Don't write the files back, just output a diff for each file on stdout.",
310 )
311 @click.option(
312     "--color/--no-color",
313     is_flag=True,
314     help="Show colored diff. Only applies when `--diff` is given.",
315 )
316 @click.option(
317     "--fast/--safe",
318     is_flag=True,
319     help="If --fast given, skip temporary sanity checks. [default: --safe]",
320 )
321 @click.option(
322     "--required-version",
323     type=str,
324     help=(
325         "Require a specific version of Black to be running (useful for unifying results"
326         " across many environments e.g. with a pyproject.toml file). It can be"
327         " either a major version number or an exact version."
328     ),
329 )
330 @click.option(
331     "--include",
332     type=str,
333     default=DEFAULT_INCLUDES,
334     callback=validate_regex,
335     help=(
336         "A regular expression that matches files and directories that should be"
337         " included on recursive searches. An empty value means all files are included"
338         " regardless of the name. Use forward slashes for directories on all platforms"
339         " (Windows, too). Exclusions are calculated first, inclusions later."
340     ),
341     show_default=True,
342 )
343 @click.option(
344     "--exclude",
345     type=str,
346     callback=validate_regex,
347     help=(
348         "A regular expression that matches files and directories that should be"
349         " excluded on recursive searches. An empty value means no paths are excluded."
350         " Use forward slashes for directories on all platforms (Windows, too)."
351         " Exclusions are calculated first, inclusions later. [default:"
352         f" {DEFAULT_EXCLUDES}]"
353     ),
354     show_default=False,
355 )
356 @click.option(
357     "--extend-exclude",
358     type=str,
359     callback=validate_regex,
360     help=(
361         "Like --exclude, but adds additional files and directories on top of the"
362         " excluded ones. (Useful if you simply want to add to the default)"
363     ),
364 )
365 @click.option(
366     "--force-exclude",
367     type=str,
368     callback=validate_regex,
369     help=(
370         "Like --exclude, but files and directories matching this regex will be "
371         "excluded even when they are passed explicitly as arguments."
372     ),
373 )
374 @click.option(
375     "--stdin-filename",
376     type=str,
377     is_eager=True,
378     help=(
379         "The name of the file when passing it through stdin. Useful to make "
380         "sure Black will respect --force-exclude option on some "
381         "editors that rely on using stdin."
382     ),
383 )
384 @click.option(
385     "-W",
386     "--workers",
387     type=click.IntRange(min=1),
388     default=None,
389     help=(
390         "Number of parallel workers [default: BLACK_NUM_WORKERS environment variable "
391         "or number of CPUs in the system]"
392     ),
393 )
394 @click.option(
395     "-q",
396     "--quiet",
397     is_flag=True,
398     help=(
399         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
400         " those with 2>/dev/null."
401     ),
402 )
403 @click.option(
404     "-v",
405     "--verbose",
406     is_flag=True,
407     help=(
408         "Also emit messages to stderr about files that were not changed or were ignored"
409         " due to exclusion patterns."
410     ),
411 )
412 @click.version_option(
413     version=__version__,
414     message=(
415         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
416         f"Python ({platform.python_implementation()}) {platform.python_version()}"
417     ),
418 )
419 @click.argument(
420     "src",
421     nargs=-1,
422     type=click.Path(
423         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
424     ),
425     is_eager=True,
426     metavar="SRC ...",
427 )
428 @click.option(
429     "--config",
430     type=click.Path(
431         exists=True,
432         file_okay=True,
433         dir_okay=False,
434         readable=True,
435         allow_dash=False,
436         path_type=str,
437     ),
438     is_eager=True,
439     callback=read_pyproject_toml,
440     help="Read configuration from FILE path.",
441 )
442 @click.pass_context
443 def main(  # noqa: C901
444     ctx: click.Context,
445     code: Optional[str],
446     line_length: int,
447     target_version: List[TargetVersion],
448     check: bool,
449     diff: bool,
450     color: bool,
451     fast: bool,
452     pyi: bool,
453     ipynb: bool,
454     python_cell_magics: Sequence[str],
455     skip_source_first_line: bool,
456     skip_string_normalization: bool,
457     skip_magic_trailing_comma: bool,
458     experimental_string_processing: bool,
459     preview: bool,
460     quiet: bool,
461     verbose: bool,
462     required_version: Optional[str],
463     include: Pattern[str],
464     exclude: Optional[Pattern[str]],
465     extend_exclude: Optional[Pattern[str]],
466     force_exclude: Optional[Pattern[str]],
467     stdin_filename: Optional[str],
468     workers: Optional[int],
469     src: Tuple[str, ...],
470     config: Optional[str],
471 ) -> None:
472     """The uncompromising code formatter."""
473     ctx.ensure_object(dict)
474
475     if src and code is not None:
476         out(
477             main.get_usage(ctx)
478             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
479         )
480         ctx.exit(1)
481     if not src and code is None:
482         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
483         ctx.exit(1)
484
485     root, method = (
486         find_project_root(src, stdin_filename) if code is None else (None, None)
487     )
488     ctx.obj["root"] = root
489
490     if verbose:
491         if root:
492             out(
493                 f"Identified `{root}` as project root containing a {method}.",
494                 fg="blue",
495             )
496
497         if config:
498             config_source = ctx.get_parameter_source("config")
499             user_level_config = str(find_user_pyproject_toml())
500             if config == user_level_config:
501                 out(
502                     "Using configuration from user-level config at "
503                     f"'{user_level_config}'.",
504                     fg="blue",
505                 )
506             elif config_source in (
507                 ParameterSource.DEFAULT,
508                 ParameterSource.DEFAULT_MAP,
509             ):
510                 out("Using configuration from project root.", fg="blue")
511             else:
512                 out(f"Using configuration in '{config}'.", fg="blue")
513             if ctx.default_map:
514                 for param, value in ctx.default_map.items():
515                     out(f"{param}: {value}")
516
517     error_msg = "Oh no! 💥 💔 💥"
518     if (
519         required_version
520         and required_version != __version__
521         and required_version != __version__.split(".")[0]
522     ):
523         err(
524             f"{error_msg} The required version `{required_version}` does not match"
525             f" the running version `{__version__}`!"
526         )
527         ctx.exit(1)
528     if ipynb and pyi:
529         err("Cannot pass both `pyi` and `ipynb` flags!")
530         ctx.exit(1)
531
532     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
533     if target_version:
534         versions = set(target_version)
535     else:
536         # We'll autodetect later.
537         versions = set()
538     mode = Mode(
539         target_versions=versions,
540         line_length=line_length,
541         is_pyi=pyi,
542         is_ipynb=ipynb,
543         skip_source_first_line=skip_source_first_line,
544         string_normalization=not skip_string_normalization,
545         magic_trailing_comma=not skip_magic_trailing_comma,
546         experimental_string_processing=experimental_string_processing,
547         preview=preview,
548         python_cell_magics=set(python_cell_magics),
549     )
550
551     if code is not None:
552         # Run in quiet mode by default with -c; the extra output isn't useful.
553         # You can still pass -v to get verbose output.
554         quiet = True
555
556     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
557
558     if code is not None:
559         reformat_code(
560             content=code, fast=fast, write_back=write_back, mode=mode, report=report
561         )
562     else:
563         assert root is not None  # root is only None if code is not None
564         try:
565             sources = get_sources(
566                 root=root,
567                 src=src,
568                 quiet=quiet,
569                 verbose=verbose,
570                 include=include,
571                 exclude=exclude,
572                 extend_exclude=extend_exclude,
573                 force_exclude=force_exclude,
574                 report=report,
575                 stdin_filename=stdin_filename,
576             )
577         except GitWildMatchPatternError:
578             ctx.exit(1)
579
580         path_empty(
581             sources,
582             "No Python files are present to be formatted. Nothing to do 😴",
583             quiet,
584             verbose,
585             ctx,
586         )
587
588         if len(sources) == 1:
589             reformat_one(
590                 src=sources.pop(),
591                 fast=fast,
592                 write_back=write_back,
593                 mode=mode,
594                 report=report,
595             )
596         else:
597             from black.concurrency import reformat_many
598
599             reformat_many(
600                 sources=sources,
601                 fast=fast,
602                 write_back=write_back,
603                 mode=mode,
604                 report=report,
605                 workers=workers,
606             )
607
608     if verbose or not quiet:
609         if code is None and (verbose or report.change_count or report.failure_count):
610             out()
611         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
612         if code is None:
613             click.echo(str(report), err=True)
614     ctx.exit(report.return_code)
615
616
617 def get_sources(
618     *,
619     root: Path,
620     src: Tuple[str, ...],
621     quiet: bool,
622     verbose: bool,
623     include: Pattern[str],
624     exclude: Optional[Pattern[str]],
625     extend_exclude: Optional[Pattern[str]],
626     force_exclude: Optional[Pattern[str]],
627     report: "Report",
628     stdin_filename: Optional[str],
629 ) -> Set[Path]:
630     """Compute the set of files to be formatted."""
631     sources: Set[Path] = set()
632
633     using_default_exclude = exclude is None
634     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
635     gitignore: Optional[Dict[Path, PathSpec]] = None
636     root_gitignore = get_gitignore(root)
637
638     for s in src:
639         if s == "-" and stdin_filename:
640             p = Path(stdin_filename)
641             is_stdin = True
642         else:
643             p = Path(s)
644             is_stdin = False
645
646         if is_stdin or p.is_file():
647             normalized_path: Optional[str] = normalize_path_maybe_ignore(
648                 p, root, report
649             )
650             if normalized_path is None:
651                 if verbose:
652                     out(f'Skipping invalid source: "{normalized_path}"', fg="red")
653                 continue
654             if verbose:
655                 out(f'Found input source: "{normalized_path}"', fg="blue")
656
657             normalized_path = "/" + normalized_path
658             # Hard-exclude any files that matches the `--force-exclude` regex.
659             if force_exclude:
660                 force_exclude_match = force_exclude.search(normalized_path)
661             else:
662                 force_exclude_match = None
663             if force_exclude_match and force_exclude_match.group(0):
664                 report.path_ignored(p, "matches the --force-exclude regular expression")
665                 continue
666
667             if is_stdin:
668                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
669
670             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
671                 warn=verbose or not quiet
672             ):
673                 continue
674
675             sources.add(p)
676         elif p.is_dir():
677             p_relative = normalize_path_maybe_ignore(p, root, report)
678             assert p_relative is not None
679             p = root / p_relative
680             if verbose:
681                 out(f'Found input source directory: "{p}"', fg="blue")
682
683             if using_default_exclude:
684                 gitignore = {
685                     root: root_gitignore,
686                     p: get_gitignore(p),
687                 }
688             sources.update(
689                 gen_python_files(
690                     p.iterdir(),
691                     root,
692                     include,
693                     exclude,
694                     extend_exclude,
695                     force_exclude,
696                     report,
697                     gitignore,
698                     verbose=verbose,
699                     quiet=quiet,
700                 )
701             )
702         elif s == "-":
703             if verbose:
704                 out("Found input source stdin", fg="blue")
705             sources.add(p)
706         else:
707             err(f"invalid path: {s}")
708
709     return sources
710
711
712 def path_empty(
713     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
714 ) -> None:
715     """
716     Exit if there is no `src` provided for formatting
717     """
718     if not src:
719         if verbose or not quiet:
720             out(msg)
721         ctx.exit(0)
722
723
724 def reformat_code(
725     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
726 ) -> None:
727     """
728     Reformat and print out `content` without spawning child processes.
729     Similar to `reformat_one`, but for string content.
730
731     `fast`, `write_back`, and `mode` options are passed to
732     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
733     """
734     path = Path("<string>")
735     try:
736         changed = Changed.NO
737         if format_stdin_to_stdout(
738             content=content, fast=fast, write_back=write_back, mode=mode
739         ):
740             changed = Changed.YES
741         report.done(path, changed)
742     except Exception as exc:
743         if report.verbose:
744             traceback.print_exc()
745         report.failed(path, str(exc))
746
747
748 # diff-shades depends on being to monkeypatch this function to operate. I know it's
749 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
750 @mypyc_attr(patchable=True)
751 def reformat_one(
752     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
753 ) -> None:
754     """Reformat a single file under `src` without spawning child processes.
755
756     `fast`, `write_back`, and `mode` options are passed to
757     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
758     """
759     try:
760         changed = Changed.NO
761
762         if str(src) == "-":
763             is_stdin = True
764         elif str(src).startswith(STDIN_PLACEHOLDER):
765             is_stdin = True
766             # Use the original name again in case we want to print something
767             # to the user
768             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
769         else:
770             is_stdin = False
771
772         if is_stdin:
773             if src.suffix == ".pyi":
774                 mode = replace(mode, is_pyi=True)
775             elif src.suffix == ".ipynb":
776                 mode = replace(mode, is_ipynb=True)
777             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
778                 changed = Changed.YES
779         else:
780             cache = Cache.read(mode)
781             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
782                 if not cache.is_changed(src):
783                     changed = Changed.CACHED
784             if changed is not Changed.CACHED and format_file_in_place(
785                 src, fast=fast, write_back=write_back, mode=mode
786             ):
787                 changed = Changed.YES
788             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
789                 write_back is WriteBack.CHECK and changed is Changed.NO
790             ):
791                 cache.write([src])
792         report.done(src, changed)
793     except Exception as exc:
794         if report.verbose:
795             traceback.print_exc()
796         report.failed(src, str(exc))
797
798
799 def format_file_in_place(
800     src: Path,
801     fast: bool,
802     mode: Mode,
803     write_back: WriteBack = WriteBack.NO,
804     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
805 ) -> bool:
806     """Format file under `src` path. Return True if changed.
807
808     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
809     code to the file.
810     `mode` and `fast` options are passed to :func:`format_file_contents`.
811     """
812     if src.suffix == ".pyi":
813         mode = replace(mode, is_pyi=True)
814     elif src.suffix == ".ipynb":
815         mode = replace(mode, is_ipynb=True)
816
817     then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)
818     header = b""
819     with open(src, "rb") as buf:
820         if mode.skip_source_first_line:
821             header = buf.readline()
822         src_contents, encoding, newline = decode_bytes(buf.read())
823     try:
824         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
825     except NothingChanged:
826         return False
827     except JSONDecodeError:
828         raise ValueError(
829             f"File '{src}' cannot be parsed as valid Jupyter notebook."
830         ) from None
831     src_contents = header.decode(encoding) + src_contents
832     dst_contents = header.decode(encoding) + dst_contents
833
834     if write_back == WriteBack.YES:
835         with open(src, "w", encoding=encoding, newline=newline) as f:
836             f.write(dst_contents)
837     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
838         now = datetime.now(timezone.utc)
839         src_name = f"{src}\t{then}"
840         dst_name = f"{src}\t{now}"
841         if mode.is_ipynb:
842             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
843         else:
844             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
845
846         if write_back == WriteBack.COLOR_DIFF:
847             diff_contents = color_diff(diff_contents)
848
849         with lock or nullcontext():
850             f = io.TextIOWrapper(
851                 sys.stdout.buffer,
852                 encoding=encoding,
853                 newline=newline,
854                 write_through=True,
855             )
856             f = wrap_stream_for_windows(f)
857             f.write(diff_contents)
858             f.detach()
859
860     return True
861
862
863 def format_stdin_to_stdout(
864     fast: bool,
865     *,
866     content: Optional[str] = None,
867     write_back: WriteBack = WriteBack.NO,
868     mode: Mode,
869 ) -> bool:
870     """Format file on stdin. Return True if changed.
871
872     If content is None, it's read from sys.stdin.
873
874     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
875     write a diff to stdout. The `mode` argument is passed to
876     :func:`format_file_contents`.
877     """
878     then = datetime.now(timezone.utc)
879
880     if content is None:
881         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
882     else:
883         src, encoding, newline = content, "utf-8", ""
884
885     dst = src
886     try:
887         dst = format_file_contents(src, fast=fast, mode=mode)
888         return True
889
890     except NothingChanged:
891         return False
892
893     finally:
894         f = io.TextIOWrapper(
895             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
896         )
897         if write_back == WriteBack.YES:
898             # Make sure there's a newline after the content
899             if dst and dst[-1] != "\n":
900                 dst += "\n"
901             f.write(dst)
902         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
903             now = datetime.now(timezone.utc)
904             src_name = f"STDIN\t{then}"
905             dst_name = f"STDOUT\t{now}"
906             d = diff(src, dst, src_name, dst_name)
907             if write_back == WriteBack.COLOR_DIFF:
908                 d = color_diff(d)
909                 f = wrap_stream_for_windows(f)
910             f.write(d)
911         f.detach()
912
913
914 def check_stability_and_equivalence(
915     src_contents: str, dst_contents: str, *, mode: Mode
916 ) -> None:
917     """Perform stability and equivalence checks.
918
919     Raise AssertionError if source and destination contents are not
920     equivalent, or if a second pass of the formatter would format the
921     content differently.
922     """
923     assert_equivalent(src_contents, dst_contents)
924     assert_stable(src_contents, dst_contents, mode=mode)
925
926
927 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
928     """Reformat contents of a file and return new contents.
929
930     If `fast` is False, additionally confirm that the reformatted code is
931     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
932     `mode` is passed to :func:`format_str`.
933     """
934     if mode.is_ipynb:
935         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
936     else:
937         dst_contents = format_str(src_contents, mode=mode)
938     if src_contents == dst_contents:
939         raise NothingChanged
940
941     if not fast and not mode.is_ipynb:
942         # Jupyter notebooks will already have been checked above.
943         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
944     return dst_contents
945
946
947 def validate_cell(src: str, mode: Mode) -> None:
948     """Check that cell does not already contain TransformerManager transformations,
949     or non-Python cell magics, which might cause tokenizer_rt to break because of
950     indentations.
951
952     If a cell contains ``!ls``, then it'll be transformed to
953     ``get_ipython().system('ls')``. However, if the cell originally contained
954     ``get_ipython().system('ls')``, then it would get transformed in the same way:
955
956         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
957         "get_ipython().system('ls')\n"
958         >>> TransformerManager().transform_cell("!ls")
959         "get_ipython().system('ls')\n"
960
961     Due to the impossibility of safely roundtripping in such situations, cells
962     containing transformed magics will be ignored.
963     """
964     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
965         raise NothingChanged
966     if (
967         src[:2] == "%%"
968         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
969     ):
970         raise NothingChanged
971
972
973 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
974     """Format code in given cell of Jupyter notebook.
975
976     General idea is:
977
978       - if cell has trailing semicolon, remove it;
979       - if cell has IPython magics, mask them;
980       - format cell;
981       - reinstate IPython magics;
982       - reinstate trailing semicolon (if originally present);
983       - strip trailing newlines.
984
985     Cells with syntax errors will not be processed, as they
986     could potentially be automagics or multi-line magics, which
987     are currently not supported.
988     """
989     validate_cell(src, mode)
990     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
991         src
992     )
993     try:
994         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
995     except SyntaxError:
996         raise NothingChanged from None
997     masked_dst = format_str(masked_src, mode=mode)
998     if not fast:
999         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
1000     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1001     dst = put_trailing_semicolon_back(
1002         dst_without_trailing_semicolon, has_trailing_semicolon
1003     )
1004     dst = dst.rstrip("\n")
1005     if dst == src:
1006         raise NothingChanged from None
1007     return dst
1008
1009
1010 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1011     """If notebook is marked as non-Python, don't format it.
1012
1013     All notebook metadata fields are optional, see
1014     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1015     if a notebook has empty metadata, we will try to parse it anyway.
1016     """
1017     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1018     if language is not None and language != "python":
1019         raise NothingChanged from None
1020
1021
1022 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1023     """Format Jupyter notebook.
1024
1025     Operate cell-by-cell, only on code cells, only for Python notebooks.
1026     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1027     """
1028     if not src_contents:
1029         raise NothingChanged
1030
1031     trailing_newline = src_contents[-1] == "\n"
1032     modified = False
1033     nb = json.loads(src_contents)
1034     validate_metadata(nb)
1035     for cell in nb["cells"]:
1036         if cell.get("cell_type", None) == "code":
1037             try:
1038                 src = "".join(cell["source"])
1039                 dst = format_cell(src, fast=fast, mode=mode)
1040             except NothingChanged:
1041                 pass
1042             else:
1043                 cell["source"] = dst.splitlines(keepends=True)
1044                 modified = True
1045     if modified:
1046         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1047         if trailing_newline:
1048             dst_contents = dst_contents + "\n"
1049         return dst_contents
1050     else:
1051         raise NothingChanged
1052
1053
1054 def format_str(src_contents: str, *, mode: Mode) -> str:
1055     """Reformat a string and return new contents.
1056
1057     `mode` determines formatting options, such as how many characters per line are
1058     allowed.  Example:
1059
1060     >>> import black
1061     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1062     def f(arg: str = "") -> None:
1063         ...
1064
1065     A more complex example:
1066
1067     >>> print(
1068     ...   black.format_str(
1069     ...     "def f(arg:str='')->None: hey",
1070     ...     mode=black.Mode(
1071     ...       target_versions={black.TargetVersion.PY36},
1072     ...       line_length=10,
1073     ...       string_normalization=False,
1074     ...       is_pyi=False,
1075     ...     ),
1076     ...   ),
1077     ... )
1078     def f(
1079         arg: str = '',
1080     ) -> None:
1081         hey
1082
1083     """
1084     dst_contents = _format_str_once(src_contents, mode=mode)
1085     # Forced second pass to work around optional trailing commas (becoming
1086     # forced trailing commas on pass 2) interacting differently with optional
1087     # parentheses.  Admittedly ugly.
1088     if src_contents != dst_contents:
1089         return _format_str_once(dst_contents, mode=mode)
1090     return dst_contents
1091
1092
1093 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1094     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1095     dst_blocks: List[LinesBlock] = []
1096     if mode.target_versions:
1097         versions = mode.target_versions
1098     else:
1099         future_imports = get_future_imports(src_node)
1100         versions = detect_target_versions(src_node, future_imports=future_imports)
1101
1102     context_manager_features = {
1103         feature
1104         for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
1105         if supports_feature(versions, feature)
1106     }
1107     normalize_fmt_off(src_node)
1108     lines = LineGenerator(mode=mode, features=context_manager_features)
1109     elt = EmptyLineTracker(mode=mode)
1110     split_line_features = {
1111         feature
1112         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1113         if supports_feature(versions, feature)
1114     }
1115     block: Optional[LinesBlock] = None
1116     for current_line in lines.visit(src_node):
1117         block = elt.maybe_empty_lines(current_line)
1118         dst_blocks.append(block)
1119         for line in transform_line(
1120             current_line, mode=mode, features=split_line_features
1121         ):
1122             block.content_lines.append(str(line))
1123     if dst_blocks:
1124         dst_blocks[-1].after = 0
1125     dst_contents = []
1126     for block in dst_blocks:
1127         dst_contents.extend(block.all_lines())
1128     if not dst_contents:
1129         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1130         # and check if normalized_content has more than one line
1131         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1132         if "\n" in normalized_content:
1133             return newline
1134         return ""
1135     return "".join(dst_contents)
1136
1137
1138 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1139     """Return a tuple of (decoded_contents, encoding, newline).
1140
1141     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1142     universal newlines (i.e. only contains LF).
1143     """
1144     srcbuf = io.BytesIO(src)
1145     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1146     if not lines:
1147         return "", encoding, "\n"
1148
1149     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1150     srcbuf.seek(0)
1151     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1152         return tiow.read(), encoding, newline
1153
1154
1155 def get_features_used(  # noqa: C901
1156     node: Node, *, future_imports: Optional[Set[str]] = None
1157 ) -> Set[Feature]:
1158     """Return a set of (relatively) new Python features used in this file.
1159
1160     Currently looking for:
1161     - f-strings;
1162     - self-documenting expressions in f-strings (f"{x=}");
1163     - underscores in numeric literals;
1164     - trailing commas after * or ** in function signatures and calls;
1165     - positional only arguments in function signatures and lambdas;
1166     - assignment expression;
1167     - relaxed decorator syntax;
1168     - usage of __future__ flags (annotations);
1169     - print / exec statements;
1170     - parenthesized context managers;
1171     - match statements;
1172     - except* clause;
1173     - variadic generics;
1174     """
1175     features: Set[Feature] = set()
1176     if future_imports:
1177         features |= {
1178             FUTURE_FLAG_TO_FEATURE[future_import]
1179             for future_import in future_imports
1180             if future_import in FUTURE_FLAG_TO_FEATURE
1181         }
1182
1183     for n in node.pre_order():
1184         if is_string_token(n):
1185             value_head = n.value[:2]
1186             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1187                 features.add(Feature.F_STRINGS)
1188                 if Feature.DEBUG_F_STRINGS not in features:
1189                     for span_beg, span_end in iter_fexpr_spans(n.value):
1190                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1191                             features.add(Feature.DEBUG_F_STRINGS)
1192                             break
1193
1194         elif is_number_token(n):
1195             if "_" in n.value:
1196                 features.add(Feature.NUMERIC_UNDERSCORES)
1197
1198         elif n.type == token.SLASH:
1199             if n.parent and n.parent.type in {
1200                 syms.typedargslist,
1201                 syms.arglist,
1202                 syms.varargslist,
1203             }:
1204                 features.add(Feature.POS_ONLY_ARGUMENTS)
1205
1206         elif n.type == token.COLONEQUAL:
1207             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1208
1209         elif n.type == syms.decorator:
1210             if len(n.children) > 1 and not is_simple_decorator_expression(
1211                 n.children[1]
1212             ):
1213                 features.add(Feature.RELAXED_DECORATORS)
1214
1215         elif (
1216             n.type in {syms.typedargslist, syms.arglist}
1217             and n.children
1218             and n.children[-1].type == token.COMMA
1219         ):
1220             if n.type == syms.typedargslist:
1221                 feature = Feature.TRAILING_COMMA_IN_DEF
1222             else:
1223                 feature = Feature.TRAILING_COMMA_IN_CALL
1224
1225             for ch in n.children:
1226                 if ch.type in STARS:
1227                     features.add(feature)
1228
1229                 if ch.type == syms.argument:
1230                     for argch in ch.children:
1231                         if argch.type in STARS:
1232                             features.add(feature)
1233
1234         elif (
1235             n.type in {syms.return_stmt, syms.yield_expr}
1236             and len(n.children) >= 2
1237             and n.children[1].type == syms.testlist_star_expr
1238             and any(child.type == syms.star_expr for child in n.children[1].children)
1239         ):
1240             features.add(Feature.UNPACKING_ON_FLOW)
1241
1242         elif (
1243             n.type == syms.annassign
1244             and len(n.children) >= 4
1245             and n.children[3].type == syms.testlist_star_expr
1246         ):
1247             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1248
1249         elif (
1250             n.type == syms.with_stmt
1251             and len(n.children) > 2
1252             and n.children[1].type == syms.atom
1253         ):
1254             atom_children = n.children[1].children
1255             if (
1256                 len(atom_children) == 3
1257                 and atom_children[0].type == token.LPAR
1258                 and atom_children[1].type == syms.testlist_gexp
1259                 and atom_children[2].type == token.RPAR
1260             ):
1261                 features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
1262
1263         elif n.type == syms.match_stmt:
1264             features.add(Feature.PATTERN_MATCHING)
1265
1266         elif (
1267             n.type == syms.except_clause
1268             and len(n.children) >= 2
1269             and n.children[1].type == token.STAR
1270         ):
1271             features.add(Feature.EXCEPT_STAR)
1272
1273         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1274             child.type == syms.star_expr for child in n.children
1275         ):
1276             features.add(Feature.VARIADIC_GENERICS)
1277
1278         elif (
1279             n.type == syms.tname_star
1280             and len(n.children) == 3
1281             and n.children[2].type == syms.star_expr
1282         ):
1283             features.add(Feature.VARIADIC_GENERICS)
1284
1285         elif n.type in (syms.type_stmt, syms.typeparams):
1286             features.add(Feature.TYPE_PARAMS)
1287
1288     return features
1289
1290
1291 def detect_target_versions(
1292     node: Node, *, future_imports: Optional[Set[str]] = None
1293 ) -> Set[TargetVersion]:
1294     """Detect the version to target based on the nodes used."""
1295     features = get_features_used(node, future_imports=future_imports)
1296     return {
1297         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1298     }
1299
1300
1301 def get_future_imports(node: Node) -> Set[str]:
1302     """Return a set of __future__ imports in the file."""
1303     imports: Set[str] = set()
1304
1305     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1306         for child in children:
1307             if isinstance(child, Leaf):
1308                 if child.type == token.NAME:
1309                     yield child.value
1310
1311             elif child.type == syms.import_as_name:
1312                 orig_name = child.children[0]
1313                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1314                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1315                 yield orig_name.value
1316
1317             elif child.type == syms.import_as_names:
1318                 yield from get_imports_from_children(child.children)
1319
1320             else:
1321                 raise AssertionError("Invalid syntax parsing imports")
1322
1323     for child in node.children:
1324         if child.type != syms.simple_stmt:
1325             break
1326
1327         first_child = child.children[0]
1328         if isinstance(first_child, Leaf):
1329             # Continue looking if we see a docstring; otherwise stop.
1330             if (
1331                 len(child.children) == 2
1332                 and first_child.type == token.STRING
1333                 and child.children[1].type == token.NEWLINE
1334             ):
1335                 continue
1336
1337             break
1338
1339         elif first_child.type == syms.import_from:
1340             module_name = first_child.children[1]
1341             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1342                 break
1343
1344             imports |= set(get_imports_from_children(first_child.children[3:]))
1345         else:
1346             break
1347
1348     return imports
1349
1350
1351 def assert_equivalent(src: str, dst: str) -> None:
1352     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1353     try:
1354         src_ast = parse_ast(src)
1355     except Exception as exc:
1356         raise AssertionError(
1357             "cannot use --safe with this file; failed to parse source file AST: "
1358             f"{exc}\n"
1359             "This could be caused by running Black with an older Python version "
1360             "that does not support new syntax used in your source file."
1361         ) from exc
1362
1363     try:
1364         dst_ast = parse_ast(dst)
1365     except Exception as exc:
1366         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1367         raise AssertionError(
1368             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1369             "Please report a bug on https://github.com/psf/black/issues.  "
1370             f"This invalid output might be helpful: {log}"
1371         ) from None
1372
1373     src_ast_str = "\n".join(stringify_ast(src_ast))
1374     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1375     if src_ast_str != dst_ast_str:
1376         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1377         raise AssertionError(
1378             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1379             " source.  Please report a bug on "
1380             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1381         ) from None
1382
1383
1384 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1385     """Raise AssertionError if `dst` reformats differently the second time."""
1386     # We shouldn't call format_str() here, because that formats the string
1387     # twice and may hide a bug where we bounce back and forth between two
1388     # versions.
1389     newdst = _format_str_once(dst, mode=mode)
1390     if dst != newdst:
1391         log = dump_to_file(
1392             str(mode),
1393             diff(src, dst, "source", "first pass"),
1394             diff(dst, newdst, "first pass", "second pass"),
1395         )
1396         raise AssertionError(
1397             "INTERNAL ERROR: Black produced different code on the second pass of the"
1398             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1399             f"  This diff might be helpful: {log}"
1400         ) from None
1401
1402
1403 @contextmanager
1404 def nullcontext() -> Iterator[None]:
1405     """Return an empty context manager.
1406
1407     To be used like `nullcontext` in Python 3.7.
1408     """
1409     yield
1410
1411
1412 def patched_main() -> None:
1413     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1414     # environments so just assume we always need to call it if frozen.
1415     if getattr(sys, "frozen", False):
1416         from multiprocessing import freeze_support
1417
1418         freeze_support()
1419
1420     main()
1421
1422
1423 if __name__ == "__main__":
1424     patched_main()