]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump mypy[c] from 0.971 to 0.991 (#3380)
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache, get_cache_info, read_cache, write_cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import (
67     FUTURE_FLAG_TO_FEATURE,
68     VERSION_TO_FEATURES,
69     Feature,
70     Mode,
71     TargetVersion,
72     supports_feature,
73 )
74 from black.nodes import (
75     STARS,
76     is_number_token,
77     is_simple_decorator_expression,
78     is_string_token,
79     syms,
80 )
81 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
82 from black.parsing import InvalidInput  # noqa F401
83 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
84 from black.report import Changed, NothingChanged, Report
85 from black.trans import iter_fexpr_spans
86 from blib2to3.pgen2 import token
87 from blib2to3.pytree import Leaf, Node
88
89 COMPILED = Path(__file__).suffix in (".pyd", ".so")
90
91 # types
92 FileContent = str
93 Encoding = str
94 NewLine = str
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102     COLOR_DIFF = 4
103
104     @classmethod
105     def from_configuration(
106         cls, *, check: bool, diff: bool, color: bool = False
107     ) -> "WriteBack":
108         if check and not diff:
109             return cls.CHECK
110
111         if diff and color:
112             return cls.COLOR_DIFF
113
114         return cls.DIFF if diff else cls.YES
115
116
117 # Legacy name, left for integrations.
118 FileMode = Mode
119
120
121 def read_pyproject_toml(
122     ctx: click.Context, param: click.Parameter, value: Optional[str]
123 ) -> Optional[str]:
124     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
125
126     Returns the path to a successfully found and read configuration file, None
127     otherwise.
128     """
129     if not value:
130         value = find_pyproject_toml(ctx.params.get("src", ()))
131         if value is None:
132             return None
133
134     try:
135         config = parse_pyproject_toml(value)
136     except (OSError, ValueError) as e:
137         raise click.FileError(
138             filename=value, hint=f"Error reading configuration file: {e}"
139         ) from None
140
141     if not config:
142         return None
143     else:
144         # Sanitize the values to be Click friendly. For more information please see:
145         # https://github.com/psf/black/issues/1458
146         # https://github.com/pallets/click/issues/1567
147         config = {
148             k: str(v) if not isinstance(v, (list, dict)) else v
149             for k, v in config.items()
150         }
151
152     target_version = config.get("target_version")
153     if target_version is not None and not isinstance(target_version, list):
154         raise click.BadOptionUsage(
155             "target-version", "Config key target-version must be a list"
156         )
157
158     default_map: Dict[str, Any] = {}
159     if ctx.default_map:
160         default_map.update(ctx.default_map)
161     default_map.update(config)
162
163     ctx.default_map = default_map
164     return value
165
166
167 def target_version_option_callback(
168     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
169 ) -> List[TargetVersion]:
170     """Compute the target versions from a --target-version flag.
171
172     This is its own function because mypy couldn't infer the type correctly
173     when it was a lambda, causing mypyc trouble.
174     """
175     return [TargetVersion[val.upper()] for val in v]
176
177
178 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
179     """Compile a regular expression string in `regex`.
180
181     If it contains newlines, use verbose mode.
182     """
183     if "\n" in regex:
184         regex = "(?x)" + regex
185     compiled: Pattern[str] = re.compile(regex)
186     return compiled
187
188
189 def validate_regex(
190     ctx: click.Context,
191     param: click.Parameter,
192     value: Optional[str],
193 ) -> Optional[Pattern[str]]:
194     try:
195         return re_compile_maybe_verbose(value) if value is not None else None
196     except re.error as e:
197         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
198
199
200 @click.command(
201     context_settings={"help_option_names": ["-h", "--help"]},
202     # While Click does set this field automatically using the docstring, mypyc
203     # (annoyingly) strips 'em so we need to set it here too.
204     help="The uncompromising code formatter.",
205 )
206 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
207 @click.option(
208     "-l",
209     "--line-length",
210     type=int,
211     default=DEFAULT_LINE_LENGTH,
212     help="How many characters per line to allow.",
213     show_default=True,
214 )
215 @click.option(
216     "-t",
217     "--target-version",
218     type=click.Choice([v.name.lower() for v in TargetVersion]),
219     callback=target_version_option_callback,
220     multiple=True,
221     help=(
222         "Python versions that should be supported by Black's output. [default: per-file"
223         " auto-detection]"
224     ),
225 )
226 @click.option(
227     "--pyi",
228     is_flag=True,
229     help=(
230         "Format all input files like typing stubs regardless of file extension (useful"
231         " when piping source on standard input)."
232     ),
233 )
234 @click.option(
235     "--ipynb",
236     is_flag=True,
237     help=(
238         "Format all input files like Jupyter Notebooks regardless of file extension "
239         "(useful when piping source on standard input)."
240     ),
241 )
242 @click.option(
243     "--python-cell-magics",
244     multiple=True,
245     help=(
246         "When processing Jupyter Notebooks, add the given magic to the list"
247         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
248         " Useful for formatting cells with custom python magics."
249     ),
250     default=[],
251 )
252 @click.option(
253     "-x",
254     "--skip-source-first-line",
255     is_flag=True,
256     help="Skip the first line of the source code.",
257 )
258 @click.option(
259     "-S",
260     "--skip-string-normalization",
261     is_flag=True,
262     help="Don't normalize string quotes or prefixes.",
263 )
264 @click.option(
265     "-C",
266     "--skip-magic-trailing-comma",
267     is_flag=True,
268     help="Don't use trailing commas as a reason to split lines.",
269 )
270 @click.option(
271     "--experimental-string-processing",
272     is_flag=True,
273     hidden=True,
274     help="(DEPRECATED and now included in --preview) Normalize string literals.",
275 )
276 @click.option(
277     "--preview",
278     is_flag=True,
279     help=(
280         "Enable potentially disruptive style changes that may be added to Black's main"
281         " functionality in the next major release."
282     ),
283 )
284 @click.option(
285     "--check",
286     is_flag=True,
287     help=(
288         "Don't write the files back, just return the status. Return code 0 means"
289         " nothing would change. Return code 1 means some files would be reformatted."
290         " Return code 123 means there was an internal error."
291     ),
292 )
293 @click.option(
294     "--diff",
295     is_flag=True,
296     help="Don't write the files back, just output a diff for each file on stdout.",
297 )
298 @click.option(
299     "--color/--no-color",
300     is_flag=True,
301     help="Show colored diff. Only applies when `--diff` is given.",
302 )
303 @click.option(
304     "--fast/--safe",
305     is_flag=True,
306     help="If --fast given, skip temporary sanity checks. [default: --safe]",
307 )
308 @click.option(
309     "--required-version",
310     type=str,
311     help=(
312         "Require a specific version of Black to be running (useful for unifying results"
313         " across many environments e.g. with a pyproject.toml file). It can be"
314         " either a major version number or an exact version."
315     ),
316 )
317 @click.option(
318     "--include",
319     type=str,
320     default=DEFAULT_INCLUDES,
321     callback=validate_regex,
322     help=(
323         "A regular expression that matches files and directories that should be"
324         " included on recursive searches. An empty value means all files are included"
325         " regardless of the name. Use forward slashes for directories on all platforms"
326         " (Windows, too). Exclusions are calculated first, inclusions later."
327     ),
328     show_default=True,
329 )
330 @click.option(
331     "--exclude",
332     type=str,
333     callback=validate_regex,
334     help=(
335         "A regular expression that matches files and directories that should be"
336         " excluded on recursive searches. An empty value means no paths are excluded."
337         " Use forward slashes for directories on all platforms (Windows, too)."
338         " Exclusions are calculated first, inclusions later. [default:"
339         f" {DEFAULT_EXCLUDES}]"
340     ),
341     show_default=False,
342 )
343 @click.option(
344     "--extend-exclude",
345     type=str,
346     callback=validate_regex,
347     help=(
348         "Like --exclude, but adds additional files and directories on top of the"
349         " excluded ones. (Useful if you simply want to add to the default)"
350     ),
351 )
352 @click.option(
353     "--force-exclude",
354     type=str,
355     callback=validate_regex,
356     help=(
357         "Like --exclude, but files and directories matching this regex will be "
358         "excluded even when they are passed explicitly as arguments."
359     ),
360 )
361 @click.option(
362     "--stdin-filename",
363     type=str,
364     help=(
365         "The name of the file when passing it through stdin. Useful to make "
366         "sure Black will respect --force-exclude option on some "
367         "editors that rely on using stdin."
368     ),
369 )
370 @click.option(
371     "-W",
372     "--workers",
373     type=click.IntRange(min=1),
374     default=None,
375     help="Number of parallel workers [default: number of CPUs in the system]",
376 )
377 @click.option(
378     "-q",
379     "--quiet",
380     is_flag=True,
381     help=(
382         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
383         " those with 2>/dev/null."
384     ),
385 )
386 @click.option(
387     "-v",
388     "--verbose",
389     is_flag=True,
390     help=(
391         "Also emit messages to stderr about files that were not changed or were ignored"
392         " due to exclusion patterns."
393     ),
394 )
395 @click.version_option(
396     version=__version__,
397     message=(
398         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
399         f"Python ({platform.python_implementation()}) {platform.python_version()}"
400     ),
401 )
402 @click.argument(
403     "src",
404     nargs=-1,
405     type=click.Path(
406         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
407     ),
408     is_eager=True,
409     metavar="SRC ...",
410 )
411 @click.option(
412     "--config",
413     type=click.Path(
414         exists=True,
415         file_okay=True,
416         dir_okay=False,
417         readable=True,
418         allow_dash=False,
419         path_type=str,
420     ),
421     is_eager=True,
422     callback=read_pyproject_toml,
423     help="Read configuration from FILE path.",
424 )
425 @click.pass_context
426 def main(  # noqa: C901
427     ctx: click.Context,
428     code: Optional[str],
429     line_length: int,
430     target_version: List[TargetVersion],
431     check: bool,
432     diff: bool,
433     color: bool,
434     fast: bool,
435     pyi: bool,
436     ipynb: bool,
437     python_cell_magics: Sequence[str],
438     skip_source_first_line: bool,
439     skip_string_normalization: bool,
440     skip_magic_trailing_comma: bool,
441     experimental_string_processing: bool,
442     preview: bool,
443     quiet: bool,
444     verbose: bool,
445     required_version: Optional[str],
446     include: Pattern[str],
447     exclude: Optional[Pattern[str]],
448     extend_exclude: Optional[Pattern[str]],
449     force_exclude: Optional[Pattern[str]],
450     stdin_filename: Optional[str],
451     workers: Optional[int],
452     src: Tuple[str, ...],
453     config: Optional[str],
454 ) -> None:
455     """The uncompromising code formatter."""
456     ctx.ensure_object(dict)
457
458     if src and code is not None:
459         out(
460             main.get_usage(ctx)
461             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
462         )
463         ctx.exit(1)
464     if not src and code is None:
465         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
466         ctx.exit(1)
467
468     root, method = (
469         find_project_root(src, stdin_filename) if code is None else (None, None)
470     )
471     ctx.obj["root"] = root
472
473     if verbose:
474         if root:
475             out(
476                 f"Identified `{root}` as project root containing a {method}.",
477                 fg="blue",
478             )
479
480             normalized = [
481                 (source, source)
482                 if source == "-"
483                 else (normalize_path_maybe_ignore(Path(source), root), source)
484                 for source in src
485             ]
486             srcs_string = ", ".join(
487                 [
488                     f'"{_norm}"'
489                     if _norm
490                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
491                     for _norm, source in normalized
492                 ]
493             )
494             out(f"Sources to be formatted: {srcs_string}", fg="blue")
495
496         if config:
497             config_source = ctx.get_parameter_source("config")
498             user_level_config = str(find_user_pyproject_toml())
499             if config == user_level_config:
500                 out(
501                     (
502                         "Using configuration from user-level config at "
503                         f"'{user_level_config}'."
504                     ),
505                     fg="blue",
506                 )
507             elif config_source in (
508                 ParameterSource.DEFAULT,
509                 ParameterSource.DEFAULT_MAP,
510             ):
511                 out("Using configuration from project root.", fg="blue")
512             else:
513                 out(f"Using configuration in '{config}'.", fg="blue")
514             if ctx.default_map:
515                 for param, value in ctx.default_map.items():
516                     out(f"{param}: {value}")
517
518     error_msg = "Oh no! 💥 💔 💥"
519     if (
520         required_version
521         and required_version != __version__
522         and required_version != __version__.split(".")[0]
523     ):
524         err(
525             f"{error_msg} The required version `{required_version}` does not match"
526             f" the running version `{__version__}`!"
527         )
528         ctx.exit(1)
529     if ipynb and pyi:
530         err("Cannot pass both `pyi` and `ipynb` flags!")
531         ctx.exit(1)
532
533     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
534     if target_version:
535         versions = set(target_version)
536     else:
537         # We'll autodetect later.
538         versions = set()
539     mode = Mode(
540         target_versions=versions,
541         line_length=line_length,
542         is_pyi=pyi,
543         is_ipynb=ipynb,
544         skip_source_first_line=skip_source_first_line,
545         string_normalization=not skip_string_normalization,
546         magic_trailing_comma=not skip_magic_trailing_comma,
547         experimental_string_processing=experimental_string_processing,
548         preview=preview,
549         python_cell_magics=set(python_cell_magics),
550     )
551
552     if code is not None:
553         # Run in quiet mode by default with -c; the extra output isn't useful.
554         # You can still pass -v to get verbose output.
555         quiet = True
556
557     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
558
559     if code is not None:
560         reformat_code(
561             content=code, fast=fast, write_back=write_back, mode=mode, report=report
562         )
563     else:
564         try:
565             sources = get_sources(
566                 ctx=ctx,
567                 src=src,
568                 quiet=quiet,
569                 verbose=verbose,
570                 include=include,
571                 exclude=exclude,
572                 extend_exclude=extend_exclude,
573                 force_exclude=force_exclude,
574                 report=report,
575                 stdin_filename=stdin_filename,
576             )
577         except GitWildMatchPatternError:
578             ctx.exit(1)
579
580         path_empty(
581             sources,
582             "No Python files are present to be formatted. Nothing to do 😴",
583             quiet,
584             verbose,
585             ctx,
586         )
587
588         if len(sources) == 1:
589             reformat_one(
590                 src=sources.pop(),
591                 fast=fast,
592                 write_back=write_back,
593                 mode=mode,
594                 report=report,
595             )
596         else:
597             from black.concurrency import reformat_many
598
599             reformat_many(
600                 sources=sources,
601                 fast=fast,
602                 write_back=write_back,
603                 mode=mode,
604                 report=report,
605                 workers=workers,
606             )
607
608     if verbose or not quiet:
609         if code is None and (verbose or report.change_count or report.failure_count):
610             out()
611         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
612         if code is None:
613             click.echo(str(report), err=True)
614     ctx.exit(report.return_code)
615
616
617 def get_sources(
618     *,
619     ctx: click.Context,
620     src: Tuple[str, ...],
621     quiet: bool,
622     verbose: bool,
623     include: Pattern[str],
624     exclude: Optional[Pattern[str]],
625     extend_exclude: Optional[Pattern[str]],
626     force_exclude: Optional[Pattern[str]],
627     report: "Report",
628     stdin_filename: Optional[str],
629 ) -> Set[Path]:
630     """Compute the set of files to be formatted."""
631     sources: Set[Path] = set()
632     root = ctx.obj["root"]
633
634     using_default_exclude = exclude is None
635     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
636     gitignore: Optional[Dict[Path, PathSpec]] = None
637     root_gitignore = get_gitignore(root)
638
639     for s in src:
640         if s == "-" and stdin_filename:
641             p = Path(stdin_filename)
642             is_stdin = True
643         else:
644             p = Path(s)
645             is_stdin = False
646
647         if is_stdin or p.is_file():
648             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
649             if normalized_path is None:
650                 continue
651
652             normalized_path = "/" + normalized_path
653             # Hard-exclude any files that matches the `--force-exclude` regex.
654             if force_exclude:
655                 force_exclude_match = force_exclude.search(normalized_path)
656             else:
657                 force_exclude_match = None
658             if force_exclude_match and force_exclude_match.group(0):
659                 report.path_ignored(p, "matches the --force-exclude regular expression")
660                 continue
661
662             if is_stdin:
663                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
664
665             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
666                 verbose=verbose, quiet=quiet
667             ):
668                 continue
669
670             sources.add(p)
671         elif p.is_dir():
672             if using_default_exclude:
673                 gitignore = {
674                     root: root_gitignore,
675                     root / p: get_gitignore(p),
676                 }
677             sources.update(
678                 gen_python_files(
679                     p.iterdir(),
680                     ctx.obj["root"],
681                     include,
682                     exclude,
683                     extend_exclude,
684                     force_exclude,
685                     report,
686                     gitignore,
687                     verbose=verbose,
688                     quiet=quiet,
689                 )
690             )
691         elif s == "-":
692             sources.add(p)
693         else:
694             err(f"invalid path: {s}")
695     return sources
696
697
698 def path_empty(
699     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
700 ) -> None:
701     """
702     Exit if there is no `src` provided for formatting
703     """
704     if not src:
705         if verbose or not quiet:
706             out(msg)
707         ctx.exit(0)
708
709
710 def reformat_code(
711     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
712 ) -> None:
713     """
714     Reformat and print out `content` without spawning child processes.
715     Similar to `reformat_one`, but for string content.
716
717     `fast`, `write_back`, and `mode` options are passed to
718     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
719     """
720     path = Path("<string>")
721     try:
722         changed = Changed.NO
723         if format_stdin_to_stdout(
724             content=content, fast=fast, write_back=write_back, mode=mode
725         ):
726             changed = Changed.YES
727         report.done(path, changed)
728     except Exception as exc:
729         if report.verbose:
730             traceback.print_exc()
731         report.failed(path, str(exc))
732
733
734 # diff-shades depends on being to monkeypatch this function to operate. I know it's
735 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
736 @mypyc_attr(patchable=True)
737 def reformat_one(
738     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
739 ) -> None:
740     """Reformat a single file under `src` without spawning child processes.
741
742     `fast`, `write_back`, and `mode` options are passed to
743     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
744     """
745     try:
746         changed = Changed.NO
747
748         if str(src) == "-":
749             is_stdin = True
750         elif str(src).startswith(STDIN_PLACEHOLDER):
751             is_stdin = True
752             # Use the original name again in case we want to print something
753             # to the user
754             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
755         else:
756             is_stdin = False
757
758         if is_stdin:
759             if src.suffix == ".pyi":
760                 mode = replace(mode, is_pyi=True)
761             elif src.suffix == ".ipynb":
762                 mode = replace(mode, is_ipynb=True)
763             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
764                 changed = Changed.YES
765         else:
766             cache: Cache = {}
767             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
768                 cache = read_cache(mode)
769                 res_src = src.resolve()
770                 res_src_s = str(res_src)
771                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
772                     changed = Changed.CACHED
773             if changed is not Changed.CACHED and format_file_in_place(
774                 src, fast=fast, write_back=write_back, mode=mode
775             ):
776                 changed = Changed.YES
777             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
778                 write_back is WriteBack.CHECK and changed is Changed.NO
779             ):
780                 write_cache(cache, [src], mode)
781         report.done(src, changed)
782     except Exception as exc:
783         if report.verbose:
784             traceback.print_exc()
785         report.failed(src, str(exc))
786
787
788 def format_file_in_place(
789     src: Path,
790     fast: bool,
791     mode: Mode,
792     write_back: WriteBack = WriteBack.NO,
793     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
794 ) -> bool:
795     """Format file under `src` path. Return True if changed.
796
797     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
798     code to the file.
799     `mode` and `fast` options are passed to :func:`format_file_contents`.
800     """
801     if src.suffix == ".pyi":
802         mode = replace(mode, is_pyi=True)
803     elif src.suffix == ".ipynb":
804         mode = replace(mode, is_ipynb=True)
805
806     then = datetime.utcfromtimestamp(src.stat().st_mtime)
807     header = b""
808     with open(src, "rb") as buf:
809         if mode.skip_source_first_line:
810             header = buf.readline()
811         src_contents, encoding, newline = decode_bytes(buf.read())
812     try:
813         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
814     except NothingChanged:
815         return False
816     except JSONDecodeError:
817         raise ValueError(
818             f"File '{src}' cannot be parsed as valid Jupyter notebook."
819         ) from None
820     src_contents = header.decode(encoding) + src_contents
821     dst_contents = header.decode(encoding) + dst_contents
822
823     if write_back == WriteBack.YES:
824         with open(src, "w", encoding=encoding, newline=newline) as f:
825             f.write(dst_contents)
826     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
827         now = datetime.utcnow()
828         src_name = f"{src}\t{then} +0000"
829         dst_name = f"{src}\t{now} +0000"
830         if mode.is_ipynb:
831             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
832         else:
833             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
834
835         if write_back == WriteBack.COLOR_DIFF:
836             diff_contents = color_diff(diff_contents)
837
838         with lock or nullcontext():
839             f = io.TextIOWrapper(
840                 sys.stdout.buffer,
841                 encoding=encoding,
842                 newline=newline,
843                 write_through=True,
844             )
845             f = wrap_stream_for_windows(f)
846             f.write(diff_contents)
847             f.detach()
848
849     return True
850
851
852 def format_stdin_to_stdout(
853     fast: bool,
854     *,
855     content: Optional[str] = None,
856     write_back: WriteBack = WriteBack.NO,
857     mode: Mode,
858 ) -> bool:
859     """Format file on stdin. Return True if changed.
860
861     If content is None, it's read from sys.stdin.
862
863     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
864     write a diff to stdout. The `mode` argument is passed to
865     :func:`format_file_contents`.
866     """
867     then = datetime.utcnow()
868
869     if content is None:
870         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
871     else:
872         src, encoding, newline = content, "utf-8", ""
873
874     dst = src
875     try:
876         dst = format_file_contents(src, fast=fast, mode=mode)
877         return True
878
879     except NothingChanged:
880         return False
881
882     finally:
883         f = io.TextIOWrapper(
884             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
885         )
886         if write_back == WriteBack.YES:
887             # Make sure there's a newline after the content
888             if dst and dst[-1] != "\n":
889                 dst += "\n"
890             f.write(dst)
891         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
892             now = datetime.utcnow()
893             src_name = f"STDIN\t{then} +0000"
894             dst_name = f"STDOUT\t{now} +0000"
895             d = diff(src, dst, src_name, dst_name)
896             if write_back == WriteBack.COLOR_DIFF:
897                 d = color_diff(d)
898                 f = wrap_stream_for_windows(f)
899             f.write(d)
900         f.detach()
901
902
903 def check_stability_and_equivalence(
904     src_contents: str, dst_contents: str, *, mode: Mode
905 ) -> None:
906     """Perform stability and equivalence checks.
907
908     Raise AssertionError if source and destination contents are not
909     equivalent, or if a second pass of the formatter would format the
910     content differently.
911     """
912     assert_equivalent(src_contents, dst_contents)
913     assert_stable(src_contents, dst_contents, mode=mode)
914
915
916 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
917     """Reformat contents of a file and return new contents.
918
919     If `fast` is False, additionally confirm that the reformatted code is
920     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
921     `mode` is passed to :func:`format_str`.
922     """
923     if not mode.preview and not src_contents.strip():
924         raise NothingChanged
925
926     if mode.is_ipynb:
927         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
928     else:
929         dst_contents = format_str(src_contents, mode=mode)
930     if src_contents == dst_contents:
931         raise NothingChanged
932
933     if not fast and not mode.is_ipynb:
934         # Jupyter notebooks will already have been checked above.
935         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
936     return dst_contents
937
938
939 def validate_cell(src: str, mode: Mode) -> None:
940     """Check that cell does not already contain TransformerManager transformations,
941     or non-Python cell magics, which might cause tokenizer_rt to break because of
942     indentations.
943
944     If a cell contains ``!ls``, then it'll be transformed to
945     ``get_ipython().system('ls')``. However, if the cell originally contained
946     ``get_ipython().system('ls')``, then it would get transformed in the same way:
947
948         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
949         "get_ipython().system('ls')\n"
950         >>> TransformerManager().transform_cell("!ls")
951         "get_ipython().system('ls')\n"
952
953     Due to the impossibility of safely roundtripping in such situations, cells
954     containing transformed magics will be ignored.
955     """
956     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
957         raise NothingChanged
958     if (
959         src[:2] == "%%"
960         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
961     ):
962         raise NothingChanged
963
964
965 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
966     """Format code in given cell of Jupyter notebook.
967
968     General idea is:
969
970       - if cell has trailing semicolon, remove it;
971       - if cell has IPython magics, mask them;
972       - format cell;
973       - reinstate IPython magics;
974       - reinstate trailing semicolon (if originally present);
975       - strip trailing newlines.
976
977     Cells with syntax errors will not be processed, as they
978     could potentially be automagics or multi-line magics, which
979     are currently not supported.
980     """
981     validate_cell(src, mode)
982     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
983         src
984     )
985     try:
986         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
987     except SyntaxError:
988         raise NothingChanged from None
989     masked_dst = format_str(masked_src, mode=mode)
990     if not fast:
991         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
992     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
993     dst = put_trailing_semicolon_back(
994         dst_without_trailing_semicolon, has_trailing_semicolon
995     )
996     dst = dst.rstrip("\n")
997     if dst == src:
998         raise NothingChanged from None
999     return dst
1000
1001
1002 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1003     """If notebook is marked as non-Python, don't format it.
1004
1005     All notebook metadata fields are optional, see
1006     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1007     if a notebook has empty metadata, we will try to parse it anyway.
1008     """
1009     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1010     if language is not None and language != "python":
1011         raise NothingChanged from None
1012
1013
1014 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1015     """Format Jupyter notebook.
1016
1017     Operate cell-by-cell, only on code cells, only for Python notebooks.
1018     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1019     """
1020     if mode.preview and not src_contents:
1021         raise NothingChanged
1022
1023     trailing_newline = src_contents[-1] == "\n"
1024     modified = False
1025     nb = json.loads(src_contents)
1026     validate_metadata(nb)
1027     for cell in nb["cells"]:
1028         if cell.get("cell_type", None) == "code":
1029             try:
1030                 src = "".join(cell["source"])
1031                 dst = format_cell(src, fast=fast, mode=mode)
1032             except NothingChanged:
1033                 pass
1034             else:
1035                 cell["source"] = dst.splitlines(keepends=True)
1036                 modified = True
1037     if modified:
1038         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1039         if trailing_newline:
1040             dst_contents = dst_contents + "\n"
1041         return dst_contents
1042     else:
1043         raise NothingChanged
1044
1045
1046 def format_str(src_contents: str, *, mode: Mode) -> str:
1047     """Reformat a string and return new contents.
1048
1049     `mode` determines formatting options, such as how many characters per line are
1050     allowed.  Example:
1051
1052     >>> import black
1053     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1054     def f(arg: str = "") -> None:
1055         ...
1056
1057     A more complex example:
1058
1059     >>> print(
1060     ...   black.format_str(
1061     ...     "def f(arg:str='')->None: hey",
1062     ...     mode=black.Mode(
1063     ...       target_versions={black.TargetVersion.PY36},
1064     ...       line_length=10,
1065     ...       string_normalization=False,
1066     ...       is_pyi=False,
1067     ...     ),
1068     ...   ),
1069     ... )
1070     def f(
1071         arg: str = '',
1072     ) -> None:
1073         hey
1074
1075     """
1076     dst_contents = _format_str_once(src_contents, mode=mode)
1077     # Forced second pass to work around optional trailing commas (becoming
1078     # forced trailing commas on pass 2) interacting differently with optional
1079     # parentheses.  Admittedly ugly.
1080     if src_contents != dst_contents:
1081         return _format_str_once(dst_contents, mode=mode)
1082     return dst_contents
1083
1084
1085 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1086     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1087     dst_blocks: List[LinesBlock] = []
1088     if mode.target_versions:
1089         versions = mode.target_versions
1090     else:
1091         future_imports = get_future_imports(src_node)
1092         versions = detect_target_versions(src_node, future_imports=future_imports)
1093
1094     normalize_fmt_off(src_node, preview=mode.preview)
1095     lines = LineGenerator(mode=mode)
1096     elt = EmptyLineTracker(mode=mode)
1097     split_line_features = {
1098         feature
1099         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1100         if supports_feature(versions, feature)
1101     }
1102     block: Optional[LinesBlock] = None
1103     for current_line in lines.visit(src_node):
1104         block = elt.maybe_empty_lines(current_line)
1105         dst_blocks.append(block)
1106         for line in transform_line(
1107             current_line, mode=mode, features=split_line_features
1108         ):
1109             block.content_lines.append(str(line))
1110     if dst_blocks:
1111         dst_blocks[-1].after = 0
1112     dst_contents = []
1113     for block in dst_blocks:
1114         dst_contents.extend(block.all_lines())
1115     if mode.preview and not dst_contents:
1116         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1117         # and check if normalized_content has more than one line
1118         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1119         if "\n" in normalized_content:
1120             return newline
1121         return ""
1122     return "".join(dst_contents)
1123
1124
1125 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1126     """Return a tuple of (decoded_contents, encoding, newline).
1127
1128     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1129     universal newlines (i.e. only contains LF).
1130     """
1131     srcbuf = io.BytesIO(src)
1132     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1133     if not lines:
1134         return "", encoding, "\n"
1135
1136     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1137     srcbuf.seek(0)
1138     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1139         return tiow.read(), encoding, newline
1140
1141
1142 def get_features_used(  # noqa: C901
1143     node: Node, *, future_imports: Optional[Set[str]] = None
1144 ) -> Set[Feature]:
1145     """Return a set of (relatively) new Python features used in this file.
1146
1147     Currently looking for:
1148     - f-strings;
1149     - self-documenting expressions in f-strings (f"{x=}");
1150     - underscores in numeric literals;
1151     - trailing commas after * or ** in function signatures and calls;
1152     - positional only arguments in function signatures and lambdas;
1153     - assignment expression;
1154     - relaxed decorator syntax;
1155     - usage of __future__ flags (annotations);
1156     - print / exec statements;
1157     """
1158     features: Set[Feature] = set()
1159     if future_imports:
1160         features |= {
1161             FUTURE_FLAG_TO_FEATURE[future_import]
1162             for future_import in future_imports
1163             if future_import in FUTURE_FLAG_TO_FEATURE
1164         }
1165
1166     for n in node.pre_order():
1167         if is_string_token(n):
1168             value_head = n.value[:2]
1169             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1170                 features.add(Feature.F_STRINGS)
1171                 if Feature.DEBUG_F_STRINGS not in features:
1172                     for span_beg, span_end in iter_fexpr_spans(n.value):
1173                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1174                             features.add(Feature.DEBUG_F_STRINGS)
1175                             break
1176
1177         elif is_number_token(n):
1178             if "_" in n.value:
1179                 features.add(Feature.NUMERIC_UNDERSCORES)
1180
1181         elif n.type == token.SLASH:
1182             if n.parent and n.parent.type in {
1183                 syms.typedargslist,
1184                 syms.arglist,
1185                 syms.varargslist,
1186             }:
1187                 features.add(Feature.POS_ONLY_ARGUMENTS)
1188
1189         elif n.type == token.COLONEQUAL:
1190             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1191
1192         elif n.type == syms.decorator:
1193             if len(n.children) > 1 and not is_simple_decorator_expression(
1194                 n.children[1]
1195             ):
1196                 features.add(Feature.RELAXED_DECORATORS)
1197
1198         elif (
1199             n.type in {syms.typedargslist, syms.arglist}
1200             and n.children
1201             and n.children[-1].type == token.COMMA
1202         ):
1203             if n.type == syms.typedargslist:
1204                 feature = Feature.TRAILING_COMMA_IN_DEF
1205             else:
1206                 feature = Feature.TRAILING_COMMA_IN_CALL
1207
1208             for ch in n.children:
1209                 if ch.type in STARS:
1210                     features.add(feature)
1211
1212                 if ch.type == syms.argument:
1213                     for argch in ch.children:
1214                         if argch.type in STARS:
1215                             features.add(feature)
1216
1217         elif (
1218             n.type in {syms.return_stmt, syms.yield_expr}
1219             and len(n.children) >= 2
1220             and n.children[1].type == syms.testlist_star_expr
1221             and any(child.type == syms.star_expr for child in n.children[1].children)
1222         ):
1223             features.add(Feature.UNPACKING_ON_FLOW)
1224
1225         elif (
1226             n.type == syms.annassign
1227             and len(n.children) >= 4
1228             and n.children[3].type == syms.testlist_star_expr
1229         ):
1230             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1231
1232         elif (
1233             n.type == syms.except_clause
1234             and len(n.children) >= 2
1235             and n.children[1].type == token.STAR
1236         ):
1237             features.add(Feature.EXCEPT_STAR)
1238
1239         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1240             child.type == syms.star_expr for child in n.children
1241         ):
1242             features.add(Feature.VARIADIC_GENERICS)
1243
1244         elif (
1245             n.type == syms.tname_star
1246             and len(n.children) == 3
1247             and n.children[2].type == syms.star_expr
1248         ):
1249             features.add(Feature.VARIADIC_GENERICS)
1250
1251     return features
1252
1253
1254 def detect_target_versions(
1255     node: Node, *, future_imports: Optional[Set[str]] = None
1256 ) -> Set[TargetVersion]:
1257     """Detect the version to target based on the nodes used."""
1258     features = get_features_used(node, future_imports=future_imports)
1259     return {
1260         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1261     }
1262
1263
1264 def get_future_imports(node: Node) -> Set[str]:
1265     """Return a set of __future__ imports in the file."""
1266     imports: Set[str] = set()
1267
1268     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1269         for child in children:
1270             if isinstance(child, Leaf):
1271                 if child.type == token.NAME:
1272                     yield child.value
1273
1274             elif child.type == syms.import_as_name:
1275                 orig_name = child.children[0]
1276                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1277                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1278                 yield orig_name.value
1279
1280             elif child.type == syms.import_as_names:
1281                 yield from get_imports_from_children(child.children)
1282
1283             else:
1284                 raise AssertionError("Invalid syntax parsing imports")
1285
1286     for child in node.children:
1287         if child.type != syms.simple_stmt:
1288             break
1289
1290         first_child = child.children[0]
1291         if isinstance(first_child, Leaf):
1292             # Continue looking if we see a docstring; otherwise stop.
1293             if (
1294                 len(child.children) == 2
1295                 and first_child.type == token.STRING
1296                 and child.children[1].type == token.NEWLINE
1297             ):
1298                 continue
1299
1300             break
1301
1302         elif first_child.type == syms.import_from:
1303             module_name = first_child.children[1]
1304             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1305                 break
1306
1307             imports |= set(get_imports_from_children(first_child.children[3:]))
1308         else:
1309             break
1310
1311     return imports
1312
1313
1314 def assert_equivalent(src: str, dst: str) -> None:
1315     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1316     try:
1317         src_ast = parse_ast(src)
1318     except Exception as exc:
1319         raise AssertionError(
1320             "cannot use --safe with this file; failed to parse source file AST: "
1321             f"{exc}\n"
1322             "This could be caused by running Black with an older Python version "
1323             "that does not support new syntax used in your source file."
1324         ) from exc
1325
1326     try:
1327         dst_ast = parse_ast(dst)
1328     except Exception as exc:
1329         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1330         raise AssertionError(
1331             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1332             "Please report a bug on https://github.com/psf/black/issues.  "
1333             f"This invalid output might be helpful: {log}"
1334         ) from None
1335
1336     src_ast_str = "\n".join(stringify_ast(src_ast))
1337     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1338     if src_ast_str != dst_ast_str:
1339         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1340         raise AssertionError(
1341             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1342             " source.  Please report a bug on "
1343             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1344         ) from None
1345
1346
1347 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1348     """Raise AssertionError if `dst` reformats differently the second time."""
1349     # We shouldn't call format_str() here, because that formats the string
1350     # twice and may hide a bug where we bounce back and forth between two
1351     # versions.
1352     newdst = _format_str_once(dst, mode=mode)
1353     if dst != newdst:
1354         log = dump_to_file(
1355             str(mode),
1356             diff(src, dst, "source", "first pass"),
1357             diff(dst, newdst, "first pass", "second pass"),
1358         )
1359         raise AssertionError(
1360             "INTERNAL ERROR: Black produced different code on the second pass of the"
1361             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1362             f"  This diff might be helpful: {log}"
1363         ) from None
1364
1365
1366 @contextmanager
1367 def nullcontext() -> Iterator[None]:
1368     """Return an empty context manager.
1369
1370     To be used like `nullcontext` in Python 3.7.
1371     """
1372     yield
1373
1374
1375 def patch_click() -> None:
1376     """Make Click not crash on Python 3.6 with LANG=C.
1377
1378     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1379     default which restricts paths that it can access during the lifetime of the
1380     application.  Click refuses to work in this scenario by raising a RuntimeError.
1381
1382     In case of Black the likelihood that non-ASCII characters are going to be used in
1383     file paths is minimal since it's Python source code.  Moreover, this crash was
1384     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1385     """
1386     modules: List[Any] = []
1387     try:
1388         from click import core
1389     except ImportError:
1390         pass
1391     else:
1392         modules.append(core)
1393     try:
1394         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1395         # older versions installed.
1396         from click import _unicodefun  # type: ignore
1397     except ImportError:
1398         pass
1399     else:
1400         modules.append(_unicodefun)
1401
1402     for module in modules:
1403         if hasattr(module, "_verify_python3_env"):
1404             module._verify_python3_env = lambda: None
1405         if hasattr(module, "_verify_python_env"):
1406             module._verify_python_env = lambda: None
1407
1408
1409 def patched_main() -> None:
1410     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1411     # environments so just assume we always need to call it if frozen.
1412     if getattr(sys, "frozen", False):
1413         from multiprocessing import freeze_support
1414
1415         freeze_support()
1416
1417     patch_click()
1418     main()
1419
1420
1421 if __name__ == "__main__":
1422     patched_main()