]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Document pre-commit mirror (#3828)
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime, timezone
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache, get_cache_info, read_cache, write_cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import (
67     FUTURE_FLAG_TO_FEATURE,
68     VERSION_TO_FEATURES,
69     Feature,
70     Mode,
71     TargetVersion,
72     supports_feature,
73 )
74 from black.nodes import (
75     STARS,
76     is_number_token,
77     is_simple_decorator_expression,
78     is_string_token,
79     syms,
80 )
81 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
82 from black.parsing import InvalidInput  # noqa F401
83 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
84 from black.report import Changed, NothingChanged, Report
85 from black.trans import iter_fexpr_spans
86 from blib2to3.pgen2 import token
87 from blib2to3.pytree import Leaf, Node
88
89 COMPILED = Path(__file__).suffix in (".pyd", ".so")
90
91 # types
92 FileContent = str
93 Encoding = str
94 NewLine = str
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102     COLOR_DIFF = 4
103
104     @classmethod
105     def from_configuration(
106         cls, *, check: bool, diff: bool, color: bool = False
107     ) -> "WriteBack":
108         if check and not diff:
109             return cls.CHECK
110
111         if diff and color:
112             return cls.COLOR_DIFF
113
114         return cls.DIFF if diff else cls.YES
115
116
117 # Legacy name, left for integrations.
118 FileMode = Mode
119
120
121 def read_pyproject_toml(
122     ctx: click.Context, param: click.Parameter, value: Optional[str]
123 ) -> Optional[str]:
124     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
125
126     Returns the path to a successfully found and read configuration file, None
127     otherwise.
128     """
129     if not value:
130         value = find_pyproject_toml(
131             ctx.params.get("src", ()), ctx.params.get("stdin_filename", None)
132         )
133         if value is None:
134             return None
135
136     try:
137         config = parse_pyproject_toml(value)
138     except (OSError, ValueError) as e:
139         raise click.FileError(
140             filename=value, hint=f"Error reading configuration file: {e}"
141         ) from None
142
143     if not config:
144         return None
145     else:
146         # Sanitize the values to be Click friendly. For more information please see:
147         # https://github.com/psf/black/issues/1458
148         # https://github.com/pallets/click/issues/1567
149         config = {
150             k: str(v) if not isinstance(v, (list, dict)) else v
151             for k, v in config.items()
152         }
153
154     target_version = config.get("target_version")
155     if target_version is not None and not isinstance(target_version, list):
156         raise click.BadOptionUsage(
157             "target-version", "Config key target-version must be a list"
158         )
159
160     exclude = config.get("exclude")
161     if exclude is not None and not isinstance(exclude, str):
162         raise click.BadOptionUsage("exclude", "Config key exclude must be a string")
163
164     extend_exclude = config.get("extend_exclude")
165     if extend_exclude is not None and not isinstance(extend_exclude, str):
166         raise click.BadOptionUsage(
167             "extend-exclude", "Config key extend-exclude must be a string"
168         )
169
170     default_map: Dict[str, Any] = {}
171     if ctx.default_map:
172         default_map.update(ctx.default_map)
173     default_map.update(config)
174
175     ctx.default_map = default_map
176     return value
177
178
179 def target_version_option_callback(
180     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
181 ) -> List[TargetVersion]:
182     """Compute the target versions from a --target-version flag.
183
184     This is its own function because mypy couldn't infer the type correctly
185     when it was a lambda, causing mypyc trouble.
186     """
187     return [TargetVersion[val.upper()] for val in v]
188
189
190 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
191     """Compile a regular expression string in `regex`.
192
193     If it contains newlines, use verbose mode.
194     """
195     if "\n" in regex:
196         regex = "(?x)" + regex
197     compiled: Pattern[str] = re.compile(regex)
198     return compiled
199
200
201 def validate_regex(
202     ctx: click.Context,
203     param: click.Parameter,
204     value: Optional[str],
205 ) -> Optional[Pattern[str]]:
206     try:
207         return re_compile_maybe_verbose(value) if value is not None else None
208     except re.error as e:
209         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
210
211
212 @click.command(
213     context_settings={"help_option_names": ["-h", "--help"]},
214     # While Click does set this field automatically using the docstring, mypyc
215     # (annoyingly) strips 'em so we need to set it here too.
216     help="The uncompromising code formatter.",
217 )
218 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
219 @click.option(
220     "-l",
221     "--line-length",
222     type=int,
223     default=DEFAULT_LINE_LENGTH,
224     help="How many characters per line to allow.",
225     show_default=True,
226 )
227 @click.option(
228     "-t",
229     "--target-version",
230     type=click.Choice([v.name.lower() for v in TargetVersion]),
231     callback=target_version_option_callback,
232     multiple=True,
233     help=(
234         "Python versions that should be supported by Black's output. By default, Black"
235         " will try to infer this from the project metadata in pyproject.toml. If this"
236         " does not yield conclusive results, Black will use per-file auto-detection."
237     ),
238 )
239 @click.option(
240     "--pyi",
241     is_flag=True,
242     help=(
243         "Format all input files like typing stubs regardless of file extension (useful"
244         " when piping source on standard input)."
245     ),
246 )
247 @click.option(
248     "--ipynb",
249     is_flag=True,
250     help=(
251         "Format all input files like Jupyter Notebooks regardless of file extension "
252         "(useful when piping source on standard input)."
253     ),
254 )
255 @click.option(
256     "--python-cell-magics",
257     multiple=True,
258     help=(
259         "When processing Jupyter Notebooks, add the given magic to the list"
260         f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
261         " Useful for formatting cells with custom python magics."
262     ),
263     default=[],
264 )
265 @click.option(
266     "-x",
267     "--skip-source-first-line",
268     is_flag=True,
269     help="Skip the first line of the source code.",
270 )
271 @click.option(
272     "-S",
273     "--skip-string-normalization",
274     is_flag=True,
275     help="Don't normalize string quotes or prefixes.",
276 )
277 @click.option(
278     "-C",
279     "--skip-magic-trailing-comma",
280     is_flag=True,
281     help="Don't use trailing commas as a reason to split lines.",
282 )
283 @click.option(
284     "--experimental-string-processing",
285     is_flag=True,
286     hidden=True,
287     help="(DEPRECATED and now included in --preview) Normalize string literals.",
288 )
289 @click.option(
290     "--preview",
291     is_flag=True,
292     help=(
293         "Enable potentially disruptive style changes that may be added to Black's main"
294         " functionality in the next major release."
295     ),
296 )
297 @click.option(
298     "--check",
299     is_flag=True,
300     help=(
301         "Don't write the files back, just return the status. Return code 0 means"
302         " nothing would change. Return code 1 means some files would be reformatted."
303         " Return code 123 means there was an internal error."
304     ),
305 )
306 @click.option(
307     "--diff",
308     is_flag=True,
309     help="Don't write the files back, just output a diff for each file on stdout.",
310 )
311 @click.option(
312     "--color/--no-color",
313     is_flag=True,
314     help="Show colored diff. Only applies when `--diff` is given.",
315 )
316 @click.option(
317     "--fast/--safe",
318     is_flag=True,
319     help="If --fast given, skip temporary sanity checks. [default: --safe]",
320 )
321 @click.option(
322     "--required-version",
323     type=str,
324     help=(
325         "Require a specific version of Black to be running (useful for unifying results"
326         " across many environments e.g. with a pyproject.toml file). It can be"
327         " either a major version number or an exact version."
328     ),
329 )
330 @click.option(
331     "--include",
332     type=str,
333     default=DEFAULT_INCLUDES,
334     callback=validate_regex,
335     help=(
336         "A regular expression that matches files and directories that should be"
337         " included on recursive searches. An empty value means all files are included"
338         " regardless of the name. Use forward slashes for directories on all platforms"
339         " (Windows, too). Exclusions are calculated first, inclusions later."
340     ),
341     show_default=True,
342 )
343 @click.option(
344     "--exclude",
345     type=str,
346     callback=validate_regex,
347     help=(
348         "A regular expression that matches files and directories that should be"
349         " excluded on recursive searches. An empty value means no paths are excluded."
350         " Use forward slashes for directories on all platforms (Windows, too)."
351         " Exclusions are calculated first, inclusions later. [default:"
352         f" {DEFAULT_EXCLUDES}]"
353     ),
354     show_default=False,
355 )
356 @click.option(
357     "--extend-exclude",
358     type=str,
359     callback=validate_regex,
360     help=(
361         "Like --exclude, but adds additional files and directories on top of the"
362         " excluded ones. (Useful if you simply want to add to the default)"
363     ),
364 )
365 @click.option(
366     "--force-exclude",
367     type=str,
368     callback=validate_regex,
369     help=(
370         "Like --exclude, but files and directories matching this regex will be "
371         "excluded even when they are passed explicitly as arguments."
372     ),
373 )
374 @click.option(
375     "--stdin-filename",
376     type=str,
377     is_eager=True,
378     help=(
379         "The name of the file when passing it through stdin. Useful to make "
380         "sure Black will respect --force-exclude option on some "
381         "editors that rely on using stdin."
382     ),
383 )
384 @click.option(
385     "-W",
386     "--workers",
387     type=click.IntRange(min=1),
388     default=None,
389     help=(
390         "Number of parallel workers [default: BLACK_NUM_WORKERS environment variable "
391         "or number of CPUs in the system]"
392     ),
393 )
394 @click.option(
395     "-q",
396     "--quiet",
397     is_flag=True,
398     help=(
399         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
400         " those with 2>/dev/null."
401     ),
402 )
403 @click.option(
404     "-v",
405     "--verbose",
406     is_flag=True,
407     help=(
408         "Also emit messages to stderr about files that were not changed or were ignored"
409         " due to exclusion patterns."
410     ),
411 )
412 @click.version_option(
413     version=__version__,
414     message=(
415         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
416         f"Python ({platform.python_implementation()}) {platform.python_version()}"
417     ),
418 )
419 @click.argument(
420     "src",
421     nargs=-1,
422     type=click.Path(
423         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
424     ),
425     is_eager=True,
426     metavar="SRC ...",
427 )
428 @click.option(
429     "--config",
430     type=click.Path(
431         exists=True,
432         file_okay=True,
433         dir_okay=False,
434         readable=True,
435         allow_dash=False,
436         path_type=str,
437     ),
438     is_eager=True,
439     callback=read_pyproject_toml,
440     help="Read configuration from FILE path.",
441 )
442 @click.pass_context
443 def main(  # noqa: C901
444     ctx: click.Context,
445     code: Optional[str],
446     line_length: int,
447     target_version: List[TargetVersion],
448     check: bool,
449     diff: bool,
450     color: bool,
451     fast: bool,
452     pyi: bool,
453     ipynb: bool,
454     python_cell_magics: Sequence[str],
455     skip_source_first_line: bool,
456     skip_string_normalization: bool,
457     skip_magic_trailing_comma: bool,
458     experimental_string_processing: bool,
459     preview: bool,
460     quiet: bool,
461     verbose: bool,
462     required_version: Optional[str],
463     include: Pattern[str],
464     exclude: Optional[Pattern[str]],
465     extend_exclude: Optional[Pattern[str]],
466     force_exclude: Optional[Pattern[str]],
467     stdin_filename: Optional[str],
468     workers: Optional[int],
469     src: Tuple[str, ...],
470     config: Optional[str],
471 ) -> None:
472     """The uncompromising code formatter."""
473     ctx.ensure_object(dict)
474
475     if src and code is not None:
476         out(
477             main.get_usage(ctx)
478             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
479         )
480         ctx.exit(1)
481     if not src and code is None:
482         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
483         ctx.exit(1)
484
485     root, method = (
486         find_project_root(src, stdin_filename) if code is None else (None, None)
487     )
488     ctx.obj["root"] = root
489
490     if verbose:
491         if root:
492             out(
493                 f"Identified `{root}` as project root containing a {method}.",
494                 fg="blue",
495             )
496
497         if config:
498             config_source = ctx.get_parameter_source("config")
499             user_level_config = str(find_user_pyproject_toml())
500             if config == user_level_config:
501                 out(
502                     "Using configuration from user-level config at "
503                     f"'{user_level_config}'.",
504                     fg="blue",
505                 )
506             elif config_source in (
507                 ParameterSource.DEFAULT,
508                 ParameterSource.DEFAULT_MAP,
509             ):
510                 out("Using configuration from project root.", fg="blue")
511             else:
512                 out(f"Using configuration in '{config}'.", fg="blue")
513             if ctx.default_map:
514                 for param, value in ctx.default_map.items():
515                     out(f"{param}: {value}")
516
517     error_msg = "Oh no! 💥 💔 💥"
518     if (
519         required_version
520         and required_version != __version__
521         and required_version != __version__.split(".")[0]
522     ):
523         err(
524             f"{error_msg} The required version `{required_version}` does not match"
525             f" the running version `{__version__}`!"
526         )
527         ctx.exit(1)
528     if ipynb and pyi:
529         err("Cannot pass both `pyi` and `ipynb` flags!")
530         ctx.exit(1)
531
532     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
533     if target_version:
534         versions = set(target_version)
535     else:
536         # We'll autodetect later.
537         versions = set()
538     mode = Mode(
539         target_versions=versions,
540         line_length=line_length,
541         is_pyi=pyi,
542         is_ipynb=ipynb,
543         skip_source_first_line=skip_source_first_line,
544         string_normalization=not skip_string_normalization,
545         magic_trailing_comma=not skip_magic_trailing_comma,
546         experimental_string_processing=experimental_string_processing,
547         preview=preview,
548         python_cell_magics=set(python_cell_magics),
549     )
550
551     if code is not None:
552         # Run in quiet mode by default with -c; the extra output isn't useful.
553         # You can still pass -v to get verbose output.
554         quiet = True
555
556     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
557
558     if code is not None:
559         reformat_code(
560             content=code, fast=fast, write_back=write_back, mode=mode, report=report
561         )
562     else:
563         try:
564             sources = get_sources(
565                 ctx=ctx,
566                 src=src,
567                 quiet=quiet,
568                 verbose=verbose,
569                 include=include,
570                 exclude=exclude,
571                 extend_exclude=extend_exclude,
572                 force_exclude=force_exclude,
573                 report=report,
574                 stdin_filename=stdin_filename,
575             )
576         except GitWildMatchPatternError:
577             ctx.exit(1)
578
579         path_empty(
580             sources,
581             "No Python files are present to be formatted. Nothing to do 😴",
582             quiet,
583             verbose,
584             ctx,
585         )
586
587         if len(sources) == 1:
588             reformat_one(
589                 src=sources.pop(),
590                 fast=fast,
591                 write_back=write_back,
592                 mode=mode,
593                 report=report,
594             )
595         else:
596             from black.concurrency import reformat_many
597
598             reformat_many(
599                 sources=sources,
600                 fast=fast,
601                 write_back=write_back,
602                 mode=mode,
603                 report=report,
604                 workers=workers,
605             )
606
607     if verbose or not quiet:
608         if code is None and (verbose or report.change_count or report.failure_count):
609             out()
610         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
611         if code is None:
612             click.echo(str(report), err=True)
613     ctx.exit(report.return_code)
614
615
616 def get_sources(
617     *,
618     ctx: click.Context,
619     src: Tuple[str, ...],
620     quiet: bool,
621     verbose: bool,
622     include: Pattern[str],
623     exclude: Optional[Pattern[str]],
624     extend_exclude: Optional[Pattern[str]],
625     force_exclude: Optional[Pattern[str]],
626     report: "Report",
627     stdin_filename: Optional[str],
628 ) -> Set[Path]:
629     """Compute the set of files to be formatted."""
630     sources: Set[Path] = set()
631     root = ctx.obj["root"]
632
633     using_default_exclude = exclude is None
634     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
635     gitignore: Optional[Dict[Path, PathSpec]] = None
636     root_gitignore = get_gitignore(root)
637
638     for s in src:
639         if s == "-" and stdin_filename:
640             p = Path(stdin_filename)
641             is_stdin = True
642         else:
643             p = Path(s)
644             is_stdin = False
645
646         if is_stdin or p.is_file():
647             normalized_path: Optional[str] = normalize_path_maybe_ignore(
648                 p, ctx.obj["root"], report
649             )
650             if normalized_path is None:
651                 if verbose:
652                     out(f'Skipping invalid source: "{normalized_path}"', fg="red")
653                 continue
654             if verbose:
655                 out(f'Found input source: "{normalized_path}"', fg="blue")
656
657             normalized_path = "/" + normalized_path
658             # Hard-exclude any files that matches the `--force-exclude` regex.
659             if force_exclude:
660                 force_exclude_match = force_exclude.search(normalized_path)
661             else:
662                 force_exclude_match = None
663             if force_exclude_match and force_exclude_match.group(0):
664                 report.path_ignored(p, "matches the --force-exclude regular expression")
665                 continue
666
667             if is_stdin:
668                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
669
670             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
671                 warn=verbose or not quiet
672             ):
673                 continue
674
675             sources.add(p)
676         elif p.is_dir():
677             p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
678             if verbose:
679                 out(f'Found input source directory: "{p}"', fg="blue")
680
681             if using_default_exclude:
682                 gitignore = {
683                     root: root_gitignore,
684                     p: get_gitignore(p),
685                 }
686             sources.update(
687                 gen_python_files(
688                     p.iterdir(),
689                     ctx.obj["root"],
690                     include,
691                     exclude,
692                     extend_exclude,
693                     force_exclude,
694                     report,
695                     gitignore,
696                     verbose=verbose,
697                     quiet=quiet,
698                 )
699             )
700         elif s == "-":
701             if verbose:
702                 out("Found input source stdin", fg="blue")
703             sources.add(p)
704         else:
705             err(f"invalid path: {s}")
706
707     return sources
708
709
710 def path_empty(
711     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
712 ) -> None:
713     """
714     Exit if there is no `src` provided for formatting
715     """
716     if not src:
717         if verbose or not quiet:
718             out(msg)
719         ctx.exit(0)
720
721
722 def reformat_code(
723     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
724 ) -> None:
725     """
726     Reformat and print out `content` without spawning child processes.
727     Similar to `reformat_one`, but for string content.
728
729     `fast`, `write_back`, and `mode` options are passed to
730     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
731     """
732     path = Path("<string>")
733     try:
734         changed = Changed.NO
735         if format_stdin_to_stdout(
736             content=content, fast=fast, write_back=write_back, mode=mode
737         ):
738             changed = Changed.YES
739         report.done(path, changed)
740     except Exception as exc:
741         if report.verbose:
742             traceback.print_exc()
743         report.failed(path, str(exc))
744
745
746 # diff-shades depends on being to monkeypatch this function to operate. I know it's
747 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
748 @mypyc_attr(patchable=True)
749 def reformat_one(
750     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
751 ) -> None:
752     """Reformat a single file under `src` without spawning child processes.
753
754     `fast`, `write_back`, and `mode` options are passed to
755     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
756     """
757     try:
758         changed = Changed.NO
759
760         if str(src) == "-":
761             is_stdin = True
762         elif str(src).startswith(STDIN_PLACEHOLDER):
763             is_stdin = True
764             # Use the original name again in case we want to print something
765             # to the user
766             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
767         else:
768             is_stdin = False
769
770         if is_stdin:
771             if src.suffix == ".pyi":
772                 mode = replace(mode, is_pyi=True)
773             elif src.suffix == ".ipynb":
774                 mode = replace(mode, is_ipynb=True)
775             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
776                 changed = Changed.YES
777         else:
778             cache: Cache = {}
779             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
780                 cache = read_cache(mode)
781                 res_src = src.resolve()
782                 res_src_s = str(res_src)
783                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
784                     changed = Changed.CACHED
785             if changed is not Changed.CACHED and format_file_in_place(
786                 src, fast=fast, write_back=write_back, mode=mode
787             ):
788                 changed = Changed.YES
789             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
790                 write_back is WriteBack.CHECK and changed is Changed.NO
791             ):
792                 write_cache(cache, [src], mode)
793         report.done(src, changed)
794     except Exception as exc:
795         if report.verbose:
796             traceback.print_exc()
797         report.failed(src, str(exc))
798
799
800 def format_file_in_place(
801     src: Path,
802     fast: bool,
803     mode: Mode,
804     write_back: WriteBack = WriteBack.NO,
805     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
806 ) -> bool:
807     """Format file under `src` path. Return True if changed.
808
809     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
810     code to the file.
811     `mode` and `fast` options are passed to :func:`format_file_contents`.
812     """
813     if src.suffix == ".pyi":
814         mode = replace(mode, is_pyi=True)
815     elif src.suffix == ".ipynb":
816         mode = replace(mode, is_ipynb=True)
817
818     then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)
819     header = b""
820     with open(src, "rb") as buf:
821         if mode.skip_source_first_line:
822             header = buf.readline()
823         src_contents, encoding, newline = decode_bytes(buf.read())
824     try:
825         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
826     except NothingChanged:
827         return False
828     except JSONDecodeError:
829         raise ValueError(
830             f"File '{src}' cannot be parsed as valid Jupyter notebook."
831         ) from None
832     src_contents = header.decode(encoding) + src_contents
833     dst_contents = header.decode(encoding) + dst_contents
834
835     if write_back == WriteBack.YES:
836         with open(src, "w", encoding=encoding, newline=newline) as f:
837             f.write(dst_contents)
838     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
839         now = datetime.now(timezone.utc)
840         src_name = f"{src}\t{then}"
841         dst_name = f"{src}\t{now}"
842         if mode.is_ipynb:
843             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
844         else:
845             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
846
847         if write_back == WriteBack.COLOR_DIFF:
848             diff_contents = color_diff(diff_contents)
849
850         with lock or nullcontext():
851             f = io.TextIOWrapper(
852                 sys.stdout.buffer,
853                 encoding=encoding,
854                 newline=newline,
855                 write_through=True,
856             )
857             f = wrap_stream_for_windows(f)
858             f.write(diff_contents)
859             f.detach()
860
861     return True
862
863
864 def format_stdin_to_stdout(
865     fast: bool,
866     *,
867     content: Optional[str] = None,
868     write_back: WriteBack = WriteBack.NO,
869     mode: Mode,
870 ) -> bool:
871     """Format file on stdin. Return True if changed.
872
873     If content is None, it's read from sys.stdin.
874
875     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
876     write a diff to stdout. The `mode` argument is passed to
877     :func:`format_file_contents`.
878     """
879     then = datetime.now(timezone.utc)
880
881     if content is None:
882         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
883     else:
884         src, encoding, newline = content, "utf-8", ""
885
886     dst = src
887     try:
888         dst = format_file_contents(src, fast=fast, mode=mode)
889         return True
890
891     except NothingChanged:
892         return False
893
894     finally:
895         f = io.TextIOWrapper(
896             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
897         )
898         if write_back == WriteBack.YES:
899             # Make sure there's a newline after the content
900             if dst and dst[-1] != "\n":
901                 dst += "\n"
902             f.write(dst)
903         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
904             now = datetime.now(timezone.utc)
905             src_name = f"STDIN\t{then}"
906             dst_name = f"STDOUT\t{now}"
907             d = diff(src, dst, src_name, dst_name)
908             if write_back == WriteBack.COLOR_DIFF:
909                 d = color_diff(d)
910                 f = wrap_stream_for_windows(f)
911             f.write(d)
912         f.detach()
913
914
915 def check_stability_and_equivalence(
916     src_contents: str, dst_contents: str, *, mode: Mode
917 ) -> None:
918     """Perform stability and equivalence checks.
919
920     Raise AssertionError if source and destination contents are not
921     equivalent, or if a second pass of the formatter would format the
922     content differently.
923     """
924     assert_equivalent(src_contents, dst_contents)
925     assert_stable(src_contents, dst_contents, mode=mode)
926
927
928 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
929     """Reformat contents of a file and return new contents.
930
931     If `fast` is False, additionally confirm that the reformatted code is
932     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
933     `mode` is passed to :func:`format_str`.
934     """
935     if mode.is_ipynb:
936         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
937     else:
938         dst_contents = format_str(src_contents, mode=mode)
939     if src_contents == dst_contents:
940         raise NothingChanged
941
942     if not fast and not mode.is_ipynb:
943         # Jupyter notebooks will already have been checked above.
944         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
945     return dst_contents
946
947
948 def validate_cell(src: str, mode: Mode) -> None:
949     """Check that cell does not already contain TransformerManager transformations,
950     or non-Python cell magics, which might cause tokenizer_rt to break because of
951     indentations.
952
953     If a cell contains ``!ls``, then it'll be transformed to
954     ``get_ipython().system('ls')``. However, if the cell originally contained
955     ``get_ipython().system('ls')``, then it would get transformed in the same way:
956
957         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
958         "get_ipython().system('ls')\n"
959         >>> TransformerManager().transform_cell("!ls")
960         "get_ipython().system('ls')\n"
961
962     Due to the impossibility of safely roundtripping in such situations, cells
963     containing transformed magics will be ignored.
964     """
965     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
966         raise NothingChanged
967     if (
968         src[:2] == "%%"
969         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
970     ):
971         raise NothingChanged
972
973
974 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
975     """Format code in given cell of Jupyter notebook.
976
977     General idea is:
978
979       - if cell has trailing semicolon, remove it;
980       - if cell has IPython magics, mask them;
981       - format cell;
982       - reinstate IPython magics;
983       - reinstate trailing semicolon (if originally present);
984       - strip trailing newlines.
985
986     Cells with syntax errors will not be processed, as they
987     could potentially be automagics or multi-line magics, which
988     are currently not supported.
989     """
990     validate_cell(src, mode)
991     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
992         src
993     )
994     try:
995         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
996     except SyntaxError:
997         raise NothingChanged from None
998     masked_dst = format_str(masked_src, mode=mode)
999     if not fast:
1000         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
1001     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1002     dst = put_trailing_semicolon_back(
1003         dst_without_trailing_semicolon, has_trailing_semicolon
1004     )
1005     dst = dst.rstrip("\n")
1006     if dst == src:
1007         raise NothingChanged from None
1008     return dst
1009
1010
1011 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1012     """If notebook is marked as non-Python, don't format it.
1013
1014     All notebook metadata fields are optional, see
1015     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1016     if a notebook has empty metadata, we will try to parse it anyway.
1017     """
1018     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1019     if language is not None and language != "python":
1020         raise NothingChanged from None
1021
1022
1023 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1024     """Format Jupyter notebook.
1025
1026     Operate cell-by-cell, only on code cells, only for Python notebooks.
1027     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1028     """
1029     if not src_contents:
1030         raise NothingChanged
1031
1032     trailing_newline = src_contents[-1] == "\n"
1033     modified = False
1034     nb = json.loads(src_contents)
1035     validate_metadata(nb)
1036     for cell in nb["cells"]:
1037         if cell.get("cell_type", None) == "code":
1038             try:
1039                 src = "".join(cell["source"])
1040                 dst = format_cell(src, fast=fast, mode=mode)
1041             except NothingChanged:
1042                 pass
1043             else:
1044                 cell["source"] = dst.splitlines(keepends=True)
1045                 modified = True
1046     if modified:
1047         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1048         if trailing_newline:
1049             dst_contents = dst_contents + "\n"
1050         return dst_contents
1051     else:
1052         raise NothingChanged
1053
1054
1055 def format_str(src_contents: str, *, mode: Mode) -> str:
1056     """Reformat a string and return new contents.
1057
1058     `mode` determines formatting options, such as how many characters per line are
1059     allowed.  Example:
1060
1061     >>> import black
1062     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1063     def f(arg: str = "") -> None:
1064         ...
1065
1066     A more complex example:
1067
1068     >>> print(
1069     ...   black.format_str(
1070     ...     "def f(arg:str='')->None: hey",
1071     ...     mode=black.Mode(
1072     ...       target_versions={black.TargetVersion.PY36},
1073     ...       line_length=10,
1074     ...       string_normalization=False,
1075     ...       is_pyi=False,
1076     ...     ),
1077     ...   ),
1078     ... )
1079     def f(
1080         arg: str = '',
1081     ) -> None:
1082         hey
1083
1084     """
1085     dst_contents = _format_str_once(src_contents, mode=mode)
1086     # Forced second pass to work around optional trailing commas (becoming
1087     # forced trailing commas on pass 2) interacting differently with optional
1088     # parentheses.  Admittedly ugly.
1089     if src_contents != dst_contents:
1090         return _format_str_once(dst_contents, mode=mode)
1091     return dst_contents
1092
1093
1094 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1095     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1096     dst_blocks: List[LinesBlock] = []
1097     if mode.target_versions:
1098         versions = mode.target_versions
1099     else:
1100         future_imports = get_future_imports(src_node)
1101         versions = detect_target_versions(src_node, future_imports=future_imports)
1102
1103     context_manager_features = {
1104         feature
1105         for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
1106         if supports_feature(versions, feature)
1107     }
1108     normalize_fmt_off(src_node)
1109     lines = LineGenerator(mode=mode, features=context_manager_features)
1110     elt = EmptyLineTracker(mode=mode)
1111     split_line_features = {
1112         feature
1113         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1114         if supports_feature(versions, feature)
1115     }
1116     block: Optional[LinesBlock] = None
1117     for current_line in lines.visit(src_node):
1118         block = elt.maybe_empty_lines(current_line)
1119         dst_blocks.append(block)
1120         for line in transform_line(
1121             current_line, mode=mode, features=split_line_features
1122         ):
1123             block.content_lines.append(str(line))
1124     if dst_blocks:
1125         dst_blocks[-1].after = 0
1126     dst_contents = []
1127     for block in dst_blocks:
1128         dst_contents.extend(block.all_lines())
1129     if not dst_contents:
1130         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1131         # and check if normalized_content has more than one line
1132         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1133         if "\n" in normalized_content:
1134             return newline
1135         return ""
1136     return "".join(dst_contents)
1137
1138
1139 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1140     """Return a tuple of (decoded_contents, encoding, newline).
1141
1142     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1143     universal newlines (i.e. only contains LF).
1144     """
1145     srcbuf = io.BytesIO(src)
1146     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1147     if not lines:
1148         return "", encoding, "\n"
1149
1150     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1151     srcbuf.seek(0)
1152     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1153         return tiow.read(), encoding, newline
1154
1155
1156 def get_features_used(  # noqa: C901
1157     node: Node, *, future_imports: Optional[Set[str]] = None
1158 ) -> Set[Feature]:
1159     """Return a set of (relatively) new Python features used in this file.
1160
1161     Currently looking for:
1162     - f-strings;
1163     - self-documenting expressions in f-strings (f"{x=}");
1164     - underscores in numeric literals;
1165     - trailing commas after * or ** in function signatures and calls;
1166     - positional only arguments in function signatures and lambdas;
1167     - assignment expression;
1168     - relaxed decorator syntax;
1169     - usage of __future__ flags (annotations);
1170     - print / exec statements;
1171     - parenthesized context managers;
1172     - match statements;
1173     - except* clause;
1174     - variadic generics;
1175     """
1176     features: Set[Feature] = set()
1177     if future_imports:
1178         features |= {
1179             FUTURE_FLAG_TO_FEATURE[future_import]
1180             for future_import in future_imports
1181             if future_import in FUTURE_FLAG_TO_FEATURE
1182         }
1183
1184     for n in node.pre_order():
1185         if is_string_token(n):
1186             value_head = n.value[:2]
1187             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1188                 features.add(Feature.F_STRINGS)
1189                 if Feature.DEBUG_F_STRINGS not in features:
1190                     for span_beg, span_end in iter_fexpr_spans(n.value):
1191                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1192                             features.add(Feature.DEBUG_F_STRINGS)
1193                             break
1194
1195         elif is_number_token(n):
1196             if "_" in n.value:
1197                 features.add(Feature.NUMERIC_UNDERSCORES)
1198
1199         elif n.type == token.SLASH:
1200             if n.parent and n.parent.type in {
1201                 syms.typedargslist,
1202                 syms.arglist,
1203                 syms.varargslist,
1204             }:
1205                 features.add(Feature.POS_ONLY_ARGUMENTS)
1206
1207         elif n.type == token.COLONEQUAL:
1208             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1209
1210         elif n.type == syms.decorator:
1211             if len(n.children) > 1 and not is_simple_decorator_expression(
1212                 n.children[1]
1213             ):
1214                 features.add(Feature.RELAXED_DECORATORS)
1215
1216         elif (
1217             n.type in {syms.typedargslist, syms.arglist}
1218             and n.children
1219             and n.children[-1].type == token.COMMA
1220         ):
1221             if n.type == syms.typedargslist:
1222                 feature = Feature.TRAILING_COMMA_IN_DEF
1223             else:
1224                 feature = Feature.TRAILING_COMMA_IN_CALL
1225
1226             for ch in n.children:
1227                 if ch.type in STARS:
1228                     features.add(feature)
1229
1230                 if ch.type == syms.argument:
1231                     for argch in ch.children:
1232                         if argch.type in STARS:
1233                             features.add(feature)
1234
1235         elif (
1236             n.type in {syms.return_stmt, syms.yield_expr}
1237             and len(n.children) >= 2
1238             and n.children[1].type == syms.testlist_star_expr
1239             and any(child.type == syms.star_expr for child in n.children[1].children)
1240         ):
1241             features.add(Feature.UNPACKING_ON_FLOW)
1242
1243         elif (
1244             n.type == syms.annassign
1245             and len(n.children) >= 4
1246             and n.children[3].type == syms.testlist_star_expr
1247         ):
1248             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1249
1250         elif (
1251             n.type == syms.with_stmt
1252             and len(n.children) > 2
1253             and n.children[1].type == syms.atom
1254         ):
1255             atom_children = n.children[1].children
1256             if (
1257                 len(atom_children) == 3
1258                 and atom_children[0].type == token.LPAR
1259                 and atom_children[1].type == syms.testlist_gexp
1260                 and atom_children[2].type == token.RPAR
1261             ):
1262                 features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
1263
1264         elif n.type == syms.match_stmt:
1265             features.add(Feature.PATTERN_MATCHING)
1266
1267         elif (
1268             n.type == syms.except_clause
1269             and len(n.children) >= 2
1270             and n.children[1].type == token.STAR
1271         ):
1272             features.add(Feature.EXCEPT_STAR)
1273
1274         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1275             child.type == syms.star_expr for child in n.children
1276         ):
1277             features.add(Feature.VARIADIC_GENERICS)
1278
1279         elif (
1280             n.type == syms.tname_star
1281             and len(n.children) == 3
1282             and n.children[2].type == syms.star_expr
1283         ):
1284             features.add(Feature.VARIADIC_GENERICS)
1285
1286         elif n.type in (syms.type_stmt, syms.typeparams):
1287             features.add(Feature.TYPE_PARAMS)
1288
1289     return features
1290
1291
1292 def detect_target_versions(
1293     node: Node, *, future_imports: Optional[Set[str]] = None
1294 ) -> Set[TargetVersion]:
1295     """Detect the version to target based on the nodes used."""
1296     features = get_features_used(node, future_imports=future_imports)
1297     return {
1298         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1299     }
1300
1301
1302 def get_future_imports(node: Node) -> Set[str]:
1303     """Return a set of __future__ imports in the file."""
1304     imports: Set[str] = set()
1305
1306     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1307         for child in children:
1308             if isinstance(child, Leaf):
1309                 if child.type == token.NAME:
1310                     yield child.value
1311
1312             elif child.type == syms.import_as_name:
1313                 orig_name = child.children[0]
1314                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1315                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1316                 yield orig_name.value
1317
1318             elif child.type == syms.import_as_names:
1319                 yield from get_imports_from_children(child.children)
1320
1321             else:
1322                 raise AssertionError("Invalid syntax parsing imports")
1323
1324     for child in node.children:
1325         if child.type != syms.simple_stmt:
1326             break
1327
1328         first_child = child.children[0]
1329         if isinstance(first_child, Leaf):
1330             # Continue looking if we see a docstring; otherwise stop.
1331             if (
1332                 len(child.children) == 2
1333                 and first_child.type == token.STRING
1334                 and child.children[1].type == token.NEWLINE
1335             ):
1336                 continue
1337
1338             break
1339
1340         elif first_child.type == syms.import_from:
1341             module_name = first_child.children[1]
1342             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1343                 break
1344
1345             imports |= set(get_imports_from_children(first_child.children[3:]))
1346         else:
1347             break
1348
1349     return imports
1350
1351
1352 def assert_equivalent(src: str, dst: str) -> None:
1353     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1354     try:
1355         src_ast = parse_ast(src)
1356     except Exception as exc:
1357         raise AssertionError(
1358             "cannot use --safe with this file; failed to parse source file AST: "
1359             f"{exc}\n"
1360             "This could be caused by running Black with an older Python version "
1361             "that does not support new syntax used in your source file."
1362         ) from exc
1363
1364     try:
1365         dst_ast = parse_ast(dst)
1366     except Exception as exc:
1367         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1368         raise AssertionError(
1369             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1370             "Please report a bug on https://github.com/psf/black/issues.  "
1371             f"This invalid output might be helpful: {log}"
1372         ) from None
1373
1374     src_ast_str = "\n".join(stringify_ast(src_ast))
1375     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1376     if src_ast_str != dst_ast_str:
1377         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1378         raise AssertionError(
1379             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1380             " source.  Please report a bug on "
1381             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1382         ) from None
1383
1384
1385 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1386     """Raise AssertionError if `dst` reformats differently the second time."""
1387     # We shouldn't call format_str() here, because that formats the string
1388     # twice and may hide a bug where we bounce back and forth between two
1389     # versions.
1390     newdst = _format_str_once(dst, mode=mode)
1391     if dst != newdst:
1392         log = dump_to_file(
1393             str(mode),
1394             diff(src, dst, "source", "first pass"),
1395             diff(dst, newdst, "first pass", "second pass"),
1396         )
1397         raise AssertionError(
1398             "INTERNAL ERROR: Black produced different code on the second pass of the"
1399             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1400             f"  This diff might be helpful: {log}"
1401         ) from None
1402
1403
1404 @contextmanager
1405 def nullcontext() -> Iterator[None]:
1406     """Return an empty context manager.
1407
1408     To be used like `nullcontext` in Python 3.7.
1409     """
1410     yield
1411
1412
1413 def patched_main() -> None:
1414     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1415     # environments so just assume we always need to call it if frozen.
1416     if getattr(sys, "frozen", False):
1417         from multiprocessing import freeze_support
1418
1419         freeze_support()
1420
1421     main()
1422
1423
1424 if __name__ == "__main__":
1425     patched_main()