]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump pypa/cibuildwheel from 2.10.2 to 2.11.2 (#3367)
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
34
35 from _black_version import version as __version__
36 from black.cache import Cache, get_cache_info, read_cache, write_cache
37 from black.comments import normalize_fmt_off
38 from black.const import (
39     DEFAULT_EXCLUDES,
40     DEFAULT_INCLUDES,
41     DEFAULT_LINE_LENGTH,
42     STDIN_PLACEHOLDER,
43 )
44 from black.files import (
45     find_project_root,
46     find_pyproject_toml,
47     find_user_pyproject_toml,
48     gen_python_files,
49     get_gitignore,
50     normalize_path_maybe_ignore,
51     parse_pyproject_toml,
52     wrap_stream_for_windows,
53 )
54 from black.handle_ipynb_magics import (
55     PYTHON_CELL_MAGICS,
56     TRANSFORMED_MAGICS,
57     jupyter_dependencies_are_installed,
58     mask_cell,
59     put_trailing_semicolon_back,
60     remove_trailing_semicolon,
61     unmask_cell,
62 )
63 from black.linegen import LN, LineGenerator, transform_line
64 from black.lines import EmptyLineTracker, LinesBlock
65 from black.mode import (
66     FUTURE_FLAG_TO_FEATURE,
67     VERSION_TO_FEATURES,
68     Feature,
69     Mode,
70     TargetVersion,
71     supports_feature,
72 )
73 from black.nodes import (
74     STARS,
75     is_number_token,
76     is_simple_decorator_expression,
77     is_string_token,
78     syms,
79 )
80 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
81 from black.parsing import InvalidInput  # noqa F401
82 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
83 from black.report import Changed, NothingChanged, Report
84 from black.trans import iter_fexpr_spans
85 from blib2to3.pgen2 import token
86 from blib2to3.pytree import Leaf, Node
87
88 COMPILED = Path(__file__).suffix in (".pyd", ".so")
89
90 # types
91 FileContent = str
92 Encoding = str
93 NewLine = str
94
95
96 class WriteBack(Enum):
97     NO = 0
98     YES = 1
99     DIFF = 2
100     CHECK = 3
101     COLOR_DIFF = 4
102
103     @classmethod
104     def from_configuration(
105         cls, *, check: bool, diff: bool, color: bool = False
106     ) -> "WriteBack":
107         if check and not diff:
108             return cls.CHECK
109
110         if diff and color:
111             return cls.COLOR_DIFF
112
113         return cls.DIFF if diff else cls.YES
114
115
116 # Legacy name, left for integrations.
117 FileMode = Mode
118
119
120 def read_pyproject_toml(
121     ctx: click.Context, param: click.Parameter, value: Optional[str]
122 ) -> Optional[str]:
123     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
124
125     Returns the path to a successfully found and read configuration file, None
126     otherwise.
127     """
128     if not value:
129         value = find_pyproject_toml(ctx.params.get("src", ()))
130         if value is None:
131             return None
132
133     try:
134         config = parse_pyproject_toml(value)
135     except (OSError, ValueError) as e:
136         raise click.FileError(
137             filename=value, hint=f"Error reading configuration file: {e}"
138         ) from None
139
140     if not config:
141         return None
142     else:
143         # Sanitize the values to be Click friendly. For more information please see:
144         # https://github.com/psf/black/issues/1458
145         # https://github.com/pallets/click/issues/1567
146         config = {
147             k: str(v) if not isinstance(v, (list, dict)) else v
148             for k, v in config.items()
149         }
150
151     target_version = config.get("target_version")
152     if target_version is not None and not isinstance(target_version, list):
153         raise click.BadOptionUsage(
154             "target-version", "Config key target-version must be a list"
155         )
156
157     default_map: Dict[str, Any] = {}
158     if ctx.default_map:
159         default_map.update(ctx.default_map)
160     default_map.update(config)
161
162     ctx.default_map = default_map
163     return value
164
165
166 def target_version_option_callback(
167     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
168 ) -> List[TargetVersion]:
169     """Compute the target versions from a --target-version flag.
170
171     This is its own function because mypy couldn't infer the type correctly
172     when it was a lambda, causing mypyc trouble.
173     """
174     return [TargetVersion[val.upper()] for val in v]
175
176
177 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
178     """Compile a regular expression string in `regex`.
179
180     If it contains newlines, use verbose mode.
181     """
182     if "\n" in regex:
183         regex = "(?x)" + regex
184     compiled: Pattern[str] = re.compile(regex)
185     return compiled
186
187
188 def validate_regex(
189     ctx: click.Context,
190     param: click.Parameter,
191     value: Optional[str],
192 ) -> Optional[Pattern[str]]:
193     try:
194         return re_compile_maybe_verbose(value) if value is not None else None
195     except re.error as e:
196         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
197
198
199 @click.command(
200     context_settings={"help_option_names": ["-h", "--help"]},
201     # While Click does set this field automatically using the docstring, mypyc
202     # (annoyingly) strips 'em so we need to set it here too.
203     help="The uncompromising code formatter.",
204 )
205 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
206 @click.option(
207     "-l",
208     "--line-length",
209     type=int,
210     default=DEFAULT_LINE_LENGTH,
211     help="How many characters per line to allow.",
212     show_default=True,
213 )
214 @click.option(
215     "-t",
216     "--target-version",
217     type=click.Choice([v.name.lower() for v in TargetVersion]),
218     callback=target_version_option_callback,
219     multiple=True,
220     help=(
221         "Python versions that should be supported by Black's output. [default: per-file"
222         " auto-detection]"
223     ),
224 )
225 @click.option(
226     "--pyi",
227     is_flag=True,
228     help=(
229         "Format all input files like typing stubs regardless of file extension (useful"
230         " when piping source on standard input)."
231     ),
232 )
233 @click.option(
234     "--ipynb",
235     is_flag=True,
236     help=(
237         "Format all input files like Jupyter Notebooks regardless of file extension "
238         "(useful when piping source on standard input)."
239     ),
240 )
241 @click.option(
242     "--python-cell-magics",
243     multiple=True,
244     help=(
245         "When processing Jupyter Notebooks, add the given magic to the list"
246         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
247         " Useful for formatting cells with custom python magics."
248     ),
249     default=[],
250 )
251 @click.option(
252     "-x",
253     "--skip-source-first-line",
254     is_flag=True,
255     help="Skip the first line of the source code.",
256 )
257 @click.option(
258     "-S",
259     "--skip-string-normalization",
260     is_flag=True,
261     help="Don't normalize string quotes or prefixes.",
262 )
263 @click.option(
264     "-C",
265     "--skip-magic-trailing-comma",
266     is_flag=True,
267     help="Don't use trailing commas as a reason to split lines.",
268 )
269 @click.option(
270     "--experimental-string-processing",
271     is_flag=True,
272     hidden=True,
273     help="(DEPRECATED and now included in --preview) Normalize string literals.",
274 )
275 @click.option(
276     "--preview",
277     is_flag=True,
278     help=(
279         "Enable potentially disruptive style changes that may be added to Black's main"
280         " functionality in the next major release."
281     ),
282 )
283 @click.option(
284     "--check",
285     is_flag=True,
286     help=(
287         "Don't write the files back, just return the status. Return code 0 means"
288         " nothing would change. Return code 1 means some files would be reformatted."
289         " Return code 123 means there was an internal error."
290     ),
291 )
292 @click.option(
293     "--diff",
294     is_flag=True,
295     help="Don't write the files back, just output a diff for each file on stdout.",
296 )
297 @click.option(
298     "--color/--no-color",
299     is_flag=True,
300     help="Show colored diff. Only applies when `--diff` is given.",
301 )
302 @click.option(
303     "--fast/--safe",
304     is_flag=True,
305     help="If --fast given, skip temporary sanity checks. [default: --safe]",
306 )
307 @click.option(
308     "--required-version",
309     type=str,
310     help=(
311         "Require a specific version of Black to be running (useful for unifying results"
312         " across many environments e.g. with a pyproject.toml file). It can be"
313         " either a major version number or an exact version."
314     ),
315 )
316 @click.option(
317     "--include",
318     type=str,
319     default=DEFAULT_INCLUDES,
320     callback=validate_regex,
321     help=(
322         "A regular expression that matches files and directories that should be"
323         " included on recursive searches. An empty value means all files are included"
324         " regardless of the name. Use forward slashes for directories on all platforms"
325         " (Windows, too). Exclusions are calculated first, inclusions later."
326     ),
327     show_default=True,
328 )
329 @click.option(
330     "--exclude",
331     type=str,
332     callback=validate_regex,
333     help=(
334         "A regular expression that matches files and directories that should be"
335         " excluded on recursive searches. An empty value means no paths are excluded."
336         " Use forward slashes for directories on all platforms (Windows, too)."
337         " Exclusions are calculated first, inclusions later. [default:"
338         f" {DEFAULT_EXCLUDES}]"
339     ),
340     show_default=False,
341 )
342 @click.option(
343     "--extend-exclude",
344     type=str,
345     callback=validate_regex,
346     help=(
347         "Like --exclude, but adds additional files and directories on top of the"
348         " excluded ones. (Useful if you simply want to add to the default)"
349     ),
350 )
351 @click.option(
352     "--force-exclude",
353     type=str,
354     callback=validate_regex,
355     help=(
356         "Like --exclude, but files and directories matching this regex will be "
357         "excluded even when they are passed explicitly as arguments."
358     ),
359 )
360 @click.option(
361     "--stdin-filename",
362     type=str,
363     help=(
364         "The name of the file when passing it through stdin. Useful to make "
365         "sure Black will respect --force-exclude option on some "
366         "editors that rely on using stdin."
367     ),
368 )
369 @click.option(
370     "-W",
371     "--workers",
372     type=click.IntRange(min=1),
373     default=None,
374     help="Number of parallel workers [default: number of CPUs in the system]",
375 )
376 @click.option(
377     "-q",
378     "--quiet",
379     is_flag=True,
380     help=(
381         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
382         " those with 2>/dev/null."
383     ),
384 )
385 @click.option(
386     "-v",
387     "--verbose",
388     is_flag=True,
389     help=(
390         "Also emit messages to stderr about files that were not changed or were ignored"
391         " due to exclusion patterns."
392     ),
393 )
394 @click.version_option(
395     version=__version__,
396     message=(
397         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
398         f"Python ({platform.python_implementation()}) {platform.python_version()}"
399     ),
400 )
401 @click.argument(
402     "src",
403     nargs=-1,
404     type=click.Path(
405         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
406     ),
407     is_eager=True,
408     metavar="SRC ...",
409 )
410 @click.option(
411     "--config",
412     type=click.Path(
413         exists=True,
414         file_okay=True,
415         dir_okay=False,
416         readable=True,
417         allow_dash=False,
418         path_type=str,
419     ),
420     is_eager=True,
421     callback=read_pyproject_toml,
422     help="Read configuration from FILE path.",
423 )
424 @click.pass_context
425 def main(  # noqa: C901
426     ctx: click.Context,
427     code: Optional[str],
428     line_length: int,
429     target_version: List[TargetVersion],
430     check: bool,
431     diff: bool,
432     color: bool,
433     fast: bool,
434     pyi: bool,
435     ipynb: bool,
436     python_cell_magics: Sequence[str],
437     skip_source_first_line: bool,
438     skip_string_normalization: bool,
439     skip_magic_trailing_comma: bool,
440     experimental_string_processing: bool,
441     preview: bool,
442     quiet: bool,
443     verbose: bool,
444     required_version: Optional[str],
445     include: Pattern[str],
446     exclude: Optional[Pattern[str]],
447     extend_exclude: Optional[Pattern[str]],
448     force_exclude: Optional[Pattern[str]],
449     stdin_filename: Optional[str],
450     workers: Optional[int],
451     src: Tuple[str, ...],
452     config: Optional[str],
453 ) -> None:
454     """The uncompromising code formatter."""
455     ctx.ensure_object(dict)
456
457     if src and code is not None:
458         out(
459             main.get_usage(ctx)
460             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
461         )
462         ctx.exit(1)
463     if not src and code is None:
464         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
465         ctx.exit(1)
466
467     root, method = (
468         find_project_root(src, stdin_filename) if code is None else (None, None)
469     )
470     ctx.obj["root"] = root
471
472     if verbose:
473         if root:
474             out(
475                 f"Identified `{root}` as project root containing a {method}.",
476                 fg="blue",
477             )
478
479             normalized = [
480                 (source, source)
481                 if source == "-"
482                 else (normalize_path_maybe_ignore(Path(source), root), source)
483                 for source in src
484             ]
485             srcs_string = ", ".join(
486                 [
487                     f'"{_norm}"'
488                     if _norm
489                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
490                     for _norm, source in normalized
491                 ]
492             )
493             out(f"Sources to be formatted: {srcs_string}", fg="blue")
494
495         if config:
496             config_source = ctx.get_parameter_source("config")
497             user_level_config = str(find_user_pyproject_toml())
498             if config == user_level_config:
499                 out(
500                     (
501                         "Using configuration from user-level config at "
502                         f"'{user_level_config}'."
503                     ),
504                     fg="blue",
505                 )
506             elif config_source in (
507                 ParameterSource.DEFAULT,
508                 ParameterSource.DEFAULT_MAP,
509             ):
510                 out("Using configuration from project root.", fg="blue")
511             else:
512                 out(f"Using configuration in '{config}'.", fg="blue")
513
514     error_msg = "Oh no! 💥 💔 💥"
515     if (
516         required_version
517         and required_version != __version__
518         and required_version != __version__.split(".")[0]
519     ):
520         err(
521             f"{error_msg} The required version `{required_version}` does not match"
522             f" the running version `{__version__}`!"
523         )
524         ctx.exit(1)
525     if ipynb and pyi:
526         err("Cannot pass both `pyi` and `ipynb` flags!")
527         ctx.exit(1)
528
529     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
530     if target_version:
531         versions = set(target_version)
532     else:
533         # We'll autodetect later.
534         versions = set()
535     mode = Mode(
536         target_versions=versions,
537         line_length=line_length,
538         is_pyi=pyi,
539         is_ipynb=ipynb,
540         skip_source_first_line=skip_source_first_line,
541         string_normalization=not skip_string_normalization,
542         magic_trailing_comma=not skip_magic_trailing_comma,
543         experimental_string_processing=experimental_string_processing,
544         preview=preview,
545         python_cell_magics=set(python_cell_magics),
546     )
547
548     if code is not None:
549         # Run in quiet mode by default with -c; the extra output isn't useful.
550         # You can still pass -v to get verbose output.
551         quiet = True
552
553     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
554
555     if code is not None:
556         reformat_code(
557             content=code, fast=fast, write_back=write_back, mode=mode, report=report
558         )
559     else:
560         try:
561             sources = get_sources(
562                 ctx=ctx,
563                 src=src,
564                 quiet=quiet,
565                 verbose=verbose,
566                 include=include,
567                 exclude=exclude,
568                 extend_exclude=extend_exclude,
569                 force_exclude=force_exclude,
570                 report=report,
571                 stdin_filename=stdin_filename,
572             )
573         except GitWildMatchPatternError:
574             ctx.exit(1)
575
576         path_empty(
577             sources,
578             "No Python files are present to be formatted. Nothing to do 😴",
579             quiet,
580             verbose,
581             ctx,
582         )
583
584         if len(sources) == 1:
585             reformat_one(
586                 src=sources.pop(),
587                 fast=fast,
588                 write_back=write_back,
589                 mode=mode,
590                 report=report,
591             )
592         else:
593             from black.concurrency import reformat_many
594
595             reformat_many(
596                 sources=sources,
597                 fast=fast,
598                 write_back=write_back,
599                 mode=mode,
600                 report=report,
601                 workers=workers,
602             )
603
604     if verbose or not quiet:
605         if code is None and (verbose or report.change_count or report.failure_count):
606             out()
607         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
608         if code is None:
609             click.echo(str(report), err=True)
610     ctx.exit(report.return_code)
611
612
613 def get_sources(
614     *,
615     ctx: click.Context,
616     src: Tuple[str, ...],
617     quiet: bool,
618     verbose: bool,
619     include: Pattern[str],
620     exclude: Optional[Pattern[str]],
621     extend_exclude: Optional[Pattern[str]],
622     force_exclude: Optional[Pattern[str]],
623     report: "Report",
624     stdin_filename: Optional[str],
625 ) -> Set[Path]:
626     """Compute the set of files to be formatted."""
627     sources: Set[Path] = set()
628     root = ctx.obj["root"]
629
630     for s in src:
631         if s == "-" and stdin_filename:
632             p = Path(stdin_filename)
633             is_stdin = True
634         else:
635             p = Path(s)
636             is_stdin = False
637
638         if is_stdin or p.is_file():
639             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
640             if normalized_path is None:
641                 continue
642
643             normalized_path = "/" + normalized_path
644             # Hard-exclude any files that matches the `--force-exclude` regex.
645             if force_exclude:
646                 force_exclude_match = force_exclude.search(normalized_path)
647             else:
648                 force_exclude_match = None
649             if force_exclude_match and force_exclude_match.group(0):
650                 report.path_ignored(p, "matches the --force-exclude regular expression")
651                 continue
652
653             if is_stdin:
654                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
655
656             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
657                 verbose=verbose, quiet=quiet
658             ):
659                 continue
660
661             sources.add(p)
662         elif p.is_dir():
663             if exclude is None:
664                 exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
665                 gitignore = get_gitignore(root)
666                 p_gitignore = get_gitignore(p)
667                 # No need to use p's gitignore if it is identical to root's gitignore
668                 # (i.e. root and p point to the same directory).
669                 if gitignore != p_gitignore:
670                     gitignore += p_gitignore
671             else:
672                 gitignore = None
673             sources.update(
674                 gen_python_files(
675                     p.iterdir(),
676                     ctx.obj["root"],
677                     include,
678                     exclude,
679                     extend_exclude,
680                     force_exclude,
681                     report,
682                     gitignore,
683                     verbose=verbose,
684                     quiet=quiet,
685                 )
686             )
687         elif s == "-":
688             sources.add(p)
689         else:
690             err(f"invalid path: {s}")
691     return sources
692
693
694 def path_empty(
695     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
696 ) -> None:
697     """
698     Exit if there is no `src` provided for formatting
699     """
700     if not src:
701         if verbose or not quiet:
702             out(msg)
703         ctx.exit(0)
704
705
706 def reformat_code(
707     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
708 ) -> None:
709     """
710     Reformat and print out `content` without spawning child processes.
711     Similar to `reformat_one`, but for string content.
712
713     `fast`, `write_back`, and `mode` options are passed to
714     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
715     """
716     path = Path("<string>")
717     try:
718         changed = Changed.NO
719         if format_stdin_to_stdout(
720             content=content, fast=fast, write_back=write_back, mode=mode
721         ):
722             changed = Changed.YES
723         report.done(path, changed)
724     except Exception as exc:
725         if report.verbose:
726             traceback.print_exc()
727         report.failed(path, str(exc))
728
729
730 # diff-shades depends on being to monkeypatch this function to operate. I know it's
731 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
732 @mypyc_attr(patchable=True)
733 def reformat_one(
734     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
735 ) -> None:
736     """Reformat a single file under `src` without spawning child processes.
737
738     `fast`, `write_back`, and `mode` options are passed to
739     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
740     """
741     try:
742         changed = Changed.NO
743
744         if str(src) == "-":
745             is_stdin = True
746         elif str(src).startswith(STDIN_PLACEHOLDER):
747             is_stdin = True
748             # Use the original name again in case we want to print something
749             # to the user
750             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
751         else:
752             is_stdin = False
753
754         if is_stdin:
755             if src.suffix == ".pyi":
756                 mode = replace(mode, is_pyi=True)
757             elif src.suffix == ".ipynb":
758                 mode = replace(mode, is_ipynb=True)
759             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
760                 changed = Changed.YES
761         else:
762             cache: Cache = {}
763             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
764                 cache = read_cache(mode)
765                 res_src = src.resolve()
766                 res_src_s = str(res_src)
767                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
768                     changed = Changed.CACHED
769             if changed is not Changed.CACHED and format_file_in_place(
770                 src, fast=fast, write_back=write_back, mode=mode
771             ):
772                 changed = Changed.YES
773             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
774                 write_back is WriteBack.CHECK and changed is Changed.NO
775             ):
776                 write_cache(cache, [src], mode)
777         report.done(src, changed)
778     except Exception as exc:
779         if report.verbose:
780             traceback.print_exc()
781         report.failed(src, str(exc))
782
783
784 def format_file_in_place(
785     src: Path,
786     fast: bool,
787     mode: Mode,
788     write_back: WriteBack = WriteBack.NO,
789     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
790 ) -> bool:
791     """Format file under `src` path. Return True if changed.
792
793     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
794     code to the file.
795     `mode` and `fast` options are passed to :func:`format_file_contents`.
796     """
797     if src.suffix == ".pyi":
798         mode = replace(mode, is_pyi=True)
799     elif src.suffix == ".ipynb":
800         mode = replace(mode, is_ipynb=True)
801
802     then = datetime.utcfromtimestamp(src.stat().st_mtime)
803     header = b""
804     with open(src, "rb") as buf:
805         if mode.skip_source_first_line:
806             header = buf.readline()
807         src_contents, encoding, newline = decode_bytes(buf.read())
808     try:
809         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
810     except NothingChanged:
811         return False
812     except JSONDecodeError:
813         raise ValueError(
814             f"File '{src}' cannot be parsed as valid Jupyter notebook."
815         ) from None
816     src_contents = header.decode(encoding) + src_contents
817     dst_contents = header.decode(encoding) + dst_contents
818
819     if write_back == WriteBack.YES:
820         with open(src, "w", encoding=encoding, newline=newline) as f:
821             f.write(dst_contents)
822     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
823         now = datetime.utcnow()
824         src_name = f"{src}\t{then} +0000"
825         dst_name = f"{src}\t{now} +0000"
826         if mode.is_ipynb:
827             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
828         else:
829             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
830
831         if write_back == WriteBack.COLOR_DIFF:
832             diff_contents = color_diff(diff_contents)
833
834         with lock or nullcontext():
835             f = io.TextIOWrapper(
836                 sys.stdout.buffer,
837                 encoding=encoding,
838                 newline=newline,
839                 write_through=True,
840             )
841             f = wrap_stream_for_windows(f)
842             f.write(diff_contents)
843             f.detach()
844
845     return True
846
847
848 def format_stdin_to_stdout(
849     fast: bool,
850     *,
851     content: Optional[str] = None,
852     write_back: WriteBack = WriteBack.NO,
853     mode: Mode,
854 ) -> bool:
855     """Format file on stdin. Return True if changed.
856
857     If content is None, it's read from sys.stdin.
858
859     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
860     write a diff to stdout. The `mode` argument is passed to
861     :func:`format_file_contents`.
862     """
863     then = datetime.utcnow()
864
865     if content is None:
866         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
867     else:
868         src, encoding, newline = content, "utf-8", ""
869
870     dst = src
871     try:
872         dst = format_file_contents(src, fast=fast, mode=mode)
873         return True
874
875     except NothingChanged:
876         return False
877
878     finally:
879         f = io.TextIOWrapper(
880             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
881         )
882         if write_back == WriteBack.YES:
883             # Make sure there's a newline after the content
884             if dst and dst[-1] != "\n":
885                 dst += "\n"
886             f.write(dst)
887         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
888             now = datetime.utcnow()
889             src_name = f"STDIN\t{then} +0000"
890             dst_name = f"STDOUT\t{now} +0000"
891             d = diff(src, dst, src_name, dst_name)
892             if write_back == WriteBack.COLOR_DIFF:
893                 d = color_diff(d)
894                 f = wrap_stream_for_windows(f)
895             f.write(d)
896         f.detach()
897
898
899 def check_stability_and_equivalence(
900     src_contents: str, dst_contents: str, *, mode: Mode
901 ) -> None:
902     """Perform stability and equivalence checks.
903
904     Raise AssertionError if source and destination contents are not
905     equivalent, or if a second pass of the formatter would format the
906     content differently.
907     """
908     assert_equivalent(src_contents, dst_contents)
909     assert_stable(src_contents, dst_contents, mode=mode)
910
911
912 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
913     """Reformat contents of a file and return new contents.
914
915     If `fast` is False, additionally confirm that the reformatted code is
916     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
917     `mode` is passed to :func:`format_str`.
918     """
919     if not src_contents.strip():
920         raise NothingChanged
921
922     if mode.is_ipynb:
923         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
924     else:
925         dst_contents = format_str(src_contents, mode=mode)
926     if src_contents == dst_contents:
927         raise NothingChanged
928
929     if not fast and not mode.is_ipynb:
930         # Jupyter notebooks will already have been checked above.
931         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
932     return dst_contents
933
934
935 def validate_cell(src: str, mode: Mode) -> None:
936     """Check that cell does not already contain TransformerManager transformations,
937     or non-Python cell magics, which might cause tokenizer_rt to break because of
938     indentations.
939
940     If a cell contains ``!ls``, then it'll be transformed to
941     ``get_ipython().system('ls')``. However, if the cell originally contained
942     ``get_ipython().system('ls')``, then it would get transformed in the same way:
943
944         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
945         "get_ipython().system('ls')\n"
946         >>> TransformerManager().transform_cell("!ls")
947         "get_ipython().system('ls')\n"
948
949     Due to the impossibility of safely roundtripping in such situations, cells
950     containing transformed magics will be ignored.
951     """
952     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
953         raise NothingChanged
954     if (
955         src[:2] == "%%"
956         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
957     ):
958         raise NothingChanged
959
960
961 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
962     """Format code in given cell of Jupyter notebook.
963
964     General idea is:
965
966       - if cell has trailing semicolon, remove it;
967       - if cell has IPython magics, mask them;
968       - format cell;
969       - reinstate IPython magics;
970       - reinstate trailing semicolon (if originally present);
971       - strip trailing newlines.
972
973     Cells with syntax errors will not be processed, as they
974     could potentially be automagics or multi-line magics, which
975     are currently not supported.
976     """
977     validate_cell(src, mode)
978     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
979         src
980     )
981     try:
982         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
983     except SyntaxError:
984         raise NothingChanged from None
985     masked_dst = format_str(masked_src, mode=mode)
986     if not fast:
987         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
988     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
989     dst = put_trailing_semicolon_back(
990         dst_without_trailing_semicolon, has_trailing_semicolon
991     )
992     dst = dst.rstrip("\n")
993     if dst == src:
994         raise NothingChanged from None
995     return dst
996
997
998 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
999     """If notebook is marked as non-Python, don't format it.
1000
1001     All notebook metadata fields are optional, see
1002     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1003     if a notebook has empty metadata, we will try to parse it anyway.
1004     """
1005     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1006     if language is not None and language != "python":
1007         raise NothingChanged from None
1008
1009
1010 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1011     """Format Jupyter notebook.
1012
1013     Operate cell-by-cell, only on code cells, only for Python notebooks.
1014     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1015     """
1016     trailing_newline = src_contents[-1] == "\n"
1017     modified = False
1018     nb = json.loads(src_contents)
1019     validate_metadata(nb)
1020     for cell in nb["cells"]:
1021         if cell.get("cell_type", None) == "code":
1022             try:
1023                 src = "".join(cell["source"])
1024                 dst = format_cell(src, fast=fast, mode=mode)
1025             except NothingChanged:
1026                 pass
1027             else:
1028                 cell["source"] = dst.splitlines(keepends=True)
1029                 modified = True
1030     if modified:
1031         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1032         if trailing_newline:
1033             dst_contents = dst_contents + "\n"
1034         return dst_contents
1035     else:
1036         raise NothingChanged
1037
1038
1039 def format_str(src_contents: str, *, mode: Mode) -> str:
1040     """Reformat a string and return new contents.
1041
1042     `mode` determines formatting options, such as how many characters per line are
1043     allowed.  Example:
1044
1045     >>> import black
1046     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1047     def f(arg: str = "") -> None:
1048         ...
1049
1050     A more complex example:
1051
1052     >>> print(
1053     ...   black.format_str(
1054     ...     "def f(arg:str='')->None: hey",
1055     ...     mode=black.Mode(
1056     ...       target_versions={black.TargetVersion.PY36},
1057     ...       line_length=10,
1058     ...       string_normalization=False,
1059     ...       is_pyi=False,
1060     ...     ),
1061     ...   ),
1062     ... )
1063     def f(
1064         arg: str = '',
1065     ) -> None:
1066         hey
1067
1068     """
1069     dst_contents = _format_str_once(src_contents, mode=mode)
1070     # Forced second pass to work around optional trailing commas (becoming
1071     # forced trailing commas on pass 2) interacting differently with optional
1072     # parentheses.  Admittedly ugly.
1073     if src_contents != dst_contents:
1074         return _format_str_once(dst_contents, mode=mode)
1075     return dst_contents
1076
1077
1078 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1079     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1080     dst_blocks: List[LinesBlock] = []
1081     if mode.target_versions:
1082         versions = mode.target_versions
1083     else:
1084         future_imports = get_future_imports(src_node)
1085         versions = detect_target_versions(src_node, future_imports=future_imports)
1086
1087     normalize_fmt_off(src_node, preview=mode.preview)
1088     lines = LineGenerator(mode=mode)
1089     elt = EmptyLineTracker(mode=mode)
1090     split_line_features = {
1091         feature
1092         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1093         if supports_feature(versions, feature)
1094     }
1095     block: Optional[LinesBlock] = None
1096     for current_line in lines.visit(src_node):
1097         block = elt.maybe_empty_lines(current_line)
1098         dst_blocks.append(block)
1099         for line in transform_line(
1100             current_line, mode=mode, features=split_line_features
1101         ):
1102             block.content_lines.append(str(line))
1103     if dst_blocks:
1104         dst_blocks[-1].after = 0
1105     dst_contents = []
1106     for block in dst_blocks:
1107         dst_contents.extend(block.all_lines())
1108     return "".join(dst_contents)
1109
1110
1111 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1112     """Return a tuple of (decoded_contents, encoding, newline).
1113
1114     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1115     universal newlines (i.e. only contains LF).
1116     """
1117     srcbuf = io.BytesIO(src)
1118     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1119     if not lines:
1120         return "", encoding, "\n"
1121
1122     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1123     srcbuf.seek(0)
1124     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1125         return tiow.read(), encoding, newline
1126
1127
1128 def get_features_used(  # noqa: C901
1129     node: Node, *, future_imports: Optional[Set[str]] = None
1130 ) -> Set[Feature]:
1131     """Return a set of (relatively) new Python features used in this file.
1132
1133     Currently looking for:
1134     - f-strings;
1135     - self-documenting expressions in f-strings (f"{x=}");
1136     - underscores in numeric literals;
1137     - trailing commas after * or ** in function signatures and calls;
1138     - positional only arguments in function signatures and lambdas;
1139     - assignment expression;
1140     - relaxed decorator syntax;
1141     - usage of __future__ flags (annotations);
1142     - print / exec statements;
1143     """
1144     features: Set[Feature] = set()
1145     if future_imports:
1146         features |= {
1147             FUTURE_FLAG_TO_FEATURE[future_import]
1148             for future_import in future_imports
1149             if future_import in FUTURE_FLAG_TO_FEATURE
1150         }
1151
1152     for n in node.pre_order():
1153         if is_string_token(n):
1154             value_head = n.value[:2]
1155             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1156                 features.add(Feature.F_STRINGS)
1157                 if Feature.DEBUG_F_STRINGS not in features:
1158                     for span_beg, span_end in iter_fexpr_spans(n.value):
1159                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1160                             features.add(Feature.DEBUG_F_STRINGS)
1161                             break
1162
1163         elif is_number_token(n):
1164             if "_" in n.value:
1165                 features.add(Feature.NUMERIC_UNDERSCORES)
1166
1167         elif n.type == token.SLASH:
1168             if n.parent and n.parent.type in {
1169                 syms.typedargslist,
1170                 syms.arglist,
1171                 syms.varargslist,
1172             }:
1173                 features.add(Feature.POS_ONLY_ARGUMENTS)
1174
1175         elif n.type == token.COLONEQUAL:
1176             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1177
1178         elif n.type == syms.decorator:
1179             if len(n.children) > 1 and not is_simple_decorator_expression(
1180                 n.children[1]
1181             ):
1182                 features.add(Feature.RELAXED_DECORATORS)
1183
1184         elif (
1185             n.type in {syms.typedargslist, syms.arglist}
1186             and n.children
1187             and n.children[-1].type == token.COMMA
1188         ):
1189             if n.type == syms.typedargslist:
1190                 feature = Feature.TRAILING_COMMA_IN_DEF
1191             else:
1192                 feature = Feature.TRAILING_COMMA_IN_CALL
1193
1194             for ch in n.children:
1195                 if ch.type in STARS:
1196                     features.add(feature)
1197
1198                 if ch.type == syms.argument:
1199                     for argch in ch.children:
1200                         if argch.type in STARS:
1201                             features.add(feature)
1202
1203         elif (
1204             n.type in {syms.return_stmt, syms.yield_expr}
1205             and len(n.children) >= 2
1206             and n.children[1].type == syms.testlist_star_expr
1207             and any(child.type == syms.star_expr for child in n.children[1].children)
1208         ):
1209             features.add(Feature.UNPACKING_ON_FLOW)
1210
1211         elif (
1212             n.type == syms.annassign
1213             and len(n.children) >= 4
1214             and n.children[3].type == syms.testlist_star_expr
1215         ):
1216             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1217
1218         elif (
1219             n.type == syms.except_clause
1220             and len(n.children) >= 2
1221             and n.children[1].type == token.STAR
1222         ):
1223             features.add(Feature.EXCEPT_STAR)
1224
1225         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1226             child.type == syms.star_expr for child in n.children
1227         ):
1228             features.add(Feature.VARIADIC_GENERICS)
1229
1230         elif (
1231             n.type == syms.tname_star
1232             and len(n.children) == 3
1233             and n.children[2].type == syms.star_expr
1234         ):
1235             features.add(Feature.VARIADIC_GENERICS)
1236
1237     return features
1238
1239
1240 def detect_target_versions(
1241     node: Node, *, future_imports: Optional[Set[str]] = None
1242 ) -> Set[TargetVersion]:
1243     """Detect the version to target based on the nodes used."""
1244     features = get_features_used(node, future_imports=future_imports)
1245     return {
1246         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1247     }
1248
1249
1250 def get_future_imports(node: Node) -> Set[str]:
1251     """Return a set of __future__ imports in the file."""
1252     imports: Set[str] = set()
1253
1254     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1255         for child in children:
1256             if isinstance(child, Leaf):
1257                 if child.type == token.NAME:
1258                     yield child.value
1259
1260             elif child.type == syms.import_as_name:
1261                 orig_name = child.children[0]
1262                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1263                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1264                 yield orig_name.value
1265
1266             elif child.type == syms.import_as_names:
1267                 yield from get_imports_from_children(child.children)
1268
1269             else:
1270                 raise AssertionError("Invalid syntax parsing imports")
1271
1272     for child in node.children:
1273         if child.type != syms.simple_stmt:
1274             break
1275
1276         first_child = child.children[0]
1277         if isinstance(first_child, Leaf):
1278             # Continue looking if we see a docstring; otherwise stop.
1279             if (
1280                 len(child.children) == 2
1281                 and first_child.type == token.STRING
1282                 and child.children[1].type == token.NEWLINE
1283             ):
1284                 continue
1285
1286             break
1287
1288         elif first_child.type == syms.import_from:
1289             module_name = first_child.children[1]
1290             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1291                 break
1292
1293             imports |= set(get_imports_from_children(first_child.children[3:]))
1294         else:
1295             break
1296
1297     return imports
1298
1299
1300 def assert_equivalent(src: str, dst: str) -> None:
1301     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1302     try:
1303         src_ast = parse_ast(src)
1304     except Exception as exc:
1305         raise AssertionError(
1306             "cannot use --safe with this file; failed to parse source file AST: "
1307             f"{exc}\n"
1308             "This could be caused by running Black with an older Python version "
1309             "that does not support new syntax used in your source file."
1310         ) from exc
1311
1312     try:
1313         dst_ast = parse_ast(dst)
1314     except Exception as exc:
1315         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1316         raise AssertionError(
1317             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1318             "Please report a bug on https://github.com/psf/black/issues.  "
1319             f"This invalid output might be helpful: {log}"
1320         ) from None
1321
1322     src_ast_str = "\n".join(stringify_ast(src_ast))
1323     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1324     if src_ast_str != dst_ast_str:
1325         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1326         raise AssertionError(
1327             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1328             " source.  Please report a bug on "
1329             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1330         ) from None
1331
1332
1333 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1334     """Raise AssertionError if `dst` reformats differently the second time."""
1335     # We shouldn't call format_str() here, because that formats the string
1336     # twice and may hide a bug where we bounce back and forth between two
1337     # versions.
1338     newdst = _format_str_once(dst, mode=mode)
1339     if dst != newdst:
1340         log = dump_to_file(
1341             str(mode),
1342             diff(src, dst, "source", "first pass"),
1343             diff(dst, newdst, "first pass", "second pass"),
1344         )
1345         raise AssertionError(
1346             "INTERNAL ERROR: Black produced different code on the second pass of the"
1347             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1348             f"  This diff might be helpful: {log}"
1349         ) from None
1350
1351
1352 @contextmanager
1353 def nullcontext() -> Iterator[None]:
1354     """Return an empty context manager.
1355
1356     To be used like `nullcontext` in Python 3.7.
1357     """
1358     yield
1359
1360
1361 def patch_click() -> None:
1362     """Make Click not crash on Python 3.6 with LANG=C.
1363
1364     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1365     default which restricts paths that it can access during the lifetime of the
1366     application.  Click refuses to work in this scenario by raising a RuntimeError.
1367
1368     In case of Black the likelihood that non-ASCII characters are going to be used in
1369     file paths is minimal since it's Python source code.  Moreover, this crash was
1370     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1371     """
1372     modules: List[Any] = []
1373     try:
1374         from click import core
1375     except ImportError:
1376         pass
1377     else:
1378         modules.append(core)
1379     try:
1380         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1381         # older versions installed.
1382         from click import _unicodefun  # type: ignore
1383     except ImportError:
1384         pass
1385     else:
1386         modules.append(_unicodefun)
1387
1388     for module in modules:
1389         if hasattr(module, "_verify_python3_env"):
1390             module._verify_python3_env = lambda: None
1391         if hasattr(module, "_verify_python_env"):
1392             module._verify_python_env = lambda: None
1393
1394
1395 def patched_main() -> None:
1396     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1397     # environments so just assume we always need to call it if frozen.
1398     if getattr(sys, "frozen", False):
1399         from multiprocessing import freeze_support
1400
1401         freeze_support()
1402
1403     patch_click()
1404     main()
1405
1406
1407 if __name__ == "__main__":
1408     patched_main()