]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

afc76e1fa0cb5e7dda1a51544718ef778e846344
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import os
4 import platform
5 import re
6 import sys
7 import tokenize
8 import traceback
9 from contextlib import contextmanager
10 from dataclasses import replace
11 from datetime import datetime
12 from enum import Enum
13 from json.decoder import JSONDecodeError
14 from pathlib import Path
15 from typing import (
16     Any,
17     Dict,
18     Generator,
19     Iterator,
20     List,
21     MutableMapping,
22     Optional,
23     Pattern,
24     Sequence,
25     Set,
26     Sized,
27     Tuple,
28     Union,
29 )
30
31 if sys.version_info >= (3, 8):
32     from typing import Final
33 else:
34     from typing_extensions import Final
35
36 import click
37 from click.core import ParameterSource
38 from mypy_extensions import mypyc_attr
39 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
40
41 from _black_version import version as __version__
42 from black.cache import Cache, get_cache_info, read_cache, write_cache
43 from black.comments import normalize_fmt_off
44 from black.const import (
45     DEFAULT_EXCLUDES,
46     DEFAULT_INCLUDES,
47     DEFAULT_LINE_LENGTH,
48     STDIN_PLACEHOLDER,
49 )
50 from black.files import (
51     find_project_root,
52     find_pyproject_toml,
53     find_user_pyproject_toml,
54     gen_python_files,
55     get_gitignore,
56     normalize_path_maybe_ignore,
57     parse_pyproject_toml,
58     wrap_stream_for_windows,
59 )
60 from black.handle_ipynb_magics import (
61     PYTHON_CELL_MAGICS,
62     TRANSFORMED_MAGICS,
63     jupyter_dependencies_are_installed,
64     mask_cell,
65     put_trailing_semicolon_back,
66     remove_trailing_semicolon,
67     unmask_cell,
68 )
69 from black.linegen import LN, LineGenerator, transform_line
70 from black.lines import EmptyLineTracker, Line
71 from black.mode import (
72     FUTURE_FLAG_TO_FEATURE,
73     VERSION_TO_FEATURES,
74     Feature,
75     Mode,
76     TargetVersion,
77     supports_feature,
78 )
79 from black.nodes import (
80     STARS,
81     is_number_token,
82     is_simple_decorator_expression,
83     is_string_token,
84     syms,
85 )
86 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
87 from black.parsing import InvalidInput  # noqa F401
88 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
89 from black.report import Changed, NothingChanged, Report
90 from black.trans import iter_fexpr_spans
91 from blib2to3.pgen2 import token
92 from blib2to3.pytree import Leaf, Node
93
94 COMPILED = Path(__file__).suffix in (".pyd", ".so")
95 DEFAULT_WORKERS: Final = os.cpu_count()
96
97 # types
98 FileContent = str
99 Encoding = str
100 NewLine = str
101
102
103 class WriteBack(Enum):
104     NO = 0
105     YES = 1
106     DIFF = 2
107     CHECK = 3
108     COLOR_DIFF = 4
109
110     @classmethod
111     def from_configuration(
112         cls, *, check: bool, diff: bool, color: bool = False
113     ) -> "WriteBack":
114         if check and not diff:
115             return cls.CHECK
116
117         if diff and color:
118             return cls.COLOR_DIFF
119
120         return cls.DIFF if diff else cls.YES
121
122
123 # Legacy name, left for integrations.
124 FileMode = Mode
125
126
127 def read_pyproject_toml(
128     ctx: click.Context, param: click.Parameter, value: Optional[str]
129 ) -> Optional[str]:
130     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
131
132     Returns the path to a successfully found and read configuration file, None
133     otherwise.
134     """
135     if not value:
136         value = find_pyproject_toml(ctx.params.get("src", ()))
137         if value is None:
138             return None
139
140     try:
141         config = parse_pyproject_toml(value)
142     except (OSError, ValueError) as e:
143         raise click.FileError(
144             filename=value, hint=f"Error reading configuration file: {e}"
145         ) from None
146
147     if not config:
148         return None
149     else:
150         # Sanitize the values to be Click friendly. For more information please see:
151         # https://github.com/psf/black/issues/1458
152         # https://github.com/pallets/click/issues/1567
153         config = {
154             k: str(v) if not isinstance(v, (list, dict)) else v
155             for k, v in config.items()
156         }
157
158     target_version = config.get("target_version")
159     if target_version is not None and not isinstance(target_version, list):
160         raise click.BadOptionUsage(
161             "target-version", "Config key target-version must be a list"
162         )
163
164     default_map: Dict[str, Any] = {}
165     if ctx.default_map:
166         default_map.update(ctx.default_map)
167     default_map.update(config)
168
169     ctx.default_map = default_map
170     return value
171
172
173 def target_version_option_callback(
174     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
175 ) -> List[TargetVersion]:
176     """Compute the target versions from a --target-version flag.
177
178     This is its own function because mypy couldn't infer the type correctly
179     when it was a lambda, causing mypyc trouble.
180     """
181     return [TargetVersion[val.upper()] for val in v]
182
183
184 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
185     """Compile a regular expression string in `regex`.
186
187     If it contains newlines, use verbose mode.
188     """
189     if "\n" in regex:
190         regex = "(?x)" + regex
191     compiled: Pattern[str] = re.compile(regex)
192     return compiled
193
194
195 def validate_regex(
196     ctx: click.Context,
197     param: click.Parameter,
198     value: Optional[str],
199 ) -> Optional[Pattern[str]]:
200     try:
201         return re_compile_maybe_verbose(value) if value is not None else None
202     except re.error as e:
203         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
204
205
206 @click.command(
207     context_settings={"help_option_names": ["-h", "--help"]},
208     # While Click does set this field automatically using the docstring, mypyc
209     # (annoyingly) strips 'em so we need to set it here too.
210     help="The uncompromising code formatter.",
211 )
212 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
213 @click.option(
214     "-l",
215     "--line-length",
216     type=int,
217     default=DEFAULT_LINE_LENGTH,
218     help="How many characters per line to allow.",
219     show_default=True,
220 )
221 @click.option(
222     "-t",
223     "--target-version",
224     type=click.Choice([v.name.lower() for v in TargetVersion]),
225     callback=target_version_option_callback,
226     multiple=True,
227     help=(
228         "Python versions that should be supported by Black's output. [default: per-file"
229         " auto-detection]"
230     ),
231 )
232 @click.option(
233     "--pyi",
234     is_flag=True,
235     help=(
236         "Format all input files like typing stubs regardless of file extension (useful"
237         " when piping source on standard input)."
238     ),
239 )
240 @click.option(
241     "--ipynb",
242     is_flag=True,
243     help=(
244         "Format all input files like Jupyter Notebooks regardless of file extension "
245         "(useful when piping source on standard input)."
246     ),
247 )
248 @click.option(
249     "--python-cell-magics",
250     multiple=True,
251     help=(
252         "When processing Jupyter Notebooks, add the given magic to the list"
253         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
254         " Useful for formatting cells with custom python magics."
255     ),
256     default=[],
257 )
258 @click.option(
259     "-S",
260     "--skip-string-normalization",
261     is_flag=True,
262     help="Don't normalize string quotes or prefixes.",
263 )
264 @click.option(
265     "-C",
266     "--skip-magic-trailing-comma",
267     is_flag=True,
268     help="Don't use trailing commas as a reason to split lines.",
269 )
270 @click.option(
271     "--experimental-string-processing",
272     is_flag=True,
273     hidden=True,
274     help="(DEPRECATED and now included in --preview) Normalize string literals.",
275 )
276 @click.option(
277     "--preview",
278     is_flag=True,
279     help=(
280         "Enable potentially disruptive style changes that may be added to Black's main"
281         " functionality in the next major release."
282     ),
283 )
284 @click.option(
285     "--check",
286     is_flag=True,
287     help=(
288         "Don't write the files back, just return the status. Return code 0 means"
289         " nothing would change. Return code 1 means some files would be reformatted."
290         " Return code 123 means there was an internal error."
291     ),
292 )
293 @click.option(
294     "--diff",
295     is_flag=True,
296     help="Don't write the files back, just output a diff for each file on stdout.",
297 )
298 @click.option(
299     "--color/--no-color",
300     is_flag=True,
301     help="Show colored diff. Only applies when `--diff` is given.",
302 )
303 @click.option(
304     "--fast/--safe",
305     is_flag=True,
306     help="If --fast given, skip temporary sanity checks. [default: --safe]",
307 )
308 @click.option(
309     "--required-version",
310     type=str,
311     help=(
312         "Require a specific version of Black to be running (useful for unifying results"
313         " across many environments e.g. with a pyproject.toml file). It can be"
314         " either a major version number or an exact version."
315     ),
316 )
317 @click.option(
318     "--include",
319     type=str,
320     default=DEFAULT_INCLUDES,
321     callback=validate_regex,
322     help=(
323         "A regular expression that matches files and directories that should be"
324         " included on recursive searches. An empty value means all files are included"
325         " regardless of the name. Use forward slashes for directories on all platforms"
326         " (Windows, too). Exclusions are calculated first, inclusions later."
327     ),
328     show_default=True,
329 )
330 @click.option(
331     "--exclude",
332     type=str,
333     callback=validate_regex,
334     help=(
335         "A regular expression that matches files and directories that should be"
336         " excluded on recursive searches. An empty value means no paths are excluded."
337         " Use forward slashes for directories on all platforms (Windows, too)."
338         " Exclusions are calculated first, inclusions later. [default:"
339         f" {DEFAULT_EXCLUDES}]"
340     ),
341     show_default=False,
342 )
343 @click.option(
344     "--extend-exclude",
345     type=str,
346     callback=validate_regex,
347     help=(
348         "Like --exclude, but adds additional files and directories on top of the"
349         " excluded ones. (Useful if you simply want to add to the default)"
350     ),
351 )
352 @click.option(
353     "--force-exclude",
354     type=str,
355     callback=validate_regex,
356     help=(
357         "Like --exclude, but files and directories matching this regex will be "
358         "excluded even when they are passed explicitly as arguments."
359     ),
360 )
361 @click.option(
362     "--stdin-filename",
363     type=str,
364     help=(
365         "The name of the file when passing it through stdin. Useful to make "
366         "sure Black will respect --force-exclude option on some "
367         "editors that rely on using stdin."
368     ),
369 )
370 @click.option(
371     "-W",
372     "--workers",
373     type=click.IntRange(min=1),
374     default=DEFAULT_WORKERS,
375     show_default=True,
376     help="Number of parallel workers",
377 )
378 @click.option(
379     "-q",
380     "--quiet",
381     is_flag=True,
382     help=(
383         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
384         " those with 2>/dev/null."
385     ),
386 )
387 @click.option(
388     "-v",
389     "--verbose",
390     is_flag=True,
391     help=(
392         "Also emit messages to stderr about files that were not changed or were ignored"
393         " due to exclusion patterns."
394     ),
395 )
396 @click.version_option(
397     version=__version__,
398     message=(
399         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
400         f"Python ({platform.python_implementation()}) {platform.python_version()}"
401     ),
402 )
403 @click.argument(
404     "src",
405     nargs=-1,
406     type=click.Path(
407         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
408     ),
409     is_eager=True,
410     metavar="SRC ...",
411 )
412 @click.option(
413     "--config",
414     type=click.Path(
415         exists=True,
416         file_okay=True,
417         dir_okay=False,
418         readable=True,
419         allow_dash=False,
420         path_type=str,
421     ),
422     is_eager=True,
423     callback=read_pyproject_toml,
424     help="Read configuration from FILE path.",
425 )
426 @click.pass_context
427 def main(  # noqa: C901
428     ctx: click.Context,
429     code: Optional[str],
430     line_length: int,
431     target_version: List[TargetVersion],
432     check: bool,
433     diff: bool,
434     color: bool,
435     fast: bool,
436     pyi: bool,
437     ipynb: bool,
438     python_cell_magics: Sequence[str],
439     skip_string_normalization: bool,
440     skip_magic_trailing_comma: bool,
441     experimental_string_processing: bool,
442     preview: bool,
443     quiet: bool,
444     verbose: bool,
445     required_version: Optional[str],
446     include: Pattern[str],
447     exclude: Optional[Pattern[str]],
448     extend_exclude: Optional[Pattern[str]],
449     force_exclude: Optional[Pattern[str]],
450     stdin_filename: Optional[str],
451     workers: int,
452     src: Tuple[str, ...],
453     config: Optional[str],
454 ) -> None:
455     """The uncompromising code formatter."""
456     ctx.ensure_object(dict)
457
458     if src and code is not None:
459         out(
460             main.get_usage(ctx)
461             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
462         )
463         ctx.exit(1)
464     if not src and code is None:
465         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
466         ctx.exit(1)
467
468     root, method = (
469         find_project_root(src, stdin_filename) if code is None else (None, None)
470     )
471     ctx.obj["root"] = root
472
473     if verbose:
474         if root:
475             out(
476                 f"Identified `{root}` as project root containing a {method}.",
477                 fg="blue",
478             )
479
480             normalized = [
481                 (source, source)
482                 if source == "-"
483                 else (normalize_path_maybe_ignore(Path(source), root), source)
484                 for source in src
485             ]
486             srcs_string = ", ".join(
487                 [
488                     f'"{_norm}"'
489                     if _norm
490                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
491                     for _norm, source in normalized
492                 ]
493             )
494             out(f"Sources to be formatted: {srcs_string}", fg="blue")
495
496         if config:
497             config_source = ctx.get_parameter_source("config")
498             user_level_config = str(find_user_pyproject_toml())
499             if config == user_level_config:
500                 out(
501                     "Using configuration from user-level config at "
502                     f"'{user_level_config}'.",
503                     fg="blue",
504                 )
505             elif config_source in (
506                 ParameterSource.DEFAULT,
507                 ParameterSource.DEFAULT_MAP,
508             ):
509                 out("Using configuration from project root.", fg="blue")
510             else:
511                 out(f"Using configuration in '{config}'.", fg="blue")
512
513     error_msg = "Oh no! 💥 💔 💥"
514     if (
515         required_version
516         and required_version != __version__
517         and required_version != __version__.split(".")[0]
518     ):
519         err(
520             f"{error_msg} The required version `{required_version}` does not match"
521             f" the running version `{__version__}`!"
522         )
523         ctx.exit(1)
524     if ipynb and pyi:
525         err("Cannot pass both `pyi` and `ipynb` flags!")
526         ctx.exit(1)
527
528     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
529     if target_version:
530         versions = set(target_version)
531     else:
532         # We'll autodetect later.
533         versions = set()
534     mode = Mode(
535         target_versions=versions,
536         line_length=line_length,
537         is_pyi=pyi,
538         is_ipynb=ipynb,
539         string_normalization=not skip_string_normalization,
540         magic_trailing_comma=not skip_magic_trailing_comma,
541         experimental_string_processing=experimental_string_processing,
542         preview=preview,
543         python_cell_magics=set(python_cell_magics),
544     )
545
546     if code is not None:
547         # Run in quiet mode by default with -c; the extra output isn't useful.
548         # You can still pass -v to get verbose output.
549         quiet = True
550
551     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
552
553     if code is not None:
554         reformat_code(
555             content=code, fast=fast, write_back=write_back, mode=mode, report=report
556         )
557     else:
558         try:
559             sources = get_sources(
560                 ctx=ctx,
561                 src=src,
562                 quiet=quiet,
563                 verbose=verbose,
564                 include=include,
565                 exclude=exclude,
566                 extend_exclude=extend_exclude,
567                 force_exclude=force_exclude,
568                 report=report,
569                 stdin_filename=stdin_filename,
570             )
571         except GitWildMatchPatternError:
572             ctx.exit(1)
573
574         path_empty(
575             sources,
576             "No Python files are present to be formatted. Nothing to do 😴",
577             quiet,
578             verbose,
579             ctx,
580         )
581
582         if len(sources) == 1:
583             reformat_one(
584                 src=sources.pop(),
585                 fast=fast,
586                 write_back=write_back,
587                 mode=mode,
588                 report=report,
589             )
590         else:
591             from black.concurrency import reformat_many
592
593             reformat_many(
594                 sources=sources,
595                 fast=fast,
596                 write_back=write_back,
597                 mode=mode,
598                 report=report,
599                 workers=workers,
600             )
601
602     if verbose or not quiet:
603         if code is None and (verbose or report.change_count or report.failure_count):
604             out()
605         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
606         if code is None:
607             click.echo(str(report), err=True)
608     ctx.exit(report.return_code)
609
610
611 def get_sources(
612     *,
613     ctx: click.Context,
614     src: Tuple[str, ...],
615     quiet: bool,
616     verbose: bool,
617     include: Pattern[str],
618     exclude: Optional[Pattern[str]],
619     extend_exclude: Optional[Pattern[str]],
620     force_exclude: Optional[Pattern[str]],
621     report: "Report",
622     stdin_filename: Optional[str],
623 ) -> Set[Path]:
624     """Compute the set of files to be formatted."""
625     sources: Set[Path] = set()
626
627     if exclude is None:
628         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
629         gitignore = get_gitignore(ctx.obj["root"])
630     else:
631         gitignore = None
632
633     for s in src:
634         if s == "-" and stdin_filename:
635             p = Path(stdin_filename)
636             is_stdin = True
637         else:
638             p = Path(s)
639             is_stdin = False
640
641         if is_stdin or p.is_file():
642             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
643             if normalized_path is None:
644                 continue
645
646             normalized_path = "/" + normalized_path
647             # Hard-exclude any files that matches the `--force-exclude` regex.
648             if force_exclude:
649                 force_exclude_match = force_exclude.search(normalized_path)
650             else:
651                 force_exclude_match = None
652             if force_exclude_match and force_exclude_match.group(0):
653                 report.path_ignored(p, "matches the --force-exclude regular expression")
654                 continue
655
656             if is_stdin:
657                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
658
659             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
660                 verbose=verbose, quiet=quiet
661             ):
662                 continue
663
664             sources.add(p)
665         elif p.is_dir():
666             sources.update(
667                 gen_python_files(
668                     p.iterdir(),
669                     ctx.obj["root"],
670                     include,
671                     exclude,
672                     extend_exclude,
673                     force_exclude,
674                     report,
675                     gitignore,
676                     verbose=verbose,
677                     quiet=quiet,
678                 )
679             )
680         elif s == "-":
681             sources.add(p)
682         else:
683             err(f"invalid path: {s}")
684     return sources
685
686
687 def path_empty(
688     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
689 ) -> None:
690     """
691     Exit if there is no `src` provided for formatting
692     """
693     if not src:
694         if verbose or not quiet:
695             out(msg)
696         ctx.exit(0)
697
698
699 def reformat_code(
700     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
701 ) -> None:
702     """
703     Reformat and print out `content` without spawning child processes.
704     Similar to `reformat_one`, but for string content.
705
706     `fast`, `write_back`, and `mode` options are passed to
707     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
708     """
709     path = Path("<string>")
710     try:
711         changed = Changed.NO
712         if format_stdin_to_stdout(
713             content=content, fast=fast, write_back=write_back, mode=mode
714         ):
715             changed = Changed.YES
716         report.done(path, changed)
717     except Exception as exc:
718         if report.verbose:
719             traceback.print_exc()
720         report.failed(path, str(exc))
721
722
723 # diff-shades depends on being to monkeypatch this function to operate. I know it's
724 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
725 @mypyc_attr(patchable=True)
726 def reformat_one(
727     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
728 ) -> None:
729     """Reformat a single file under `src` without spawning child processes.
730
731     `fast`, `write_back`, and `mode` options are passed to
732     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
733     """
734     try:
735         changed = Changed.NO
736
737         if str(src) == "-":
738             is_stdin = True
739         elif str(src).startswith(STDIN_PLACEHOLDER):
740             is_stdin = True
741             # Use the original name again in case we want to print something
742             # to the user
743             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
744         else:
745             is_stdin = False
746
747         if is_stdin:
748             if src.suffix == ".pyi":
749                 mode = replace(mode, is_pyi=True)
750             elif src.suffix == ".ipynb":
751                 mode = replace(mode, is_ipynb=True)
752             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
753                 changed = Changed.YES
754         else:
755             cache: Cache = {}
756             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
757                 cache = read_cache(mode)
758                 res_src = src.resolve()
759                 res_src_s = str(res_src)
760                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
761                     changed = Changed.CACHED
762             if changed is not Changed.CACHED and format_file_in_place(
763                 src, fast=fast, write_back=write_back, mode=mode
764             ):
765                 changed = Changed.YES
766             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
767                 write_back is WriteBack.CHECK and changed is Changed.NO
768             ):
769                 write_cache(cache, [src], mode)
770         report.done(src, changed)
771     except Exception as exc:
772         if report.verbose:
773             traceback.print_exc()
774         report.failed(src, str(exc))
775
776
777 def format_file_in_place(
778     src: Path,
779     fast: bool,
780     mode: Mode,
781     write_back: WriteBack = WriteBack.NO,
782     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
783 ) -> bool:
784     """Format file under `src` path. Return True if changed.
785
786     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
787     code to the file.
788     `mode` and `fast` options are passed to :func:`format_file_contents`.
789     """
790     if src.suffix == ".pyi":
791         mode = replace(mode, is_pyi=True)
792     elif src.suffix == ".ipynb":
793         mode = replace(mode, is_ipynb=True)
794
795     then = datetime.utcfromtimestamp(src.stat().st_mtime)
796     with open(src, "rb") as buf:
797         src_contents, encoding, newline = decode_bytes(buf.read())
798     try:
799         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
800     except NothingChanged:
801         return False
802     except JSONDecodeError:
803         raise ValueError(
804             f"File '{src}' cannot be parsed as valid Jupyter notebook."
805         ) from None
806
807     if write_back == WriteBack.YES:
808         with open(src, "w", encoding=encoding, newline=newline) as f:
809             f.write(dst_contents)
810     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
811         now = datetime.utcnow()
812         src_name = f"{src}\t{then} +0000"
813         dst_name = f"{src}\t{now} +0000"
814         if mode.is_ipynb:
815             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
816         else:
817             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
818
819         if write_back == WriteBack.COLOR_DIFF:
820             diff_contents = color_diff(diff_contents)
821
822         with lock or nullcontext():
823             f = io.TextIOWrapper(
824                 sys.stdout.buffer,
825                 encoding=encoding,
826                 newline=newline,
827                 write_through=True,
828             )
829             f = wrap_stream_for_windows(f)
830             f.write(diff_contents)
831             f.detach()
832
833     return True
834
835
836 def format_stdin_to_stdout(
837     fast: bool,
838     *,
839     content: Optional[str] = None,
840     write_back: WriteBack = WriteBack.NO,
841     mode: Mode,
842 ) -> bool:
843     """Format file on stdin. Return True if changed.
844
845     If content is None, it's read from sys.stdin.
846
847     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
848     write a diff to stdout. The `mode` argument is passed to
849     :func:`format_file_contents`.
850     """
851     then = datetime.utcnow()
852
853     if content is None:
854         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
855     else:
856         src, encoding, newline = content, "utf-8", ""
857
858     dst = src
859     try:
860         dst = format_file_contents(src, fast=fast, mode=mode)
861         return True
862
863     except NothingChanged:
864         return False
865
866     finally:
867         f = io.TextIOWrapper(
868             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
869         )
870         if write_back == WriteBack.YES:
871             # Make sure there's a newline after the content
872             if dst and dst[-1] != "\n":
873                 dst += "\n"
874             f.write(dst)
875         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
876             now = datetime.utcnow()
877             src_name = f"STDIN\t{then} +0000"
878             dst_name = f"STDOUT\t{now} +0000"
879             d = diff(src, dst, src_name, dst_name)
880             if write_back == WriteBack.COLOR_DIFF:
881                 d = color_diff(d)
882                 f = wrap_stream_for_windows(f)
883             f.write(d)
884         f.detach()
885
886
887 def check_stability_and_equivalence(
888     src_contents: str, dst_contents: str, *, mode: Mode
889 ) -> None:
890     """Perform stability and equivalence checks.
891
892     Raise AssertionError if source and destination contents are not
893     equivalent, or if a second pass of the formatter would format the
894     content differently.
895     """
896     assert_equivalent(src_contents, dst_contents)
897     assert_stable(src_contents, dst_contents, mode=mode)
898
899
900 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
901     """Reformat contents of a file and return new contents.
902
903     If `fast` is False, additionally confirm that the reformatted code is
904     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
905     `mode` is passed to :func:`format_str`.
906     """
907     if not src_contents.strip():
908         raise NothingChanged
909
910     if mode.is_ipynb:
911         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
912     else:
913         dst_contents = format_str(src_contents, mode=mode)
914     if src_contents == dst_contents:
915         raise NothingChanged
916
917     if not fast and not mode.is_ipynb:
918         # Jupyter notebooks will already have been checked above.
919         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
920     return dst_contents
921
922
923 def validate_cell(src: str, mode: Mode) -> None:
924     """Check that cell does not already contain TransformerManager transformations,
925     or non-Python cell magics, which might cause tokenizer_rt to break because of
926     indentations.
927
928     If a cell contains ``!ls``, then it'll be transformed to
929     ``get_ipython().system('ls')``. However, if the cell originally contained
930     ``get_ipython().system('ls')``, then it would get transformed in the same way:
931
932         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
933         "get_ipython().system('ls')\n"
934         >>> TransformerManager().transform_cell("!ls")
935         "get_ipython().system('ls')\n"
936
937     Due to the impossibility of safely roundtripping in such situations, cells
938     containing transformed magics will be ignored.
939     """
940     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
941         raise NothingChanged
942     if (
943         src[:2] == "%%"
944         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
945     ):
946         raise NothingChanged
947
948
949 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
950     """Format code in given cell of Jupyter notebook.
951
952     General idea is:
953
954       - if cell has trailing semicolon, remove it;
955       - if cell has IPython magics, mask them;
956       - format cell;
957       - reinstate IPython magics;
958       - reinstate trailing semicolon (if originally present);
959       - strip trailing newlines.
960
961     Cells with syntax errors will not be processed, as they
962     could potentially be automagics or multi-line magics, which
963     are currently not supported.
964     """
965     validate_cell(src, mode)
966     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
967         src
968     )
969     try:
970         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
971     except SyntaxError:
972         raise NothingChanged from None
973     masked_dst = format_str(masked_src, mode=mode)
974     if not fast:
975         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
976     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
977     dst = put_trailing_semicolon_back(
978         dst_without_trailing_semicolon, has_trailing_semicolon
979     )
980     dst = dst.rstrip("\n")
981     if dst == src:
982         raise NothingChanged from None
983     return dst
984
985
986 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
987     """If notebook is marked as non-Python, don't format it.
988
989     All notebook metadata fields are optional, see
990     https://nbformat.readthedocs.io/en/latest/format_description.html. So
991     if a notebook has empty metadata, we will try to parse it anyway.
992     """
993     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
994     if language is not None and language != "python":
995         raise NothingChanged from None
996
997
998 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
999     """Format Jupyter notebook.
1000
1001     Operate cell-by-cell, only on code cells, only for Python notebooks.
1002     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1003     """
1004     trailing_newline = src_contents[-1] == "\n"
1005     modified = False
1006     nb = json.loads(src_contents)
1007     validate_metadata(nb)
1008     for cell in nb["cells"]:
1009         if cell.get("cell_type", None) == "code":
1010             try:
1011                 src = "".join(cell["source"])
1012                 dst = format_cell(src, fast=fast, mode=mode)
1013             except NothingChanged:
1014                 pass
1015             else:
1016                 cell["source"] = dst.splitlines(keepends=True)
1017                 modified = True
1018     if modified:
1019         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1020         if trailing_newline:
1021             dst_contents = dst_contents + "\n"
1022         return dst_contents
1023     else:
1024         raise NothingChanged
1025
1026
1027 def format_str(src_contents: str, *, mode: Mode) -> str:
1028     """Reformat a string and return new contents.
1029
1030     `mode` determines formatting options, such as how many characters per line are
1031     allowed.  Example:
1032
1033     >>> import black
1034     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1035     def f(arg: str = "") -> None:
1036         ...
1037
1038     A more complex example:
1039
1040     >>> print(
1041     ...   black.format_str(
1042     ...     "def f(arg:str='')->None: hey",
1043     ...     mode=black.Mode(
1044     ...       target_versions={black.TargetVersion.PY36},
1045     ...       line_length=10,
1046     ...       string_normalization=False,
1047     ...       is_pyi=False,
1048     ...     ),
1049     ...   ),
1050     ... )
1051     def f(
1052         arg: str = '',
1053     ) -> None:
1054         hey
1055
1056     """
1057     dst_contents = _format_str_once(src_contents, mode=mode)
1058     # Forced second pass to work around optional trailing commas (becoming
1059     # forced trailing commas on pass 2) interacting differently with optional
1060     # parentheses.  Admittedly ugly.
1061     if src_contents != dst_contents:
1062         return _format_str_once(dst_contents, mode=mode)
1063     return dst_contents
1064
1065
1066 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1067     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1068     dst_contents = []
1069     if mode.target_versions:
1070         versions = mode.target_versions
1071     else:
1072         future_imports = get_future_imports(src_node)
1073         versions = detect_target_versions(src_node, future_imports=future_imports)
1074
1075     normalize_fmt_off(src_node, preview=mode.preview)
1076     lines = LineGenerator(mode=mode)
1077     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1078     empty_line = Line(mode=mode)
1079     after = 0
1080     split_line_features = {
1081         feature
1082         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1083         if supports_feature(versions, feature)
1084     }
1085     for current_line in lines.visit(src_node):
1086         dst_contents.append(str(empty_line) * after)
1087         before, after = elt.maybe_empty_lines(current_line)
1088         dst_contents.append(str(empty_line) * before)
1089         for line in transform_line(
1090             current_line, mode=mode, features=split_line_features
1091         ):
1092             dst_contents.append(str(line))
1093     return "".join(dst_contents)
1094
1095
1096 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1097     """Return a tuple of (decoded_contents, encoding, newline).
1098
1099     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1100     universal newlines (i.e. only contains LF).
1101     """
1102     srcbuf = io.BytesIO(src)
1103     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1104     if not lines:
1105         return "", encoding, "\n"
1106
1107     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1108     srcbuf.seek(0)
1109     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1110         return tiow.read(), encoding, newline
1111
1112
1113 def get_features_used(  # noqa: C901
1114     node: Node, *, future_imports: Optional[Set[str]] = None
1115 ) -> Set[Feature]:
1116     """Return a set of (relatively) new Python features used in this file.
1117
1118     Currently looking for:
1119     - f-strings;
1120     - self-documenting expressions in f-strings (f"{x=}");
1121     - underscores in numeric literals;
1122     - trailing commas after * or ** in function signatures and calls;
1123     - positional only arguments in function signatures and lambdas;
1124     - assignment expression;
1125     - relaxed decorator syntax;
1126     - usage of __future__ flags (annotations);
1127     - print / exec statements;
1128     """
1129     features: Set[Feature] = set()
1130     if future_imports:
1131         features |= {
1132             FUTURE_FLAG_TO_FEATURE[future_import]
1133             for future_import in future_imports
1134             if future_import in FUTURE_FLAG_TO_FEATURE
1135         }
1136
1137     for n in node.pre_order():
1138         if is_string_token(n):
1139             value_head = n.value[:2]
1140             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1141                 features.add(Feature.F_STRINGS)
1142                 if Feature.DEBUG_F_STRINGS not in features:
1143                     for span_beg, span_end in iter_fexpr_spans(n.value):
1144                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1145                             features.add(Feature.DEBUG_F_STRINGS)
1146                             break
1147
1148         elif is_number_token(n):
1149             if "_" in n.value:
1150                 features.add(Feature.NUMERIC_UNDERSCORES)
1151
1152         elif n.type == token.SLASH:
1153             if n.parent and n.parent.type in {
1154                 syms.typedargslist,
1155                 syms.arglist,
1156                 syms.varargslist,
1157             }:
1158                 features.add(Feature.POS_ONLY_ARGUMENTS)
1159
1160         elif n.type == token.COLONEQUAL:
1161             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1162
1163         elif n.type == syms.decorator:
1164             if len(n.children) > 1 and not is_simple_decorator_expression(
1165                 n.children[1]
1166             ):
1167                 features.add(Feature.RELAXED_DECORATORS)
1168
1169         elif (
1170             n.type in {syms.typedargslist, syms.arglist}
1171             and n.children
1172             and n.children[-1].type == token.COMMA
1173         ):
1174             if n.type == syms.typedargslist:
1175                 feature = Feature.TRAILING_COMMA_IN_DEF
1176             else:
1177                 feature = Feature.TRAILING_COMMA_IN_CALL
1178
1179             for ch in n.children:
1180                 if ch.type in STARS:
1181                     features.add(feature)
1182
1183                 if ch.type == syms.argument:
1184                     for argch in ch.children:
1185                         if argch.type in STARS:
1186                             features.add(feature)
1187
1188         elif (
1189             n.type in {syms.return_stmt, syms.yield_expr}
1190             and len(n.children) >= 2
1191             and n.children[1].type == syms.testlist_star_expr
1192             and any(child.type == syms.star_expr for child in n.children[1].children)
1193         ):
1194             features.add(Feature.UNPACKING_ON_FLOW)
1195
1196         elif (
1197             n.type == syms.annassign
1198             and len(n.children) >= 4
1199             and n.children[3].type == syms.testlist_star_expr
1200         ):
1201             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1202
1203         elif (
1204             n.type == syms.except_clause
1205             and len(n.children) >= 2
1206             and n.children[1].type == token.STAR
1207         ):
1208             features.add(Feature.EXCEPT_STAR)
1209
1210         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1211             child.type == syms.star_expr for child in n.children
1212         ):
1213             features.add(Feature.VARIADIC_GENERICS)
1214
1215         elif (
1216             n.type == syms.tname_star
1217             and len(n.children) == 3
1218             and n.children[2].type == syms.star_expr
1219         ):
1220             features.add(Feature.VARIADIC_GENERICS)
1221
1222     return features
1223
1224
1225 def detect_target_versions(
1226     node: Node, *, future_imports: Optional[Set[str]] = None
1227 ) -> Set[TargetVersion]:
1228     """Detect the version to target based on the nodes used."""
1229     features = get_features_used(node, future_imports=future_imports)
1230     return {
1231         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1232     }
1233
1234
1235 def get_future_imports(node: Node) -> Set[str]:
1236     """Return a set of __future__ imports in the file."""
1237     imports: Set[str] = set()
1238
1239     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1240         for child in children:
1241             if isinstance(child, Leaf):
1242                 if child.type == token.NAME:
1243                     yield child.value
1244
1245             elif child.type == syms.import_as_name:
1246                 orig_name = child.children[0]
1247                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1248                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1249                 yield orig_name.value
1250
1251             elif child.type == syms.import_as_names:
1252                 yield from get_imports_from_children(child.children)
1253
1254             else:
1255                 raise AssertionError("Invalid syntax parsing imports")
1256
1257     for child in node.children:
1258         if child.type != syms.simple_stmt:
1259             break
1260
1261         first_child = child.children[0]
1262         if isinstance(first_child, Leaf):
1263             # Continue looking if we see a docstring; otherwise stop.
1264             if (
1265                 len(child.children) == 2
1266                 and first_child.type == token.STRING
1267                 and child.children[1].type == token.NEWLINE
1268             ):
1269                 continue
1270
1271             break
1272
1273         elif first_child.type == syms.import_from:
1274             module_name = first_child.children[1]
1275             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1276                 break
1277
1278             imports |= set(get_imports_from_children(first_child.children[3:]))
1279         else:
1280             break
1281
1282     return imports
1283
1284
1285 def assert_equivalent(src: str, dst: str) -> None:
1286     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1287     try:
1288         src_ast = parse_ast(src)
1289     except Exception as exc:
1290         raise AssertionError(
1291             "cannot use --safe with this file; failed to parse source file AST: "
1292             f"{exc}\n"
1293             "This could be caused by running Black with an older Python version "
1294             "that does not support new syntax used in your source file."
1295         ) from exc
1296
1297     try:
1298         dst_ast = parse_ast(dst)
1299     except Exception as exc:
1300         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1301         raise AssertionError(
1302             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1303             "Please report a bug on https://github.com/psf/black/issues.  "
1304             f"This invalid output might be helpful: {log}"
1305         ) from None
1306
1307     src_ast_str = "\n".join(stringify_ast(src_ast))
1308     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1309     if src_ast_str != dst_ast_str:
1310         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1311         raise AssertionError(
1312             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1313             " source.  Please report a bug on "
1314             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1315         ) from None
1316
1317
1318 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1319     """Raise AssertionError if `dst` reformats differently the second time."""
1320     # We shouldn't call format_str() here, because that formats the string
1321     # twice and may hide a bug where we bounce back and forth between two
1322     # versions.
1323     newdst = _format_str_once(dst, mode=mode)
1324     if dst != newdst:
1325         log = dump_to_file(
1326             str(mode),
1327             diff(src, dst, "source", "first pass"),
1328             diff(dst, newdst, "first pass", "second pass"),
1329         )
1330         raise AssertionError(
1331             "INTERNAL ERROR: Black produced different code on the second pass of the"
1332             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1333             f"  This diff might be helpful: {log}"
1334         ) from None
1335
1336
1337 @contextmanager
1338 def nullcontext() -> Iterator[None]:
1339     """Return an empty context manager.
1340
1341     To be used like `nullcontext` in Python 3.7.
1342     """
1343     yield
1344
1345
1346 def patch_click() -> None:
1347     """Make Click not crash on Python 3.6 with LANG=C.
1348
1349     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1350     default which restricts paths that it can access during the lifetime of the
1351     application.  Click refuses to work in this scenario by raising a RuntimeError.
1352
1353     In case of Black the likelihood that non-ASCII characters are going to be used in
1354     file paths is minimal since it's Python source code.  Moreover, this crash was
1355     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1356     """
1357     modules: List[Any] = []
1358     try:
1359         from click import core
1360     except ImportError:
1361         pass
1362     else:
1363         modules.append(core)
1364     try:
1365         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1366         # older versions installed.
1367         from click import _unicodefun  # type: ignore
1368     except ImportError:
1369         pass
1370     else:
1371         modules.append(_unicodefun)
1372
1373     for module in modules:
1374         if hasattr(module, "_verify_python3_env"):
1375             module._verify_python3_env = lambda: None  # type: ignore
1376         if hasattr(module, "_verify_python_env"):
1377             module._verify_python_env = lambda: None  # type: ignore
1378
1379
1380 def patched_main() -> None:
1381     if sys.platform == "win32" and getattr(sys, "frozen", False):
1382         from multiprocessing import freeze_support
1383
1384         freeze_support()
1385
1386     patch_click()
1387     main()
1388
1389
1390 if __name__ == "__main__":
1391     patched_main()