]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Enforce empty lines before classes/functions with sticky leading comments. (#3302)
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
34
35 from _black_version import version as __version__
36 from black.cache import Cache, get_cache_info, read_cache, write_cache
37 from black.comments import normalize_fmt_off
38 from black.const import (
39     DEFAULT_EXCLUDES,
40     DEFAULT_INCLUDES,
41     DEFAULT_LINE_LENGTH,
42     STDIN_PLACEHOLDER,
43 )
44 from black.files import (
45     find_project_root,
46     find_pyproject_toml,
47     find_user_pyproject_toml,
48     gen_python_files,
49     get_gitignore,
50     normalize_path_maybe_ignore,
51     parse_pyproject_toml,
52     wrap_stream_for_windows,
53 )
54 from black.handle_ipynb_magics import (
55     PYTHON_CELL_MAGICS,
56     TRANSFORMED_MAGICS,
57     jupyter_dependencies_are_installed,
58     mask_cell,
59     put_trailing_semicolon_back,
60     remove_trailing_semicolon,
61     unmask_cell,
62 )
63 from black.linegen import LN, LineGenerator, transform_line
64 from black.lines import EmptyLineTracker, LinesBlock
65 from black.mode import (
66     FUTURE_FLAG_TO_FEATURE,
67     VERSION_TO_FEATURES,
68     Feature,
69     Mode,
70     TargetVersion,
71     supports_feature,
72 )
73 from black.nodes import (
74     STARS,
75     is_number_token,
76     is_simple_decorator_expression,
77     is_string_token,
78     syms,
79 )
80 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
81 from black.parsing import InvalidInput  # noqa F401
82 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
83 from black.report import Changed, NothingChanged, Report
84 from black.trans import iter_fexpr_spans
85 from blib2to3.pgen2 import token
86 from blib2to3.pytree import Leaf, Node
87
88 COMPILED = Path(__file__).suffix in (".pyd", ".so")
89
90 # types
91 FileContent = str
92 Encoding = str
93 NewLine = str
94
95
96 class WriteBack(Enum):
97     NO = 0
98     YES = 1
99     DIFF = 2
100     CHECK = 3
101     COLOR_DIFF = 4
102
103     @classmethod
104     def from_configuration(
105         cls, *, check: bool, diff: bool, color: bool = False
106     ) -> "WriteBack":
107         if check and not diff:
108             return cls.CHECK
109
110         if diff and color:
111             return cls.COLOR_DIFF
112
113         return cls.DIFF if diff else cls.YES
114
115
116 # Legacy name, left for integrations.
117 FileMode = Mode
118
119
120 def read_pyproject_toml(
121     ctx: click.Context, param: click.Parameter, value: Optional[str]
122 ) -> Optional[str]:
123     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
124
125     Returns the path to a successfully found and read configuration file, None
126     otherwise.
127     """
128     if not value:
129         value = find_pyproject_toml(ctx.params.get("src", ()))
130         if value is None:
131             return None
132
133     try:
134         config = parse_pyproject_toml(value)
135     except (OSError, ValueError) as e:
136         raise click.FileError(
137             filename=value, hint=f"Error reading configuration file: {e}"
138         ) from None
139
140     if not config:
141         return None
142     else:
143         # Sanitize the values to be Click friendly. For more information please see:
144         # https://github.com/psf/black/issues/1458
145         # https://github.com/pallets/click/issues/1567
146         config = {
147             k: str(v) if not isinstance(v, (list, dict)) else v
148             for k, v in config.items()
149         }
150
151     target_version = config.get("target_version")
152     if target_version is not None and not isinstance(target_version, list):
153         raise click.BadOptionUsage(
154             "target-version", "Config key target-version must be a list"
155         )
156
157     default_map: Dict[str, Any] = {}
158     if ctx.default_map:
159         default_map.update(ctx.default_map)
160     default_map.update(config)
161
162     ctx.default_map = default_map
163     return value
164
165
166 def target_version_option_callback(
167     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
168 ) -> List[TargetVersion]:
169     """Compute the target versions from a --target-version flag.
170
171     This is its own function because mypy couldn't infer the type correctly
172     when it was a lambda, causing mypyc trouble.
173     """
174     return [TargetVersion[val.upper()] for val in v]
175
176
177 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
178     """Compile a regular expression string in `regex`.
179
180     If it contains newlines, use verbose mode.
181     """
182     if "\n" in regex:
183         regex = "(?x)" + regex
184     compiled: Pattern[str] = re.compile(regex)
185     return compiled
186
187
188 def validate_regex(
189     ctx: click.Context,
190     param: click.Parameter,
191     value: Optional[str],
192 ) -> Optional[Pattern[str]]:
193     try:
194         return re_compile_maybe_verbose(value) if value is not None else None
195     except re.error as e:
196         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
197
198
199 @click.command(
200     context_settings={"help_option_names": ["-h", "--help"]},
201     # While Click does set this field automatically using the docstring, mypyc
202     # (annoyingly) strips 'em so we need to set it here too.
203     help="The uncompromising code formatter.",
204 )
205 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
206 @click.option(
207     "-l",
208     "--line-length",
209     type=int,
210     default=DEFAULT_LINE_LENGTH,
211     help="How many characters per line to allow.",
212     show_default=True,
213 )
214 @click.option(
215     "-t",
216     "--target-version",
217     type=click.Choice([v.name.lower() for v in TargetVersion]),
218     callback=target_version_option_callback,
219     multiple=True,
220     help=(
221         "Python versions that should be supported by Black's output. [default: per-file"
222         " auto-detection]"
223     ),
224 )
225 @click.option(
226     "--pyi",
227     is_flag=True,
228     help=(
229         "Format all input files like typing stubs regardless of file extension (useful"
230         " when piping source on standard input)."
231     ),
232 )
233 @click.option(
234     "--ipynb",
235     is_flag=True,
236     help=(
237         "Format all input files like Jupyter Notebooks regardless of file extension "
238         "(useful when piping source on standard input)."
239     ),
240 )
241 @click.option(
242     "--python-cell-magics",
243     multiple=True,
244     help=(
245         "When processing Jupyter Notebooks, add the given magic to the list"
246         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
247         " Useful for formatting cells with custom python magics."
248     ),
249     default=[],
250 )
251 @click.option(
252     "-x",
253     "--skip-source-first-line",
254     is_flag=True,
255     help="Skip the first line of the source code.",
256 )
257 @click.option(
258     "-S",
259     "--skip-string-normalization",
260     is_flag=True,
261     help="Don't normalize string quotes or prefixes.",
262 )
263 @click.option(
264     "-C",
265     "--skip-magic-trailing-comma",
266     is_flag=True,
267     help="Don't use trailing commas as a reason to split lines.",
268 )
269 @click.option(
270     "--experimental-string-processing",
271     is_flag=True,
272     hidden=True,
273     help="(DEPRECATED and now included in --preview) Normalize string literals.",
274 )
275 @click.option(
276     "--preview",
277     is_flag=True,
278     help=(
279         "Enable potentially disruptive style changes that may be added to Black's main"
280         " functionality in the next major release."
281     ),
282 )
283 @click.option(
284     "--check",
285     is_flag=True,
286     help=(
287         "Don't write the files back, just return the status. Return code 0 means"
288         " nothing would change. Return code 1 means some files would be reformatted."
289         " Return code 123 means there was an internal error."
290     ),
291 )
292 @click.option(
293     "--diff",
294     is_flag=True,
295     help="Don't write the files back, just output a diff for each file on stdout.",
296 )
297 @click.option(
298     "--color/--no-color",
299     is_flag=True,
300     help="Show colored diff. Only applies when `--diff` is given.",
301 )
302 @click.option(
303     "--fast/--safe",
304     is_flag=True,
305     help="If --fast given, skip temporary sanity checks. [default: --safe]",
306 )
307 @click.option(
308     "--required-version",
309     type=str,
310     help=(
311         "Require a specific version of Black to be running (useful for unifying results"
312         " across many environments e.g. with a pyproject.toml file). It can be"
313         " either a major version number or an exact version."
314     ),
315 )
316 @click.option(
317     "--include",
318     type=str,
319     default=DEFAULT_INCLUDES,
320     callback=validate_regex,
321     help=(
322         "A regular expression that matches files and directories that should be"
323         " included on recursive searches. An empty value means all files are included"
324         " regardless of the name. Use forward slashes for directories on all platforms"
325         " (Windows, too). Exclusions are calculated first, inclusions later."
326     ),
327     show_default=True,
328 )
329 @click.option(
330     "--exclude",
331     type=str,
332     callback=validate_regex,
333     help=(
334         "A regular expression that matches files and directories that should be"
335         " excluded on recursive searches. An empty value means no paths are excluded."
336         " Use forward slashes for directories on all platforms (Windows, too)."
337         " Exclusions are calculated first, inclusions later. [default:"
338         f" {DEFAULT_EXCLUDES}]"
339     ),
340     show_default=False,
341 )
342 @click.option(
343     "--extend-exclude",
344     type=str,
345     callback=validate_regex,
346     help=(
347         "Like --exclude, but adds additional files and directories on top of the"
348         " excluded ones. (Useful if you simply want to add to the default)"
349     ),
350 )
351 @click.option(
352     "--force-exclude",
353     type=str,
354     callback=validate_regex,
355     help=(
356         "Like --exclude, but files and directories matching this regex will be "
357         "excluded even when they are passed explicitly as arguments."
358     ),
359 )
360 @click.option(
361     "--stdin-filename",
362     type=str,
363     help=(
364         "The name of the file when passing it through stdin. Useful to make "
365         "sure Black will respect --force-exclude option on some "
366         "editors that rely on using stdin."
367     ),
368 )
369 @click.option(
370     "-W",
371     "--workers",
372     type=click.IntRange(min=1),
373     default=None,
374     help="Number of parallel workers [default: number of CPUs in the system]",
375 )
376 @click.option(
377     "-q",
378     "--quiet",
379     is_flag=True,
380     help=(
381         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
382         " those with 2>/dev/null."
383     ),
384 )
385 @click.option(
386     "-v",
387     "--verbose",
388     is_flag=True,
389     help=(
390         "Also emit messages to stderr about files that were not changed or were ignored"
391         " due to exclusion patterns."
392     ),
393 )
394 @click.version_option(
395     version=__version__,
396     message=(
397         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
398         f"Python ({platform.python_implementation()}) {platform.python_version()}"
399     ),
400 )
401 @click.argument(
402     "src",
403     nargs=-1,
404     type=click.Path(
405         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
406     ),
407     is_eager=True,
408     metavar="SRC ...",
409 )
410 @click.option(
411     "--config",
412     type=click.Path(
413         exists=True,
414         file_okay=True,
415         dir_okay=False,
416         readable=True,
417         allow_dash=False,
418         path_type=str,
419     ),
420     is_eager=True,
421     callback=read_pyproject_toml,
422     help="Read configuration from FILE path.",
423 )
424 @click.pass_context
425 def main(  # noqa: C901
426     ctx: click.Context,
427     code: Optional[str],
428     line_length: int,
429     target_version: List[TargetVersion],
430     check: bool,
431     diff: bool,
432     color: bool,
433     fast: bool,
434     pyi: bool,
435     ipynb: bool,
436     python_cell_magics: Sequence[str],
437     skip_source_first_line: bool,
438     skip_string_normalization: bool,
439     skip_magic_trailing_comma: bool,
440     experimental_string_processing: bool,
441     preview: bool,
442     quiet: bool,
443     verbose: bool,
444     required_version: Optional[str],
445     include: Pattern[str],
446     exclude: Optional[Pattern[str]],
447     extend_exclude: Optional[Pattern[str]],
448     force_exclude: Optional[Pattern[str]],
449     stdin_filename: Optional[str],
450     workers: Optional[int],
451     src: Tuple[str, ...],
452     config: Optional[str],
453 ) -> None:
454     """The uncompromising code formatter."""
455     ctx.ensure_object(dict)
456
457     if src and code is not None:
458         out(
459             main.get_usage(ctx)
460             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
461         )
462         ctx.exit(1)
463     if not src and code is None:
464         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
465         ctx.exit(1)
466
467     root, method = (
468         find_project_root(src, stdin_filename) if code is None else (None, None)
469     )
470     ctx.obj["root"] = root
471
472     if verbose:
473         if root:
474             out(
475                 f"Identified `{root}` as project root containing a {method}.",
476                 fg="blue",
477             )
478
479             normalized = [
480                 (source, source)
481                 if source == "-"
482                 else (normalize_path_maybe_ignore(Path(source), root), source)
483                 for source in src
484             ]
485             srcs_string = ", ".join(
486                 [
487                     f'"{_norm}"'
488                     if _norm
489                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
490                     for _norm, source in normalized
491                 ]
492             )
493             out(f"Sources to be formatted: {srcs_string}", fg="blue")
494
495         if config:
496             config_source = ctx.get_parameter_source("config")
497             user_level_config = str(find_user_pyproject_toml())
498             if config == user_level_config:
499                 out(
500                     "Using configuration from user-level config at "
501                     f"'{user_level_config}'.",
502                     fg="blue",
503                 )
504             elif config_source in (
505                 ParameterSource.DEFAULT,
506                 ParameterSource.DEFAULT_MAP,
507             ):
508                 out("Using configuration from project root.", fg="blue")
509             else:
510                 out(f"Using configuration in '{config}'.", fg="blue")
511
512     error_msg = "Oh no! 💥 💔 💥"
513     if (
514         required_version
515         and required_version != __version__
516         and required_version != __version__.split(".")[0]
517     ):
518         err(
519             f"{error_msg} The required version `{required_version}` does not match"
520             f" the running version `{__version__}`!"
521         )
522         ctx.exit(1)
523     if ipynb and pyi:
524         err("Cannot pass both `pyi` and `ipynb` flags!")
525         ctx.exit(1)
526
527     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
528     if target_version:
529         versions = set(target_version)
530     else:
531         # We'll autodetect later.
532         versions = set()
533     mode = Mode(
534         target_versions=versions,
535         line_length=line_length,
536         is_pyi=pyi,
537         is_ipynb=ipynb,
538         skip_source_first_line=skip_source_first_line,
539         string_normalization=not skip_string_normalization,
540         magic_trailing_comma=not skip_magic_trailing_comma,
541         experimental_string_processing=experimental_string_processing,
542         preview=preview,
543         python_cell_magics=set(python_cell_magics),
544     )
545
546     if code is not None:
547         # Run in quiet mode by default with -c; the extra output isn't useful.
548         # You can still pass -v to get verbose output.
549         quiet = True
550
551     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
552
553     if code is not None:
554         reformat_code(
555             content=code, fast=fast, write_back=write_back, mode=mode, report=report
556         )
557     else:
558         try:
559             sources = get_sources(
560                 ctx=ctx,
561                 src=src,
562                 quiet=quiet,
563                 verbose=verbose,
564                 include=include,
565                 exclude=exclude,
566                 extend_exclude=extend_exclude,
567                 force_exclude=force_exclude,
568                 report=report,
569                 stdin_filename=stdin_filename,
570             )
571         except GitWildMatchPatternError:
572             ctx.exit(1)
573
574         path_empty(
575             sources,
576             "No Python files are present to be formatted. Nothing to do 😴",
577             quiet,
578             verbose,
579             ctx,
580         )
581
582         if len(sources) == 1:
583             reformat_one(
584                 src=sources.pop(),
585                 fast=fast,
586                 write_back=write_back,
587                 mode=mode,
588                 report=report,
589             )
590         else:
591             from black.concurrency import reformat_many
592
593             reformat_many(
594                 sources=sources,
595                 fast=fast,
596                 write_back=write_back,
597                 mode=mode,
598                 report=report,
599                 workers=workers,
600             )
601
602     if verbose or not quiet:
603         if code is None and (verbose or report.change_count or report.failure_count):
604             out()
605         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
606         if code is None:
607             click.echo(str(report), err=True)
608     ctx.exit(report.return_code)
609
610
611 def get_sources(
612     *,
613     ctx: click.Context,
614     src: Tuple[str, ...],
615     quiet: bool,
616     verbose: bool,
617     include: Pattern[str],
618     exclude: Optional[Pattern[str]],
619     extend_exclude: Optional[Pattern[str]],
620     force_exclude: Optional[Pattern[str]],
621     report: "Report",
622     stdin_filename: Optional[str],
623 ) -> Set[Path]:
624     """Compute the set of files to be formatted."""
625     sources: Set[Path] = set()
626     root = ctx.obj["root"]
627
628     for s in src:
629         if s == "-" and stdin_filename:
630             p = Path(stdin_filename)
631             is_stdin = True
632         else:
633             p = Path(s)
634             is_stdin = False
635
636         if is_stdin or p.is_file():
637             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
638             if normalized_path is None:
639                 continue
640
641             normalized_path = "/" + normalized_path
642             # Hard-exclude any files that matches the `--force-exclude` regex.
643             if force_exclude:
644                 force_exclude_match = force_exclude.search(normalized_path)
645             else:
646                 force_exclude_match = None
647             if force_exclude_match and force_exclude_match.group(0):
648                 report.path_ignored(p, "matches the --force-exclude regular expression")
649                 continue
650
651             if is_stdin:
652                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
653
654             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
655                 verbose=verbose, quiet=quiet
656             ):
657                 continue
658
659             sources.add(p)
660         elif p.is_dir():
661             if exclude is None:
662                 exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
663                 gitignore = get_gitignore(root)
664                 p_gitignore = get_gitignore(p)
665                 # No need to use p's gitignore if it is identical to root's gitignore
666                 # (i.e. root and p point to the same directory).
667                 if gitignore != p_gitignore:
668                     gitignore += p_gitignore
669             else:
670                 gitignore = None
671             sources.update(
672                 gen_python_files(
673                     p.iterdir(),
674                     ctx.obj["root"],
675                     include,
676                     exclude,
677                     extend_exclude,
678                     force_exclude,
679                     report,
680                     gitignore,
681                     verbose=verbose,
682                     quiet=quiet,
683                 )
684             )
685         elif s == "-":
686             sources.add(p)
687         else:
688             err(f"invalid path: {s}")
689     return sources
690
691
692 def path_empty(
693     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
694 ) -> None:
695     """
696     Exit if there is no `src` provided for formatting
697     """
698     if not src:
699         if verbose or not quiet:
700             out(msg)
701         ctx.exit(0)
702
703
704 def reformat_code(
705     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
706 ) -> None:
707     """
708     Reformat and print out `content` without spawning child processes.
709     Similar to `reformat_one`, but for string content.
710
711     `fast`, `write_back`, and `mode` options are passed to
712     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
713     """
714     path = Path("<string>")
715     try:
716         changed = Changed.NO
717         if format_stdin_to_stdout(
718             content=content, fast=fast, write_back=write_back, mode=mode
719         ):
720             changed = Changed.YES
721         report.done(path, changed)
722     except Exception as exc:
723         if report.verbose:
724             traceback.print_exc()
725         report.failed(path, str(exc))
726
727
728 # diff-shades depends on being to monkeypatch this function to operate. I know it's
729 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
730 @mypyc_attr(patchable=True)
731 def reformat_one(
732     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
733 ) -> None:
734     """Reformat a single file under `src` without spawning child processes.
735
736     `fast`, `write_back`, and `mode` options are passed to
737     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
738     """
739     try:
740         changed = Changed.NO
741
742         if str(src) == "-":
743             is_stdin = True
744         elif str(src).startswith(STDIN_PLACEHOLDER):
745             is_stdin = True
746             # Use the original name again in case we want to print something
747             # to the user
748             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
749         else:
750             is_stdin = False
751
752         if is_stdin:
753             if src.suffix == ".pyi":
754                 mode = replace(mode, is_pyi=True)
755             elif src.suffix == ".ipynb":
756                 mode = replace(mode, is_ipynb=True)
757             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
758                 changed = Changed.YES
759         else:
760             cache: Cache = {}
761             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
762                 cache = read_cache(mode)
763                 res_src = src.resolve()
764                 res_src_s = str(res_src)
765                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
766                     changed = Changed.CACHED
767             if changed is not Changed.CACHED and format_file_in_place(
768                 src, fast=fast, write_back=write_back, mode=mode
769             ):
770                 changed = Changed.YES
771             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
772                 write_back is WriteBack.CHECK and changed is Changed.NO
773             ):
774                 write_cache(cache, [src], mode)
775         report.done(src, changed)
776     except Exception as exc:
777         if report.verbose:
778             traceback.print_exc()
779         report.failed(src, str(exc))
780
781
782 def format_file_in_place(
783     src: Path,
784     fast: bool,
785     mode: Mode,
786     write_back: WriteBack = WriteBack.NO,
787     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
788 ) -> bool:
789     """Format file under `src` path. Return True if changed.
790
791     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
792     code to the file.
793     `mode` and `fast` options are passed to :func:`format_file_contents`.
794     """
795     if src.suffix == ".pyi":
796         mode = replace(mode, is_pyi=True)
797     elif src.suffix == ".ipynb":
798         mode = replace(mode, is_ipynb=True)
799
800     then = datetime.utcfromtimestamp(src.stat().st_mtime)
801     header = b""
802     with open(src, "rb") as buf:
803         if mode.skip_source_first_line:
804             header = buf.readline()
805         src_contents, encoding, newline = decode_bytes(buf.read())
806     try:
807         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
808     except NothingChanged:
809         return False
810     except JSONDecodeError:
811         raise ValueError(
812             f"File '{src}' cannot be parsed as valid Jupyter notebook."
813         ) from None
814     src_contents = header.decode(encoding) + src_contents
815     dst_contents = header.decode(encoding) + dst_contents
816
817     if write_back == WriteBack.YES:
818         with open(src, "w", encoding=encoding, newline=newline) as f:
819             f.write(dst_contents)
820     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
821         now = datetime.utcnow()
822         src_name = f"{src}\t{then} +0000"
823         dst_name = f"{src}\t{now} +0000"
824         if mode.is_ipynb:
825             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
826         else:
827             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
828
829         if write_back == WriteBack.COLOR_DIFF:
830             diff_contents = color_diff(diff_contents)
831
832         with lock or nullcontext():
833             f = io.TextIOWrapper(
834                 sys.stdout.buffer,
835                 encoding=encoding,
836                 newline=newline,
837                 write_through=True,
838             )
839             f = wrap_stream_for_windows(f)
840             f.write(diff_contents)
841             f.detach()
842
843     return True
844
845
846 def format_stdin_to_stdout(
847     fast: bool,
848     *,
849     content: Optional[str] = None,
850     write_back: WriteBack = WriteBack.NO,
851     mode: Mode,
852 ) -> bool:
853     """Format file on stdin. Return True if changed.
854
855     If content is None, it's read from sys.stdin.
856
857     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
858     write a diff to stdout. The `mode` argument is passed to
859     :func:`format_file_contents`.
860     """
861     then = datetime.utcnow()
862
863     if content is None:
864         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
865     else:
866         src, encoding, newline = content, "utf-8", ""
867
868     dst = src
869     try:
870         dst = format_file_contents(src, fast=fast, mode=mode)
871         return True
872
873     except NothingChanged:
874         return False
875
876     finally:
877         f = io.TextIOWrapper(
878             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
879         )
880         if write_back == WriteBack.YES:
881             # Make sure there's a newline after the content
882             if dst and dst[-1] != "\n":
883                 dst += "\n"
884             f.write(dst)
885         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
886             now = datetime.utcnow()
887             src_name = f"STDIN\t{then} +0000"
888             dst_name = f"STDOUT\t{now} +0000"
889             d = diff(src, dst, src_name, dst_name)
890             if write_back == WriteBack.COLOR_DIFF:
891                 d = color_diff(d)
892                 f = wrap_stream_for_windows(f)
893             f.write(d)
894         f.detach()
895
896
897 def check_stability_and_equivalence(
898     src_contents: str, dst_contents: str, *, mode: Mode
899 ) -> None:
900     """Perform stability and equivalence checks.
901
902     Raise AssertionError if source and destination contents are not
903     equivalent, or if a second pass of the formatter would format the
904     content differently.
905     """
906     assert_equivalent(src_contents, dst_contents)
907     assert_stable(src_contents, dst_contents, mode=mode)
908
909
910 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
911     """Reformat contents of a file and return new contents.
912
913     If `fast` is False, additionally confirm that the reformatted code is
914     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
915     `mode` is passed to :func:`format_str`.
916     """
917     if not src_contents.strip():
918         raise NothingChanged
919
920     if mode.is_ipynb:
921         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
922     else:
923         dst_contents = format_str(src_contents, mode=mode)
924     if src_contents == dst_contents:
925         raise NothingChanged
926
927     if not fast and not mode.is_ipynb:
928         # Jupyter notebooks will already have been checked above.
929         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
930     return dst_contents
931
932
933 def validate_cell(src: str, mode: Mode) -> None:
934     """Check that cell does not already contain TransformerManager transformations,
935     or non-Python cell magics, which might cause tokenizer_rt to break because of
936     indentations.
937
938     If a cell contains ``!ls``, then it'll be transformed to
939     ``get_ipython().system('ls')``. However, if the cell originally contained
940     ``get_ipython().system('ls')``, then it would get transformed in the same way:
941
942         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
943         "get_ipython().system('ls')\n"
944         >>> TransformerManager().transform_cell("!ls")
945         "get_ipython().system('ls')\n"
946
947     Due to the impossibility of safely roundtripping in such situations, cells
948     containing transformed magics will be ignored.
949     """
950     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
951         raise NothingChanged
952     if (
953         src[:2] == "%%"
954         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
955     ):
956         raise NothingChanged
957
958
959 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
960     """Format code in given cell of Jupyter notebook.
961
962     General idea is:
963
964       - if cell has trailing semicolon, remove it;
965       - if cell has IPython magics, mask them;
966       - format cell;
967       - reinstate IPython magics;
968       - reinstate trailing semicolon (if originally present);
969       - strip trailing newlines.
970
971     Cells with syntax errors will not be processed, as they
972     could potentially be automagics or multi-line magics, which
973     are currently not supported.
974     """
975     validate_cell(src, mode)
976     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
977         src
978     )
979     try:
980         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
981     except SyntaxError:
982         raise NothingChanged from None
983     masked_dst = format_str(masked_src, mode=mode)
984     if not fast:
985         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
986     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
987     dst = put_trailing_semicolon_back(
988         dst_without_trailing_semicolon, has_trailing_semicolon
989     )
990     dst = dst.rstrip("\n")
991     if dst == src:
992         raise NothingChanged from None
993     return dst
994
995
996 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
997     """If notebook is marked as non-Python, don't format it.
998
999     All notebook metadata fields are optional, see
1000     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1001     if a notebook has empty metadata, we will try to parse it anyway.
1002     """
1003     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1004     if language is not None and language != "python":
1005         raise NothingChanged from None
1006
1007
1008 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1009     """Format Jupyter notebook.
1010
1011     Operate cell-by-cell, only on code cells, only for Python notebooks.
1012     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1013     """
1014     trailing_newline = src_contents[-1] == "\n"
1015     modified = False
1016     nb = json.loads(src_contents)
1017     validate_metadata(nb)
1018     for cell in nb["cells"]:
1019         if cell.get("cell_type", None) == "code":
1020             try:
1021                 src = "".join(cell["source"])
1022                 dst = format_cell(src, fast=fast, mode=mode)
1023             except NothingChanged:
1024                 pass
1025             else:
1026                 cell["source"] = dst.splitlines(keepends=True)
1027                 modified = True
1028     if modified:
1029         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1030         if trailing_newline:
1031             dst_contents = dst_contents + "\n"
1032         return dst_contents
1033     else:
1034         raise NothingChanged
1035
1036
1037 def format_str(src_contents: str, *, mode: Mode) -> str:
1038     """Reformat a string and return new contents.
1039
1040     `mode` determines formatting options, such as how many characters per line are
1041     allowed.  Example:
1042
1043     >>> import black
1044     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1045     def f(arg: str = "") -> None:
1046         ...
1047
1048     A more complex example:
1049
1050     >>> print(
1051     ...   black.format_str(
1052     ...     "def f(arg:str='')->None: hey",
1053     ...     mode=black.Mode(
1054     ...       target_versions={black.TargetVersion.PY36},
1055     ...       line_length=10,
1056     ...       string_normalization=False,
1057     ...       is_pyi=False,
1058     ...     ),
1059     ...   ),
1060     ... )
1061     def f(
1062         arg: str = '',
1063     ) -> None:
1064         hey
1065
1066     """
1067     dst_contents = _format_str_once(src_contents, mode=mode)
1068     # Forced second pass to work around optional trailing commas (becoming
1069     # forced trailing commas on pass 2) interacting differently with optional
1070     # parentheses.  Admittedly ugly.
1071     if src_contents != dst_contents:
1072         return _format_str_once(dst_contents, mode=mode)
1073     return dst_contents
1074
1075
1076 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1077     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1078     dst_blocks: List[LinesBlock] = []
1079     if mode.target_versions:
1080         versions = mode.target_versions
1081     else:
1082         future_imports = get_future_imports(src_node)
1083         versions = detect_target_versions(src_node, future_imports=future_imports)
1084
1085     normalize_fmt_off(src_node, preview=mode.preview)
1086     lines = LineGenerator(mode=mode)
1087     elt = EmptyLineTracker(mode=mode)
1088     split_line_features = {
1089         feature
1090         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1091         if supports_feature(versions, feature)
1092     }
1093     block: Optional[LinesBlock] = None
1094     for current_line in lines.visit(src_node):
1095         block = elt.maybe_empty_lines(current_line)
1096         dst_blocks.append(block)
1097         for line in transform_line(
1098             current_line, mode=mode, features=split_line_features
1099         ):
1100             block.content_lines.append(str(line))
1101     if dst_blocks:
1102         dst_blocks[-1].after = 0
1103     dst_contents = []
1104     for block in dst_blocks:
1105         dst_contents.extend(block.all_lines())
1106     return "".join(dst_contents)
1107
1108
1109 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1110     """Return a tuple of (decoded_contents, encoding, newline).
1111
1112     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1113     universal newlines (i.e. only contains LF).
1114     """
1115     srcbuf = io.BytesIO(src)
1116     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1117     if not lines:
1118         return "", encoding, "\n"
1119
1120     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1121     srcbuf.seek(0)
1122     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1123         return tiow.read(), encoding, newline
1124
1125
1126 def get_features_used(  # noqa: C901
1127     node: Node, *, future_imports: Optional[Set[str]] = None
1128 ) -> Set[Feature]:
1129     """Return a set of (relatively) new Python features used in this file.
1130
1131     Currently looking for:
1132     - f-strings;
1133     - self-documenting expressions in f-strings (f"{x=}");
1134     - underscores in numeric literals;
1135     - trailing commas after * or ** in function signatures and calls;
1136     - positional only arguments in function signatures and lambdas;
1137     - assignment expression;
1138     - relaxed decorator syntax;
1139     - usage of __future__ flags (annotations);
1140     - print / exec statements;
1141     """
1142     features: Set[Feature] = set()
1143     if future_imports:
1144         features |= {
1145             FUTURE_FLAG_TO_FEATURE[future_import]
1146             for future_import in future_imports
1147             if future_import in FUTURE_FLAG_TO_FEATURE
1148         }
1149
1150     for n in node.pre_order():
1151         if is_string_token(n):
1152             value_head = n.value[:2]
1153             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1154                 features.add(Feature.F_STRINGS)
1155                 if Feature.DEBUG_F_STRINGS not in features:
1156                     for span_beg, span_end in iter_fexpr_spans(n.value):
1157                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1158                             features.add(Feature.DEBUG_F_STRINGS)
1159                             break
1160
1161         elif is_number_token(n):
1162             if "_" in n.value:
1163                 features.add(Feature.NUMERIC_UNDERSCORES)
1164
1165         elif n.type == token.SLASH:
1166             if n.parent and n.parent.type in {
1167                 syms.typedargslist,
1168                 syms.arglist,
1169                 syms.varargslist,
1170             }:
1171                 features.add(Feature.POS_ONLY_ARGUMENTS)
1172
1173         elif n.type == token.COLONEQUAL:
1174             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1175
1176         elif n.type == syms.decorator:
1177             if len(n.children) > 1 and not is_simple_decorator_expression(
1178                 n.children[1]
1179             ):
1180                 features.add(Feature.RELAXED_DECORATORS)
1181
1182         elif (
1183             n.type in {syms.typedargslist, syms.arglist}
1184             and n.children
1185             and n.children[-1].type == token.COMMA
1186         ):
1187             if n.type == syms.typedargslist:
1188                 feature = Feature.TRAILING_COMMA_IN_DEF
1189             else:
1190                 feature = Feature.TRAILING_COMMA_IN_CALL
1191
1192             for ch in n.children:
1193                 if ch.type in STARS:
1194                     features.add(feature)
1195
1196                 if ch.type == syms.argument:
1197                     for argch in ch.children:
1198                         if argch.type in STARS:
1199                             features.add(feature)
1200
1201         elif (
1202             n.type in {syms.return_stmt, syms.yield_expr}
1203             and len(n.children) >= 2
1204             and n.children[1].type == syms.testlist_star_expr
1205             and any(child.type == syms.star_expr for child in n.children[1].children)
1206         ):
1207             features.add(Feature.UNPACKING_ON_FLOW)
1208
1209         elif (
1210             n.type == syms.annassign
1211             and len(n.children) >= 4
1212             and n.children[3].type == syms.testlist_star_expr
1213         ):
1214             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1215
1216         elif (
1217             n.type == syms.except_clause
1218             and len(n.children) >= 2
1219             and n.children[1].type == token.STAR
1220         ):
1221             features.add(Feature.EXCEPT_STAR)
1222
1223         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1224             child.type == syms.star_expr for child in n.children
1225         ):
1226             features.add(Feature.VARIADIC_GENERICS)
1227
1228         elif (
1229             n.type == syms.tname_star
1230             and len(n.children) == 3
1231             and n.children[2].type == syms.star_expr
1232         ):
1233             features.add(Feature.VARIADIC_GENERICS)
1234
1235     return features
1236
1237
1238 def detect_target_versions(
1239     node: Node, *, future_imports: Optional[Set[str]] = None
1240 ) -> Set[TargetVersion]:
1241     """Detect the version to target based on the nodes used."""
1242     features = get_features_used(node, future_imports=future_imports)
1243     return {
1244         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1245     }
1246
1247
1248 def get_future_imports(node: Node) -> Set[str]:
1249     """Return a set of __future__ imports in the file."""
1250     imports: Set[str] = set()
1251
1252     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1253         for child in children:
1254             if isinstance(child, Leaf):
1255                 if child.type == token.NAME:
1256                     yield child.value
1257
1258             elif child.type == syms.import_as_name:
1259                 orig_name = child.children[0]
1260                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1261                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1262                 yield orig_name.value
1263
1264             elif child.type == syms.import_as_names:
1265                 yield from get_imports_from_children(child.children)
1266
1267             else:
1268                 raise AssertionError("Invalid syntax parsing imports")
1269
1270     for child in node.children:
1271         if child.type != syms.simple_stmt:
1272             break
1273
1274         first_child = child.children[0]
1275         if isinstance(first_child, Leaf):
1276             # Continue looking if we see a docstring; otherwise stop.
1277             if (
1278                 len(child.children) == 2
1279                 and first_child.type == token.STRING
1280                 and child.children[1].type == token.NEWLINE
1281             ):
1282                 continue
1283
1284             break
1285
1286         elif first_child.type == syms.import_from:
1287             module_name = first_child.children[1]
1288             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1289                 break
1290
1291             imports |= set(get_imports_from_children(first_child.children[3:]))
1292         else:
1293             break
1294
1295     return imports
1296
1297
1298 def assert_equivalent(src: str, dst: str) -> None:
1299     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1300     try:
1301         src_ast = parse_ast(src)
1302     except Exception as exc:
1303         raise AssertionError(
1304             "cannot use --safe with this file; failed to parse source file AST: "
1305             f"{exc}\n"
1306             "This could be caused by running Black with an older Python version "
1307             "that does not support new syntax used in your source file."
1308         ) from exc
1309
1310     try:
1311         dst_ast = parse_ast(dst)
1312     except Exception as exc:
1313         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1314         raise AssertionError(
1315             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1316             "Please report a bug on https://github.com/psf/black/issues.  "
1317             f"This invalid output might be helpful: {log}"
1318         ) from None
1319
1320     src_ast_str = "\n".join(stringify_ast(src_ast))
1321     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1322     if src_ast_str != dst_ast_str:
1323         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1324         raise AssertionError(
1325             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1326             " source.  Please report a bug on "
1327             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1328         ) from None
1329
1330
1331 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1332     """Raise AssertionError if `dst` reformats differently the second time."""
1333     # We shouldn't call format_str() here, because that formats the string
1334     # twice and may hide a bug where we bounce back and forth between two
1335     # versions.
1336     newdst = _format_str_once(dst, mode=mode)
1337     if dst != newdst:
1338         log = dump_to_file(
1339             str(mode),
1340             diff(src, dst, "source", "first pass"),
1341             diff(dst, newdst, "first pass", "second pass"),
1342         )
1343         raise AssertionError(
1344             "INTERNAL ERROR: Black produced different code on the second pass of the"
1345             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1346             f"  This diff might be helpful: {log}"
1347         ) from None
1348
1349
1350 @contextmanager
1351 def nullcontext() -> Iterator[None]:
1352     """Return an empty context manager.
1353
1354     To be used like `nullcontext` in Python 3.7.
1355     """
1356     yield
1357
1358
1359 def patch_click() -> None:
1360     """Make Click not crash on Python 3.6 with LANG=C.
1361
1362     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1363     default which restricts paths that it can access during the lifetime of the
1364     application.  Click refuses to work in this scenario by raising a RuntimeError.
1365
1366     In case of Black the likelihood that non-ASCII characters are going to be used in
1367     file paths is minimal since it's Python source code.  Moreover, this crash was
1368     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1369     """
1370     modules: List[Any] = []
1371     try:
1372         from click import core
1373     except ImportError:
1374         pass
1375     else:
1376         modules.append(core)
1377     try:
1378         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1379         # older versions installed.
1380         from click import _unicodefun  # type: ignore
1381     except ImportError:
1382         pass
1383     else:
1384         modules.append(_unicodefun)
1385
1386     for module in modules:
1387         if hasattr(module, "_verify_python3_env"):
1388             module._verify_python3_env = lambda: None
1389         if hasattr(module, "_verify_python_env"):
1390             module._verify_python_env = lambda: None
1391
1392
1393 def patched_main() -> None:
1394     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1395     # environments so just assume we always need to call it if frozen.
1396     if getattr(sys, "frozen", False):
1397         from multiprocessing import freeze_support
1398
1399         freeze_support()
1400
1401     patch_click()
1402     main()
1403
1404
1405 if __name__ == "__main__":
1406     patched_main()