]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

daf6f88f58edac008203b5d28b8eb3034bae7100
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache, get_cache_info, read_cache, write_cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import (
67     FUTURE_FLAG_TO_FEATURE,
68     VERSION_TO_FEATURES,
69     Feature,
70     Mode,
71     TargetVersion,
72     supports_feature,
73 )
74 from black.nodes import (
75     STARS,
76     is_number_token,
77     is_simple_decorator_expression,
78     is_string_token,
79     syms,
80 )
81 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
82 from black.parsing import InvalidInput  # noqa F401
83 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
84 from black.report import Changed, NothingChanged, Report
85 from black.trans import iter_fexpr_spans
86 from blib2to3.pgen2 import token
87 from blib2to3.pytree import Leaf, Node
88
89 COMPILED = Path(__file__).suffix in (".pyd", ".so")
90
91 # types
92 FileContent = str
93 Encoding = str
94 NewLine = str
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102     COLOR_DIFF = 4
103
104     @classmethod
105     def from_configuration(
106         cls, *, check: bool, diff: bool, color: bool = False
107     ) -> "WriteBack":
108         if check and not diff:
109             return cls.CHECK
110
111         if diff and color:
112             return cls.COLOR_DIFF
113
114         return cls.DIFF if diff else cls.YES
115
116
117 # Legacy name, left for integrations.
118 FileMode = Mode
119
120
121 def read_pyproject_toml(
122     ctx: click.Context, param: click.Parameter, value: Optional[str]
123 ) -> Optional[str]:
124     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
125
126     Returns the path to a successfully found and read configuration file, None
127     otherwise.
128     """
129     if not value:
130         value = find_pyproject_toml(ctx.params.get("src", ()))
131         if value is None:
132             return None
133
134     try:
135         config = parse_pyproject_toml(value)
136     except (OSError, ValueError) as e:
137         raise click.FileError(
138             filename=value, hint=f"Error reading configuration file: {e}"
139         ) from None
140
141     if not config:
142         return None
143     else:
144         # Sanitize the values to be Click friendly. For more information please see:
145         # https://github.com/psf/black/issues/1458
146         # https://github.com/pallets/click/issues/1567
147         config = {
148             k: str(v) if not isinstance(v, (list, dict)) else v
149             for k, v in config.items()
150         }
151
152     target_version = config.get("target_version")
153     if target_version is not None and not isinstance(target_version, list):
154         raise click.BadOptionUsage(
155             "target-version", "Config key target-version must be a list"
156         )
157
158     default_map: Dict[str, Any] = {}
159     if ctx.default_map:
160         default_map.update(ctx.default_map)
161     default_map.update(config)
162
163     ctx.default_map = default_map
164     return value
165
166
167 def target_version_option_callback(
168     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
169 ) -> List[TargetVersion]:
170     """Compute the target versions from a --target-version flag.
171
172     This is its own function because mypy couldn't infer the type correctly
173     when it was a lambda, causing mypyc trouble.
174     """
175     return [TargetVersion[val.upper()] for val in v]
176
177
178 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
179     """Compile a regular expression string in `regex`.
180
181     If it contains newlines, use verbose mode.
182     """
183     if "\n" in regex:
184         regex = "(?x)" + regex
185     compiled: Pattern[str] = re.compile(regex)
186     return compiled
187
188
189 def validate_regex(
190     ctx: click.Context,
191     param: click.Parameter,
192     value: Optional[str],
193 ) -> Optional[Pattern[str]]:
194     try:
195         return re_compile_maybe_verbose(value) if value is not None else None
196     except re.error as e:
197         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
198
199
200 @click.command(
201     context_settings={"help_option_names": ["-h", "--help"]},
202     # While Click does set this field automatically using the docstring, mypyc
203     # (annoyingly) strips 'em so we need to set it here too.
204     help="The uncompromising code formatter.",
205 )
206 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
207 @click.option(
208     "-l",
209     "--line-length",
210     type=int,
211     default=DEFAULT_LINE_LENGTH,
212     help="How many characters per line to allow.",
213     show_default=True,
214 )
215 @click.option(
216     "-t",
217     "--target-version",
218     type=click.Choice([v.name.lower() for v in TargetVersion]),
219     callback=target_version_option_callback,
220     multiple=True,
221     help=(
222         "Python versions that should be supported by Black's output. [default: per-file"
223         " auto-detection]"
224     ),
225 )
226 @click.option(
227     "--pyi",
228     is_flag=True,
229     help=(
230         "Format all input files like typing stubs regardless of file extension (useful"
231         " when piping source on standard input)."
232     ),
233 )
234 @click.option(
235     "--ipynb",
236     is_flag=True,
237     help=(
238         "Format all input files like Jupyter Notebooks regardless of file extension "
239         "(useful when piping source on standard input)."
240     ),
241 )
242 @click.option(
243     "--python-cell-magics",
244     multiple=True,
245     help=(
246         "When processing Jupyter Notebooks, add the given magic to the list"
247         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
248         " Useful for formatting cells with custom python magics."
249     ),
250     default=[],
251 )
252 @click.option(
253     "-x",
254     "--skip-source-first-line",
255     is_flag=True,
256     help="Skip the first line of the source code.",
257 )
258 @click.option(
259     "-S",
260     "--skip-string-normalization",
261     is_flag=True,
262     help="Don't normalize string quotes or prefixes.",
263 )
264 @click.option(
265     "-C",
266     "--skip-magic-trailing-comma",
267     is_flag=True,
268     help="Don't use trailing commas as a reason to split lines.",
269 )
270 @click.option(
271     "--experimental-string-processing",
272     is_flag=True,
273     hidden=True,
274     help="(DEPRECATED and now included in --preview) Normalize string literals.",
275 )
276 @click.option(
277     "--preview",
278     is_flag=True,
279     help=(
280         "Enable potentially disruptive style changes that may be added to Black's main"
281         " functionality in the next major release."
282     ),
283 )
284 @click.option(
285     "--check",
286     is_flag=True,
287     help=(
288         "Don't write the files back, just return the status. Return code 0 means"
289         " nothing would change. Return code 1 means some files would be reformatted."
290         " Return code 123 means there was an internal error."
291     ),
292 )
293 @click.option(
294     "--diff",
295     is_flag=True,
296     help="Don't write the files back, just output a diff for each file on stdout.",
297 )
298 @click.option(
299     "--color/--no-color",
300     is_flag=True,
301     help="Show colored diff. Only applies when `--diff` is given.",
302 )
303 @click.option(
304     "--fast/--safe",
305     is_flag=True,
306     help="If --fast given, skip temporary sanity checks. [default: --safe]",
307 )
308 @click.option(
309     "--required-version",
310     type=str,
311     help=(
312         "Require a specific version of Black to be running (useful for unifying results"
313         " across many environments e.g. with a pyproject.toml file). It can be"
314         " either a major version number or an exact version."
315     ),
316 )
317 @click.option(
318     "--include",
319     type=str,
320     default=DEFAULT_INCLUDES,
321     callback=validate_regex,
322     help=(
323         "A regular expression that matches files and directories that should be"
324         " included on recursive searches. An empty value means all files are included"
325         " regardless of the name. Use forward slashes for directories on all platforms"
326         " (Windows, too). Exclusions are calculated first, inclusions later."
327     ),
328     show_default=True,
329 )
330 @click.option(
331     "--exclude",
332     type=str,
333     callback=validate_regex,
334     help=(
335         "A regular expression that matches files and directories that should be"
336         " excluded on recursive searches. An empty value means no paths are excluded."
337         " Use forward slashes for directories on all platforms (Windows, too)."
338         " Exclusions are calculated first, inclusions later. [default:"
339         f" {DEFAULT_EXCLUDES}]"
340     ),
341     show_default=False,
342 )
343 @click.option(
344     "--extend-exclude",
345     type=str,
346     callback=validate_regex,
347     help=(
348         "Like --exclude, but adds additional files and directories on top of the"
349         " excluded ones. (Useful if you simply want to add to the default)"
350     ),
351 )
352 @click.option(
353     "--force-exclude",
354     type=str,
355     callback=validate_regex,
356     help=(
357         "Like --exclude, but files and directories matching this regex will be "
358         "excluded even when they are passed explicitly as arguments."
359     ),
360 )
361 @click.option(
362     "--stdin-filename",
363     type=str,
364     help=(
365         "The name of the file when passing it through stdin. Useful to make "
366         "sure Black will respect --force-exclude option on some "
367         "editors that rely on using stdin."
368     ),
369 )
370 @click.option(
371     "-W",
372     "--workers",
373     type=click.IntRange(min=1),
374     default=None,
375     help="Number of parallel workers [default: number of CPUs in the system]",
376 )
377 @click.option(
378     "-q",
379     "--quiet",
380     is_flag=True,
381     help=(
382         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
383         " those with 2>/dev/null."
384     ),
385 )
386 @click.option(
387     "-v",
388     "--verbose",
389     is_flag=True,
390     help=(
391         "Also emit messages to stderr about files that were not changed or were ignored"
392         " due to exclusion patterns."
393     ),
394 )
395 @click.version_option(
396     version=__version__,
397     message=(
398         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
399         f"Python ({platform.python_implementation()}) {platform.python_version()}"
400     ),
401 )
402 @click.argument(
403     "src",
404     nargs=-1,
405     type=click.Path(
406         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
407     ),
408     is_eager=True,
409     metavar="SRC ...",
410 )
411 @click.option(
412     "--config",
413     type=click.Path(
414         exists=True,
415         file_okay=True,
416         dir_okay=False,
417         readable=True,
418         allow_dash=False,
419         path_type=str,
420     ),
421     is_eager=True,
422     callback=read_pyproject_toml,
423     help="Read configuration from FILE path.",
424 )
425 @click.pass_context
426 def main(  # noqa: C901
427     ctx: click.Context,
428     code: Optional[str],
429     line_length: int,
430     target_version: List[TargetVersion],
431     check: bool,
432     diff: bool,
433     color: bool,
434     fast: bool,
435     pyi: bool,
436     ipynb: bool,
437     python_cell_magics: Sequence[str],
438     skip_source_first_line: bool,
439     skip_string_normalization: bool,
440     skip_magic_trailing_comma: bool,
441     experimental_string_processing: bool,
442     preview: bool,
443     quiet: bool,
444     verbose: bool,
445     required_version: Optional[str],
446     include: Pattern[str],
447     exclude: Optional[Pattern[str]],
448     extend_exclude: Optional[Pattern[str]],
449     force_exclude: Optional[Pattern[str]],
450     stdin_filename: Optional[str],
451     workers: Optional[int],
452     src: Tuple[str, ...],
453     config: Optional[str],
454 ) -> None:
455     """The uncompromising code formatter."""
456     ctx.ensure_object(dict)
457
458     if src and code is not None:
459         out(
460             main.get_usage(ctx)
461             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
462         )
463         ctx.exit(1)
464     if not src and code is None:
465         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
466         ctx.exit(1)
467
468     root, method = (
469         find_project_root(src, stdin_filename) if code is None else (None, None)
470     )
471     ctx.obj["root"] = root
472
473     if verbose:
474         if root:
475             out(
476                 f"Identified `{root}` as project root containing a {method}.",
477                 fg="blue",
478             )
479
480             normalized = [
481                 (
482                     (source, source)
483                     if source == "-"
484                     else (normalize_path_maybe_ignore(Path(source), root), source)
485                 )
486                 for source in src
487             ]
488             srcs_string = ", ".join(
489                 [
490                     (
491                         f'"{_norm}"'
492                         if _norm
493                         else f'\033[31m"{source} (skipping - invalid)"\033[34m'
494                     )
495                     for _norm, source in normalized
496                 ]
497             )
498             out(f"Sources to be formatted: {srcs_string}", fg="blue")
499
500         if config:
501             config_source = ctx.get_parameter_source("config")
502             user_level_config = str(find_user_pyproject_toml())
503             if config == user_level_config:
504                 out(
505                     (
506                         "Using configuration from user-level config at "
507                         f"'{user_level_config}'."
508                     ),
509                     fg="blue",
510                 )
511             elif config_source in (
512                 ParameterSource.DEFAULT,
513                 ParameterSource.DEFAULT_MAP,
514             ):
515                 out("Using configuration from project root.", fg="blue")
516             else:
517                 out(f"Using configuration in '{config}'.", fg="blue")
518             if ctx.default_map:
519                 for param, value in ctx.default_map.items():
520                     out(f"{param}: {value}")
521
522     error_msg = "Oh no! 💥 💔 💥"
523     if (
524         required_version
525         and required_version != __version__
526         and required_version != __version__.split(".")[0]
527     ):
528         err(
529             f"{error_msg} The required version `{required_version}` does not match"
530             f" the running version `{__version__}`!"
531         )
532         ctx.exit(1)
533     if ipynb and pyi:
534         err("Cannot pass both `pyi` and `ipynb` flags!")
535         ctx.exit(1)
536
537     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
538     if target_version:
539         versions = set(target_version)
540     else:
541         # We'll autodetect later.
542         versions = set()
543     mode = Mode(
544         target_versions=versions,
545         line_length=line_length,
546         is_pyi=pyi,
547         is_ipynb=ipynb,
548         skip_source_first_line=skip_source_first_line,
549         string_normalization=not skip_string_normalization,
550         magic_trailing_comma=not skip_magic_trailing_comma,
551         experimental_string_processing=experimental_string_processing,
552         preview=preview,
553         python_cell_magics=set(python_cell_magics),
554     )
555
556     if code is not None:
557         # Run in quiet mode by default with -c; the extra output isn't useful.
558         # You can still pass -v to get verbose output.
559         quiet = True
560
561     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
562
563     if code is not None:
564         reformat_code(
565             content=code, fast=fast, write_back=write_back, mode=mode, report=report
566         )
567     else:
568         try:
569             sources = get_sources(
570                 ctx=ctx,
571                 src=src,
572                 quiet=quiet,
573                 verbose=verbose,
574                 include=include,
575                 exclude=exclude,
576                 extend_exclude=extend_exclude,
577                 force_exclude=force_exclude,
578                 report=report,
579                 stdin_filename=stdin_filename,
580             )
581         except GitWildMatchPatternError:
582             ctx.exit(1)
583
584         path_empty(
585             sources,
586             "No Python files are present to be formatted. Nothing to do 😴",
587             quiet,
588             verbose,
589             ctx,
590         )
591
592         if len(sources) == 1:
593             reformat_one(
594                 src=sources.pop(),
595                 fast=fast,
596                 write_back=write_back,
597                 mode=mode,
598                 report=report,
599             )
600         else:
601             from black.concurrency import reformat_many
602
603             reformat_many(
604                 sources=sources,
605                 fast=fast,
606                 write_back=write_back,
607                 mode=mode,
608                 report=report,
609                 workers=workers,
610             )
611
612     if verbose or not quiet:
613         if code is None and (verbose or report.change_count or report.failure_count):
614             out()
615         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
616         if code is None:
617             click.echo(str(report), err=True)
618     ctx.exit(report.return_code)
619
620
621 def get_sources(
622     *,
623     ctx: click.Context,
624     src: Tuple[str, ...],
625     quiet: bool,
626     verbose: bool,
627     include: Pattern[str],
628     exclude: Optional[Pattern[str]],
629     extend_exclude: Optional[Pattern[str]],
630     force_exclude: Optional[Pattern[str]],
631     report: "Report",
632     stdin_filename: Optional[str],
633 ) -> Set[Path]:
634     """Compute the set of files to be formatted."""
635     sources: Set[Path] = set()
636     root = ctx.obj["root"]
637
638     using_default_exclude = exclude is None
639     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
640     gitignore: Optional[Dict[Path, PathSpec]] = None
641     root_gitignore = get_gitignore(root)
642
643     for s in src:
644         if s == "-" and stdin_filename:
645             p = Path(stdin_filename)
646             is_stdin = True
647         else:
648             p = Path(s)
649             is_stdin = False
650
651         if is_stdin or p.is_file():
652             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
653             if normalized_path is None:
654                 continue
655
656             normalized_path = "/" + normalized_path
657             # Hard-exclude any files that matches the `--force-exclude` regex.
658             if force_exclude:
659                 force_exclude_match = force_exclude.search(normalized_path)
660             else:
661                 force_exclude_match = None
662             if force_exclude_match and force_exclude_match.group(0):
663                 report.path_ignored(p, "matches the --force-exclude regular expression")
664                 continue
665
666             if is_stdin:
667                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
668
669             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
670                 verbose=verbose, quiet=quiet
671             ):
672                 continue
673
674             sources.add(p)
675         elif p.is_dir():
676             p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
677             if using_default_exclude:
678                 gitignore = {
679                     root: root_gitignore,
680                     p: get_gitignore(p),
681                 }
682             sources.update(
683                 gen_python_files(
684                     p.iterdir(),
685                     ctx.obj["root"],
686                     include,
687                     exclude,
688                     extend_exclude,
689                     force_exclude,
690                     report,
691                     gitignore,
692                     verbose=verbose,
693                     quiet=quiet,
694                 )
695             )
696         elif s == "-":
697             sources.add(p)
698         else:
699             err(f"invalid path: {s}")
700     return sources
701
702
703 def path_empty(
704     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
705 ) -> None:
706     """
707     Exit if there is no `src` provided for formatting
708     """
709     if not src:
710         if verbose or not quiet:
711             out(msg)
712         ctx.exit(0)
713
714
715 def reformat_code(
716     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
717 ) -> None:
718     """
719     Reformat and print out `content` without spawning child processes.
720     Similar to `reformat_one`, but for string content.
721
722     `fast`, `write_back`, and `mode` options are passed to
723     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
724     """
725     path = Path("<string>")
726     try:
727         changed = Changed.NO
728         if format_stdin_to_stdout(
729             content=content, fast=fast, write_back=write_back, mode=mode
730         ):
731             changed = Changed.YES
732         report.done(path, changed)
733     except Exception as exc:
734         if report.verbose:
735             traceback.print_exc()
736         report.failed(path, str(exc))
737
738
739 # diff-shades depends on being to monkeypatch this function to operate. I know it's
740 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
741 @mypyc_attr(patchable=True)
742 def reformat_one(
743     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
744 ) -> None:
745     """Reformat a single file under `src` without spawning child processes.
746
747     `fast`, `write_back`, and `mode` options are passed to
748     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
749     """
750     try:
751         changed = Changed.NO
752
753         if str(src) == "-":
754             is_stdin = True
755         elif str(src).startswith(STDIN_PLACEHOLDER):
756             is_stdin = True
757             # Use the original name again in case we want to print something
758             # to the user
759             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
760         else:
761             is_stdin = False
762
763         if is_stdin:
764             if src.suffix == ".pyi":
765                 mode = replace(mode, is_pyi=True)
766             elif src.suffix == ".ipynb":
767                 mode = replace(mode, is_ipynb=True)
768             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
769                 changed = Changed.YES
770         else:
771             cache: Cache = {}
772             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
773                 cache = read_cache(mode)
774                 res_src = src.resolve()
775                 res_src_s = str(res_src)
776                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
777                     changed = Changed.CACHED
778             if changed is not Changed.CACHED and format_file_in_place(
779                 src, fast=fast, write_back=write_back, mode=mode
780             ):
781                 changed = Changed.YES
782             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
783                 write_back is WriteBack.CHECK and changed is Changed.NO
784             ):
785                 write_cache(cache, [src], mode)
786         report.done(src, changed)
787     except Exception as exc:
788         if report.verbose:
789             traceback.print_exc()
790         report.failed(src, str(exc))
791
792
793 def format_file_in_place(
794     src: Path,
795     fast: bool,
796     mode: Mode,
797     write_back: WriteBack = WriteBack.NO,
798     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
799 ) -> bool:
800     """Format file under `src` path. Return True if changed.
801
802     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
803     code to the file.
804     `mode` and `fast` options are passed to :func:`format_file_contents`.
805     """
806     if src.suffix == ".pyi":
807         mode = replace(mode, is_pyi=True)
808     elif src.suffix == ".ipynb":
809         mode = replace(mode, is_ipynb=True)
810
811     then = datetime.utcfromtimestamp(src.stat().st_mtime)
812     header = b""
813     with open(src, "rb") as buf:
814         if mode.skip_source_first_line:
815             header = buf.readline()
816         src_contents, encoding, newline = decode_bytes(buf.read())
817     try:
818         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
819     except NothingChanged:
820         return False
821     except JSONDecodeError:
822         raise ValueError(
823             f"File '{src}' cannot be parsed as valid Jupyter notebook."
824         ) from None
825     src_contents = header.decode(encoding) + src_contents
826     dst_contents = header.decode(encoding) + dst_contents
827
828     if write_back == WriteBack.YES:
829         with open(src, "w", encoding=encoding, newline=newline) as f:
830             f.write(dst_contents)
831     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
832         now = datetime.utcnow()
833         src_name = f"{src}\t{then} +0000"
834         dst_name = f"{src}\t{now} +0000"
835         if mode.is_ipynb:
836             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
837         else:
838             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
839
840         if write_back == WriteBack.COLOR_DIFF:
841             diff_contents = color_diff(diff_contents)
842
843         with lock or nullcontext():
844             f = io.TextIOWrapper(
845                 sys.stdout.buffer,
846                 encoding=encoding,
847                 newline=newline,
848                 write_through=True,
849             )
850             f = wrap_stream_for_windows(f)
851             f.write(diff_contents)
852             f.detach()
853
854     return True
855
856
857 def format_stdin_to_stdout(
858     fast: bool,
859     *,
860     content: Optional[str] = None,
861     write_back: WriteBack = WriteBack.NO,
862     mode: Mode,
863 ) -> bool:
864     """Format file on stdin. Return True if changed.
865
866     If content is None, it's read from sys.stdin.
867
868     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
869     write a diff to stdout. The `mode` argument is passed to
870     :func:`format_file_contents`.
871     """
872     then = datetime.utcnow()
873
874     if content is None:
875         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
876     else:
877         src, encoding, newline = content, "utf-8", ""
878
879     dst = src
880     try:
881         dst = format_file_contents(src, fast=fast, mode=mode)
882         return True
883
884     except NothingChanged:
885         return False
886
887     finally:
888         f = io.TextIOWrapper(
889             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
890         )
891         if write_back == WriteBack.YES:
892             # Make sure there's a newline after the content
893             if dst and dst[-1] != "\n":
894                 dst += "\n"
895             f.write(dst)
896         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
897             now = datetime.utcnow()
898             src_name = f"STDIN\t{then} +0000"
899             dst_name = f"STDOUT\t{now} +0000"
900             d = diff(src, dst, src_name, dst_name)
901             if write_back == WriteBack.COLOR_DIFF:
902                 d = color_diff(d)
903                 f = wrap_stream_for_windows(f)
904             f.write(d)
905         f.detach()
906
907
908 def check_stability_and_equivalence(
909     src_contents: str, dst_contents: str, *, mode: Mode
910 ) -> None:
911     """Perform stability and equivalence checks.
912
913     Raise AssertionError if source and destination contents are not
914     equivalent, or if a second pass of the formatter would format the
915     content differently.
916     """
917     assert_equivalent(src_contents, dst_contents)
918     assert_stable(src_contents, dst_contents, mode=mode)
919
920
921 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
922     """Reformat contents of a file and return new contents.
923
924     If `fast` is False, additionally confirm that the reformatted code is
925     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
926     `mode` is passed to :func:`format_str`.
927     """
928     if not mode.preview and not src_contents.strip():
929         raise NothingChanged
930
931     if mode.is_ipynb:
932         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
933     else:
934         dst_contents = format_str(src_contents, mode=mode)
935     if src_contents == dst_contents:
936         raise NothingChanged
937
938     if not fast and not mode.is_ipynb:
939         # Jupyter notebooks will already have been checked above.
940         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
941     return dst_contents
942
943
944 def validate_cell(src: str, mode: Mode) -> None:
945     """Check that cell does not already contain TransformerManager transformations,
946     or non-Python cell magics, which might cause tokenizer_rt to break because of
947     indentations.
948
949     If a cell contains ``!ls``, then it'll be transformed to
950     ``get_ipython().system('ls')``. However, if the cell originally contained
951     ``get_ipython().system('ls')``, then it would get transformed in the same way:
952
953         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
954         "get_ipython().system('ls')\n"
955         >>> TransformerManager().transform_cell("!ls")
956         "get_ipython().system('ls')\n"
957
958     Due to the impossibility of safely roundtripping in such situations, cells
959     containing transformed magics will be ignored.
960     """
961     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
962         raise NothingChanged
963     if (
964         src[:2] == "%%"
965         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
966     ):
967         raise NothingChanged
968
969
970 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
971     """Format code in given cell of Jupyter notebook.
972
973     General idea is:
974
975       - if cell has trailing semicolon, remove it;
976       - if cell has IPython magics, mask them;
977       - format cell;
978       - reinstate IPython magics;
979       - reinstate trailing semicolon (if originally present);
980       - strip trailing newlines.
981
982     Cells with syntax errors will not be processed, as they
983     could potentially be automagics or multi-line magics, which
984     are currently not supported.
985     """
986     validate_cell(src, mode)
987     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
988         src
989     )
990     try:
991         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
992     except SyntaxError:
993         raise NothingChanged from None
994     masked_dst = format_str(masked_src, mode=mode)
995     if not fast:
996         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
997     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
998     dst = put_trailing_semicolon_back(
999         dst_without_trailing_semicolon, has_trailing_semicolon
1000     )
1001     dst = dst.rstrip("\n")
1002     if dst == src:
1003         raise NothingChanged from None
1004     return dst
1005
1006
1007 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1008     """If notebook is marked as non-Python, don't format it.
1009
1010     All notebook metadata fields are optional, see
1011     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1012     if a notebook has empty metadata, we will try to parse it anyway.
1013     """
1014     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1015     if language is not None and language != "python":
1016         raise NothingChanged from None
1017
1018
1019 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1020     """Format Jupyter notebook.
1021
1022     Operate cell-by-cell, only on code cells, only for Python notebooks.
1023     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1024     """
1025     if mode.preview and not src_contents:
1026         raise NothingChanged
1027
1028     trailing_newline = src_contents[-1] == "\n"
1029     modified = False
1030     nb = json.loads(src_contents)
1031     validate_metadata(nb)
1032     for cell in nb["cells"]:
1033         if cell.get("cell_type", None) == "code":
1034             try:
1035                 src = "".join(cell["source"])
1036                 dst = format_cell(src, fast=fast, mode=mode)
1037             except NothingChanged:
1038                 pass
1039             else:
1040                 cell["source"] = dst.splitlines(keepends=True)
1041                 modified = True
1042     if modified:
1043         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1044         if trailing_newline:
1045             dst_contents = dst_contents + "\n"
1046         return dst_contents
1047     else:
1048         raise NothingChanged
1049
1050
1051 def format_str(src_contents: str, *, mode: Mode) -> str:
1052     """Reformat a string and return new contents.
1053
1054     `mode` determines formatting options, such as how many characters per line are
1055     allowed.  Example:
1056
1057     >>> import black
1058     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1059     def f(arg: str = "") -> None:
1060         ...
1061
1062     A more complex example:
1063
1064     >>> print(
1065     ...   black.format_str(
1066     ...     "def f(arg:str='')->None: hey",
1067     ...     mode=black.Mode(
1068     ...       target_versions={black.TargetVersion.PY36},
1069     ...       line_length=10,
1070     ...       string_normalization=False,
1071     ...       is_pyi=False,
1072     ...     ),
1073     ...   ),
1074     ... )
1075     def f(
1076         arg: str = '',
1077     ) -> None:
1078         hey
1079
1080     """
1081     dst_contents = _format_str_once(src_contents, mode=mode)
1082     # Forced second pass to work around optional trailing commas (becoming
1083     # forced trailing commas on pass 2) interacting differently with optional
1084     # parentheses.  Admittedly ugly.
1085     if src_contents != dst_contents:
1086         return _format_str_once(dst_contents, mode=mode)
1087     return dst_contents
1088
1089
1090 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1091     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1092     dst_blocks: List[LinesBlock] = []
1093     if mode.target_versions:
1094         versions = mode.target_versions
1095     else:
1096         future_imports = get_future_imports(src_node)
1097         versions = detect_target_versions(src_node, future_imports=future_imports)
1098
1099     context_manager_features = {
1100         feature
1101         for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
1102         if supports_feature(versions, feature)
1103     }
1104     normalize_fmt_off(src_node, preview=mode.preview)
1105     lines = LineGenerator(mode=mode, features=context_manager_features)
1106     elt = EmptyLineTracker(mode=mode)
1107     split_line_features = {
1108         feature
1109         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1110         if supports_feature(versions, feature)
1111     }
1112     block: Optional[LinesBlock] = None
1113     for current_line in lines.visit(src_node):
1114         block = elt.maybe_empty_lines(current_line)
1115         dst_blocks.append(block)
1116         for line in transform_line(
1117             current_line, mode=mode, features=split_line_features
1118         ):
1119             block.content_lines.append(str(line))
1120     if dst_blocks:
1121         dst_blocks[-1].after = 0
1122     dst_contents = []
1123     for block in dst_blocks:
1124         dst_contents.extend(block.all_lines())
1125     if mode.preview and not dst_contents:
1126         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1127         # and check if normalized_content has more than one line
1128         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1129         if "\n" in normalized_content:
1130             return newline
1131         return ""
1132     return "".join(dst_contents)
1133
1134
1135 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1136     """Return a tuple of (decoded_contents, encoding, newline).
1137
1138     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1139     universal newlines (i.e. only contains LF).
1140     """
1141     srcbuf = io.BytesIO(src)
1142     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1143     if not lines:
1144         return "", encoding, "\n"
1145
1146     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1147     srcbuf.seek(0)
1148     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1149         return tiow.read(), encoding, newline
1150
1151
1152 def get_features_used(  # noqa: C901
1153     node: Node, *, future_imports: Optional[Set[str]] = None
1154 ) -> Set[Feature]:
1155     """Return a set of (relatively) new Python features used in this file.
1156
1157     Currently looking for:
1158     - f-strings;
1159     - self-documenting expressions in f-strings (f"{x=}");
1160     - underscores in numeric literals;
1161     - trailing commas after * or ** in function signatures and calls;
1162     - positional only arguments in function signatures and lambdas;
1163     - assignment expression;
1164     - relaxed decorator syntax;
1165     - usage of __future__ flags (annotations);
1166     - print / exec statements;
1167     - parenthesized context managers;
1168     - match statements;
1169     - except* clause;
1170     - variadic generics;
1171     """
1172     features: Set[Feature] = set()
1173     if future_imports:
1174         features |= {
1175             FUTURE_FLAG_TO_FEATURE[future_import]
1176             for future_import in future_imports
1177             if future_import in FUTURE_FLAG_TO_FEATURE
1178         }
1179
1180     for n in node.pre_order():
1181         if is_string_token(n):
1182             value_head = n.value[:2]
1183             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1184                 features.add(Feature.F_STRINGS)
1185                 if Feature.DEBUG_F_STRINGS not in features:
1186                     for span_beg, span_end in iter_fexpr_spans(n.value):
1187                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1188                             features.add(Feature.DEBUG_F_STRINGS)
1189                             break
1190
1191         elif is_number_token(n):
1192             if "_" in n.value:
1193                 features.add(Feature.NUMERIC_UNDERSCORES)
1194
1195         elif n.type == token.SLASH:
1196             if n.parent and n.parent.type in {
1197                 syms.typedargslist,
1198                 syms.arglist,
1199                 syms.varargslist,
1200             }:
1201                 features.add(Feature.POS_ONLY_ARGUMENTS)
1202
1203         elif n.type == token.COLONEQUAL:
1204             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1205
1206         elif n.type == syms.decorator:
1207             if len(n.children) > 1 and not is_simple_decorator_expression(
1208                 n.children[1]
1209             ):
1210                 features.add(Feature.RELAXED_DECORATORS)
1211
1212         elif (
1213             n.type in {syms.typedargslist, syms.arglist}
1214             and n.children
1215             and n.children[-1].type == token.COMMA
1216         ):
1217             if n.type == syms.typedargslist:
1218                 feature = Feature.TRAILING_COMMA_IN_DEF
1219             else:
1220                 feature = Feature.TRAILING_COMMA_IN_CALL
1221
1222             for ch in n.children:
1223                 if ch.type in STARS:
1224                     features.add(feature)
1225
1226                 if ch.type == syms.argument:
1227                     for argch in ch.children:
1228                         if argch.type in STARS:
1229                             features.add(feature)
1230
1231         elif (
1232             n.type in {syms.return_stmt, syms.yield_expr}
1233             and len(n.children) >= 2
1234             and n.children[1].type == syms.testlist_star_expr
1235             and any(child.type == syms.star_expr for child in n.children[1].children)
1236         ):
1237             features.add(Feature.UNPACKING_ON_FLOW)
1238
1239         elif (
1240             n.type == syms.annassign
1241             and len(n.children) >= 4
1242             and n.children[3].type == syms.testlist_star_expr
1243         ):
1244             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1245
1246         elif (
1247             n.type == syms.with_stmt
1248             and len(n.children) > 2
1249             and n.children[1].type == syms.atom
1250         ):
1251             atom_children = n.children[1].children
1252             if (
1253                 len(atom_children) == 3
1254                 and atom_children[0].type == token.LPAR
1255                 and atom_children[1].type == syms.testlist_gexp
1256                 and atom_children[2].type == token.RPAR
1257             ):
1258                 features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
1259
1260         elif n.type == syms.match_stmt:
1261             features.add(Feature.PATTERN_MATCHING)
1262
1263         elif (
1264             n.type == syms.except_clause
1265             and len(n.children) >= 2
1266             and n.children[1].type == token.STAR
1267         ):
1268             features.add(Feature.EXCEPT_STAR)
1269
1270         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1271             child.type == syms.star_expr for child in n.children
1272         ):
1273             features.add(Feature.VARIADIC_GENERICS)
1274
1275         elif (
1276             n.type == syms.tname_star
1277             and len(n.children) == 3
1278             and n.children[2].type == syms.star_expr
1279         ):
1280             features.add(Feature.VARIADIC_GENERICS)
1281
1282     return features
1283
1284
1285 def detect_target_versions(
1286     node: Node, *, future_imports: Optional[Set[str]] = None
1287 ) -> Set[TargetVersion]:
1288     """Detect the version to target based on the nodes used."""
1289     features = get_features_used(node, future_imports=future_imports)
1290     return {
1291         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1292     }
1293
1294
1295 def get_future_imports(node: Node) -> Set[str]:
1296     """Return a set of __future__ imports in the file."""
1297     imports: Set[str] = set()
1298
1299     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1300         for child in children:
1301             if isinstance(child, Leaf):
1302                 if child.type == token.NAME:
1303                     yield child.value
1304
1305             elif child.type == syms.import_as_name:
1306                 orig_name = child.children[0]
1307                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1308                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1309                 yield orig_name.value
1310
1311             elif child.type == syms.import_as_names:
1312                 yield from get_imports_from_children(child.children)
1313
1314             else:
1315                 raise AssertionError("Invalid syntax parsing imports")
1316
1317     for child in node.children:
1318         if child.type != syms.simple_stmt:
1319             break
1320
1321         first_child = child.children[0]
1322         if isinstance(first_child, Leaf):
1323             # Continue looking if we see a docstring; otherwise stop.
1324             if (
1325                 len(child.children) == 2
1326                 and first_child.type == token.STRING
1327                 and child.children[1].type == token.NEWLINE
1328             ):
1329                 continue
1330
1331             break
1332
1333         elif first_child.type == syms.import_from:
1334             module_name = first_child.children[1]
1335             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1336                 break
1337
1338             imports |= set(get_imports_from_children(first_child.children[3:]))
1339         else:
1340             break
1341
1342     return imports
1343
1344
1345 def assert_equivalent(src: str, dst: str) -> None:
1346     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1347     try:
1348         src_ast = parse_ast(src)
1349     except Exception as exc:
1350         raise AssertionError(
1351             "cannot use --safe with this file; failed to parse source file AST: "
1352             f"{exc}\n"
1353             "This could be caused by running Black with an older Python version "
1354             "that does not support new syntax used in your source file."
1355         ) from exc
1356
1357     try:
1358         dst_ast = parse_ast(dst)
1359     except Exception as exc:
1360         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1361         raise AssertionError(
1362             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1363             "Please report a bug on https://github.com/psf/black/issues.  "
1364             f"This invalid output might be helpful: {log}"
1365         ) from None
1366
1367     src_ast_str = "\n".join(stringify_ast(src_ast))
1368     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1369     if src_ast_str != dst_ast_str:
1370         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1371         raise AssertionError(
1372             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1373             " source.  Please report a bug on "
1374             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1375         ) from None
1376
1377
1378 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1379     """Raise AssertionError if `dst` reformats differently the second time."""
1380     # We shouldn't call format_str() here, because that formats the string
1381     # twice and may hide a bug where we bounce back and forth between two
1382     # versions.
1383     newdst = _format_str_once(dst, mode=mode)
1384     if dst != newdst:
1385         log = dump_to_file(
1386             str(mode),
1387             diff(src, dst, "source", "first pass"),
1388             diff(dst, newdst, "first pass", "second pass"),
1389         )
1390         raise AssertionError(
1391             "INTERNAL ERROR: Black produced different code on the second pass of the"
1392             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1393             f"  This diff might be helpful: {log}"
1394         ) from None
1395
1396
1397 @contextmanager
1398 def nullcontext() -> Iterator[None]:
1399     """Return an empty context manager.
1400
1401     To be used like `nullcontext` in Python 3.7.
1402     """
1403     yield
1404
1405
1406 def patch_click() -> None:
1407     """Make Click not crash on Python 3.6 with LANG=C.
1408
1409     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1410     default which restricts paths that it can access during the lifetime of the
1411     application.  Click refuses to work in this scenario by raising a RuntimeError.
1412
1413     In case of Black the likelihood that non-ASCII characters are going to be used in
1414     file paths is minimal since it's Python source code.  Moreover, this crash was
1415     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1416     """
1417     modules: List[Any] = []
1418     try:
1419         from click import core
1420     except ImportError:
1421         pass
1422     else:
1423         modules.append(core)
1424     try:
1425         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1426         # older versions installed.
1427         from click import _unicodefun  # type: ignore
1428     except ImportError:
1429         pass
1430     else:
1431         modules.append(_unicodefun)
1432
1433     for module in modules:
1434         if hasattr(module, "_verify_python3_env"):
1435             module._verify_python3_env = lambda: None
1436         if hasattr(module, "_verify_python_env"):
1437             module._verify_python_env = lambda: None
1438
1439
1440 def patched_main() -> None:
1441     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1442     # environments so just assume we always need to call it if frozen.
1443     if getattr(sys, "frozen", False):
1444         from multiprocessing import freeze_support
1445
1446         freeze_support()
1447
1448     patch_click()
1449     main()
1450
1451
1452 if __name__ == "__main__":
1453     patched_main()