]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

86a0b63744264c85dc23f8702768d333c047f221
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
34
35 from _black_version import version as __version__
36 from black.cache import Cache, get_cache_info, read_cache, write_cache
37 from black.comments import normalize_fmt_off
38 from black.const import (
39     DEFAULT_EXCLUDES,
40     DEFAULT_INCLUDES,
41     DEFAULT_LINE_LENGTH,
42     STDIN_PLACEHOLDER,
43 )
44 from black.files import (
45     find_project_root,
46     find_pyproject_toml,
47     find_user_pyproject_toml,
48     gen_python_files,
49     get_gitignore,
50     normalize_path_maybe_ignore,
51     parse_pyproject_toml,
52     wrap_stream_for_windows,
53 )
54 from black.handle_ipynb_magics import (
55     PYTHON_CELL_MAGICS,
56     TRANSFORMED_MAGICS,
57     jupyter_dependencies_are_installed,
58     mask_cell,
59     put_trailing_semicolon_back,
60     remove_trailing_semicolon,
61     unmask_cell,
62 )
63 from black.linegen import LN, LineGenerator, transform_line
64 from black.lines import EmptyLineTracker, Line
65 from black.mode import (
66     FUTURE_FLAG_TO_FEATURE,
67     VERSION_TO_FEATURES,
68     Feature,
69     Mode,
70     TargetVersion,
71     supports_feature,
72 )
73 from black.nodes import (
74     STARS,
75     is_number_token,
76     is_simple_decorator_expression,
77     is_string_token,
78     syms,
79 )
80 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
81 from black.parsing import InvalidInput  # noqa F401
82 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
83 from black.report import Changed, NothingChanged, Report
84 from black.trans import iter_fexpr_spans
85 from blib2to3.pgen2 import token
86 from blib2to3.pytree import Leaf, Node
87
88 COMPILED = Path(__file__).suffix in (".pyd", ".so")
89
90 # types
91 FileContent = str
92 Encoding = str
93 NewLine = str
94
95
96 class WriteBack(Enum):
97     NO = 0
98     YES = 1
99     DIFF = 2
100     CHECK = 3
101     COLOR_DIFF = 4
102
103     @classmethod
104     def from_configuration(
105         cls, *, check: bool, diff: bool, color: bool = False
106     ) -> "WriteBack":
107         if check and not diff:
108             return cls.CHECK
109
110         if diff and color:
111             return cls.COLOR_DIFF
112
113         return cls.DIFF if diff else cls.YES
114
115
116 # Legacy name, left for integrations.
117 FileMode = Mode
118
119
120 def read_pyproject_toml(
121     ctx: click.Context, param: click.Parameter, value: Optional[str]
122 ) -> Optional[str]:
123     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
124
125     Returns the path to a successfully found and read configuration file, None
126     otherwise.
127     """
128     if not value:
129         value = find_pyproject_toml(ctx.params.get("src", ()))
130         if value is None:
131             return None
132
133     try:
134         config = parse_pyproject_toml(value)
135     except (OSError, ValueError) as e:
136         raise click.FileError(
137             filename=value, hint=f"Error reading configuration file: {e}"
138         ) from None
139
140     if not config:
141         return None
142     else:
143         # Sanitize the values to be Click friendly. For more information please see:
144         # https://github.com/psf/black/issues/1458
145         # https://github.com/pallets/click/issues/1567
146         config = {
147             k: str(v) if not isinstance(v, (list, dict)) else v
148             for k, v in config.items()
149         }
150
151     target_version = config.get("target_version")
152     if target_version is not None and not isinstance(target_version, list):
153         raise click.BadOptionUsage(
154             "target-version", "Config key target-version must be a list"
155         )
156
157     default_map: Dict[str, Any] = {}
158     if ctx.default_map:
159         default_map.update(ctx.default_map)
160     default_map.update(config)
161
162     ctx.default_map = default_map
163     return value
164
165
166 def target_version_option_callback(
167     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
168 ) -> List[TargetVersion]:
169     """Compute the target versions from a --target-version flag.
170
171     This is its own function because mypy couldn't infer the type correctly
172     when it was a lambda, causing mypyc trouble.
173     """
174     return [TargetVersion[val.upper()] for val in v]
175
176
177 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
178     """Compile a regular expression string in `regex`.
179
180     If it contains newlines, use verbose mode.
181     """
182     if "\n" in regex:
183         regex = "(?x)" + regex
184     compiled: Pattern[str] = re.compile(regex)
185     return compiled
186
187
188 def validate_regex(
189     ctx: click.Context,
190     param: click.Parameter,
191     value: Optional[str],
192 ) -> Optional[Pattern[str]]:
193     try:
194         return re_compile_maybe_verbose(value) if value is not None else None
195     except re.error as e:
196         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
197
198
199 @click.command(
200     context_settings={"help_option_names": ["-h", "--help"]},
201     # While Click does set this field automatically using the docstring, mypyc
202     # (annoyingly) strips 'em so we need to set it here too.
203     help="The uncompromising code formatter.",
204 )
205 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
206 @click.option(
207     "-l",
208     "--line-length",
209     type=int,
210     default=DEFAULT_LINE_LENGTH,
211     help="How many characters per line to allow.",
212     show_default=True,
213 )
214 @click.option(
215     "-t",
216     "--target-version",
217     type=click.Choice([v.name.lower() for v in TargetVersion]),
218     callback=target_version_option_callback,
219     multiple=True,
220     help=(
221         "Python versions that should be supported by Black's output. [default: per-file"
222         " auto-detection]"
223     ),
224 )
225 @click.option(
226     "--pyi",
227     is_flag=True,
228     help=(
229         "Format all input files like typing stubs regardless of file extension (useful"
230         " when piping source on standard input)."
231     ),
232 )
233 @click.option(
234     "--ipynb",
235     is_flag=True,
236     help=(
237         "Format all input files like Jupyter Notebooks regardless of file extension "
238         "(useful when piping source on standard input)."
239     ),
240 )
241 @click.option(
242     "--python-cell-magics",
243     multiple=True,
244     help=(
245         "When processing Jupyter Notebooks, add the given magic to the list"
246         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
247         " Useful for formatting cells with custom python magics."
248     ),
249     default=[],
250 )
251 @click.option(
252     "-S",
253     "--skip-string-normalization",
254     is_flag=True,
255     help="Don't normalize string quotes or prefixes.",
256 )
257 @click.option(
258     "-C",
259     "--skip-magic-trailing-comma",
260     is_flag=True,
261     help="Don't use trailing commas as a reason to split lines.",
262 )
263 @click.option(
264     "--experimental-string-processing",
265     is_flag=True,
266     hidden=True,
267     help="(DEPRECATED and now included in --preview) Normalize string literals.",
268 )
269 @click.option(
270     "--preview",
271     is_flag=True,
272     help=(
273         "Enable potentially disruptive style changes that may be added to Black's main"
274         " functionality in the next major release."
275     ),
276 )
277 @click.option(
278     "--check",
279     is_flag=True,
280     help=(
281         "Don't write the files back, just return the status. Return code 0 means"
282         " nothing would change. Return code 1 means some files would be reformatted."
283         " Return code 123 means there was an internal error."
284     ),
285 )
286 @click.option(
287     "--diff",
288     is_flag=True,
289     help="Don't write the files back, just output a diff for each file on stdout.",
290 )
291 @click.option(
292     "--color/--no-color",
293     is_flag=True,
294     help="Show colored diff. Only applies when `--diff` is given.",
295 )
296 @click.option(
297     "--fast/--safe",
298     is_flag=True,
299     help="If --fast given, skip temporary sanity checks. [default: --safe]",
300 )
301 @click.option(
302     "--required-version",
303     type=str,
304     help=(
305         "Require a specific version of Black to be running (useful for unifying results"
306         " across many environments e.g. with a pyproject.toml file). It can be"
307         " either a major version number or an exact version."
308     ),
309 )
310 @click.option(
311     "--include",
312     type=str,
313     default=DEFAULT_INCLUDES,
314     callback=validate_regex,
315     help=(
316         "A regular expression that matches files and directories that should be"
317         " included on recursive searches. An empty value means all files are included"
318         " regardless of the name. Use forward slashes for directories on all platforms"
319         " (Windows, too). Exclusions are calculated first, inclusions later."
320     ),
321     show_default=True,
322 )
323 @click.option(
324     "--exclude",
325     type=str,
326     callback=validate_regex,
327     help=(
328         "A regular expression that matches files and directories that should be"
329         " excluded on recursive searches. An empty value means no paths are excluded."
330         " Use forward slashes for directories on all platforms (Windows, too)."
331         " Exclusions are calculated first, inclusions later. [default:"
332         f" {DEFAULT_EXCLUDES}]"
333     ),
334     show_default=False,
335 )
336 @click.option(
337     "--extend-exclude",
338     type=str,
339     callback=validate_regex,
340     help=(
341         "Like --exclude, but adds additional files and directories on top of the"
342         " excluded ones. (Useful if you simply want to add to the default)"
343     ),
344 )
345 @click.option(
346     "--force-exclude",
347     type=str,
348     callback=validate_regex,
349     help=(
350         "Like --exclude, but files and directories matching this regex will be "
351         "excluded even when they are passed explicitly as arguments."
352     ),
353 )
354 @click.option(
355     "--stdin-filename",
356     type=str,
357     help=(
358         "The name of the file when passing it through stdin. Useful to make "
359         "sure Black will respect --force-exclude option on some "
360         "editors that rely on using stdin."
361     ),
362 )
363 @click.option(
364     "-W",
365     "--workers",
366     type=click.IntRange(min=1),
367     default=None,
368     help="Number of parallel workers [default: number of CPUs in the system]",
369 )
370 @click.option(
371     "-q",
372     "--quiet",
373     is_flag=True,
374     help=(
375         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
376         " those with 2>/dev/null."
377     ),
378 )
379 @click.option(
380     "-v",
381     "--verbose",
382     is_flag=True,
383     help=(
384         "Also emit messages to stderr about files that were not changed or were ignored"
385         " due to exclusion patterns."
386     ),
387 )
388 @click.version_option(
389     version=__version__,
390     message=(
391         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
392         f"Python ({platform.python_implementation()}) {platform.python_version()}"
393     ),
394 )
395 @click.argument(
396     "src",
397     nargs=-1,
398     type=click.Path(
399         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
400     ),
401     is_eager=True,
402     metavar="SRC ...",
403 )
404 @click.option(
405     "--config",
406     type=click.Path(
407         exists=True,
408         file_okay=True,
409         dir_okay=False,
410         readable=True,
411         allow_dash=False,
412         path_type=str,
413     ),
414     is_eager=True,
415     callback=read_pyproject_toml,
416     help="Read configuration from FILE path.",
417 )
418 @click.pass_context
419 def main(  # noqa: C901
420     ctx: click.Context,
421     code: Optional[str],
422     line_length: int,
423     target_version: List[TargetVersion],
424     check: bool,
425     diff: bool,
426     color: bool,
427     fast: bool,
428     pyi: bool,
429     ipynb: bool,
430     python_cell_magics: Sequence[str],
431     skip_string_normalization: bool,
432     skip_magic_trailing_comma: bool,
433     experimental_string_processing: bool,
434     preview: bool,
435     quiet: bool,
436     verbose: bool,
437     required_version: Optional[str],
438     include: Pattern[str],
439     exclude: Optional[Pattern[str]],
440     extend_exclude: Optional[Pattern[str]],
441     force_exclude: Optional[Pattern[str]],
442     stdin_filename: Optional[str],
443     workers: Optional[int],
444     src: Tuple[str, ...],
445     config: Optional[str],
446 ) -> None:
447     """The uncompromising code formatter."""
448     ctx.ensure_object(dict)
449
450     if src and code is not None:
451         out(
452             main.get_usage(ctx)
453             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
454         )
455         ctx.exit(1)
456     if not src and code is None:
457         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
458         ctx.exit(1)
459
460     root, method = (
461         find_project_root(src, stdin_filename) if code is None else (None, None)
462     )
463     ctx.obj["root"] = root
464
465     if verbose:
466         if root:
467             out(
468                 f"Identified `{root}` as project root containing a {method}.",
469                 fg="blue",
470             )
471
472             normalized = [
473                 (source, source)
474                 if source == "-"
475                 else (normalize_path_maybe_ignore(Path(source), root), source)
476                 for source in src
477             ]
478             srcs_string = ", ".join(
479                 [
480                     f'"{_norm}"'
481                     if _norm
482                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
483                     for _norm, source in normalized
484                 ]
485             )
486             out(f"Sources to be formatted: {srcs_string}", fg="blue")
487
488         if config:
489             config_source = ctx.get_parameter_source("config")
490             user_level_config = str(find_user_pyproject_toml())
491             if config == user_level_config:
492                 out(
493                     "Using configuration from user-level config at "
494                     f"'{user_level_config}'.",
495                     fg="blue",
496                 )
497             elif config_source in (
498                 ParameterSource.DEFAULT,
499                 ParameterSource.DEFAULT_MAP,
500             ):
501                 out("Using configuration from project root.", fg="blue")
502             else:
503                 out(f"Using configuration in '{config}'.", fg="blue")
504
505     error_msg = "Oh no! 💥 💔 💥"
506     if (
507         required_version
508         and required_version != __version__
509         and required_version != __version__.split(".")[0]
510     ):
511         err(
512             f"{error_msg} The required version `{required_version}` does not match"
513             f" the running version `{__version__}`!"
514         )
515         ctx.exit(1)
516     if ipynb and pyi:
517         err("Cannot pass both `pyi` and `ipynb` flags!")
518         ctx.exit(1)
519
520     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
521     if target_version:
522         versions = set(target_version)
523     else:
524         # We'll autodetect later.
525         versions = set()
526     mode = Mode(
527         target_versions=versions,
528         line_length=line_length,
529         is_pyi=pyi,
530         is_ipynb=ipynb,
531         string_normalization=not skip_string_normalization,
532         magic_trailing_comma=not skip_magic_trailing_comma,
533         experimental_string_processing=experimental_string_processing,
534         preview=preview,
535         python_cell_magics=set(python_cell_magics),
536     )
537
538     if code is not None:
539         # Run in quiet mode by default with -c; the extra output isn't useful.
540         # You can still pass -v to get verbose output.
541         quiet = True
542
543     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
544
545     if code is not None:
546         reformat_code(
547             content=code, fast=fast, write_back=write_back, mode=mode, report=report
548         )
549     else:
550         try:
551             sources = get_sources(
552                 ctx=ctx,
553                 src=src,
554                 quiet=quiet,
555                 verbose=verbose,
556                 include=include,
557                 exclude=exclude,
558                 extend_exclude=extend_exclude,
559                 force_exclude=force_exclude,
560                 report=report,
561                 stdin_filename=stdin_filename,
562             )
563         except GitWildMatchPatternError:
564             ctx.exit(1)
565
566         path_empty(
567             sources,
568             "No Python files are present to be formatted. Nothing to do 😴",
569             quiet,
570             verbose,
571             ctx,
572         )
573
574         if len(sources) == 1:
575             reformat_one(
576                 src=sources.pop(),
577                 fast=fast,
578                 write_back=write_back,
579                 mode=mode,
580                 report=report,
581             )
582         else:
583             from black.concurrency import reformat_many
584
585             reformat_many(
586                 sources=sources,
587                 fast=fast,
588                 write_back=write_back,
589                 mode=mode,
590                 report=report,
591                 workers=workers,
592             )
593
594     if verbose or not quiet:
595         if code is None and (verbose or report.change_count or report.failure_count):
596             out()
597         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
598         if code is None:
599             click.echo(str(report), err=True)
600     ctx.exit(report.return_code)
601
602
603 def get_sources(
604     *,
605     ctx: click.Context,
606     src: Tuple[str, ...],
607     quiet: bool,
608     verbose: bool,
609     include: Pattern[str],
610     exclude: Optional[Pattern[str]],
611     extend_exclude: Optional[Pattern[str]],
612     force_exclude: Optional[Pattern[str]],
613     report: "Report",
614     stdin_filename: Optional[str],
615 ) -> Set[Path]:
616     """Compute the set of files to be formatted."""
617     sources: Set[Path] = set()
618     root = ctx.obj["root"]
619
620     for s in src:
621         if s == "-" and stdin_filename:
622             p = Path(stdin_filename)
623             is_stdin = True
624         else:
625             p = Path(s)
626             is_stdin = False
627
628         if is_stdin or p.is_file():
629             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
630             if normalized_path is None:
631                 continue
632
633             normalized_path = "/" + normalized_path
634             # Hard-exclude any files that matches the `--force-exclude` regex.
635             if force_exclude:
636                 force_exclude_match = force_exclude.search(normalized_path)
637             else:
638                 force_exclude_match = None
639             if force_exclude_match and force_exclude_match.group(0):
640                 report.path_ignored(p, "matches the --force-exclude regular expression")
641                 continue
642
643             if is_stdin:
644                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
645
646             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
647                 verbose=verbose, quiet=quiet
648             ):
649                 continue
650
651             sources.add(p)
652         elif p.is_dir():
653             if exclude is None:
654                 exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
655                 gitignore = get_gitignore(root)
656             else:
657                 gitignore = None
658             sources.update(
659                 gen_python_files(
660                     p.iterdir(),
661                     ctx.obj["root"],
662                     include,
663                     exclude,
664                     extend_exclude,
665                     force_exclude,
666                     report,
667                     gitignore,
668                     verbose=verbose,
669                     quiet=quiet,
670                 )
671             )
672         elif s == "-":
673             sources.add(p)
674         else:
675             err(f"invalid path: {s}")
676     return sources
677
678
679 def path_empty(
680     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
681 ) -> None:
682     """
683     Exit if there is no `src` provided for formatting
684     """
685     if not src:
686         if verbose or not quiet:
687             out(msg)
688         ctx.exit(0)
689
690
691 def reformat_code(
692     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
693 ) -> None:
694     """
695     Reformat and print out `content` without spawning child processes.
696     Similar to `reformat_one`, but for string content.
697
698     `fast`, `write_back`, and `mode` options are passed to
699     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
700     """
701     path = Path("<string>")
702     try:
703         changed = Changed.NO
704         if format_stdin_to_stdout(
705             content=content, fast=fast, write_back=write_back, mode=mode
706         ):
707             changed = Changed.YES
708         report.done(path, changed)
709     except Exception as exc:
710         if report.verbose:
711             traceback.print_exc()
712         report.failed(path, str(exc))
713
714
715 # diff-shades depends on being to monkeypatch this function to operate. I know it's
716 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
717 @mypyc_attr(patchable=True)
718 def reformat_one(
719     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
720 ) -> None:
721     """Reformat a single file under `src` without spawning child processes.
722
723     `fast`, `write_back`, and `mode` options are passed to
724     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
725     """
726     try:
727         changed = Changed.NO
728
729         if str(src) == "-":
730             is_stdin = True
731         elif str(src).startswith(STDIN_PLACEHOLDER):
732             is_stdin = True
733             # Use the original name again in case we want to print something
734             # to the user
735             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
736         else:
737             is_stdin = False
738
739         if is_stdin:
740             if src.suffix == ".pyi":
741                 mode = replace(mode, is_pyi=True)
742             elif src.suffix == ".ipynb":
743                 mode = replace(mode, is_ipynb=True)
744             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
745                 changed = Changed.YES
746         else:
747             cache: Cache = {}
748             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
749                 cache = read_cache(mode)
750                 res_src = src.resolve()
751                 res_src_s = str(res_src)
752                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
753                     changed = Changed.CACHED
754             if changed is not Changed.CACHED and format_file_in_place(
755                 src, fast=fast, write_back=write_back, mode=mode
756             ):
757                 changed = Changed.YES
758             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
759                 write_back is WriteBack.CHECK and changed is Changed.NO
760             ):
761                 write_cache(cache, [src], mode)
762         report.done(src, changed)
763     except Exception as exc:
764         if report.verbose:
765             traceback.print_exc()
766         report.failed(src, str(exc))
767
768
769 def format_file_in_place(
770     src: Path,
771     fast: bool,
772     mode: Mode,
773     write_back: WriteBack = WriteBack.NO,
774     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
775 ) -> bool:
776     """Format file under `src` path. Return True if changed.
777
778     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
779     code to the file.
780     `mode` and `fast` options are passed to :func:`format_file_contents`.
781     """
782     if src.suffix == ".pyi":
783         mode = replace(mode, is_pyi=True)
784     elif src.suffix == ".ipynb":
785         mode = replace(mode, is_ipynb=True)
786
787     then = datetime.utcfromtimestamp(src.stat().st_mtime)
788     with open(src, "rb") as buf:
789         src_contents, encoding, newline = decode_bytes(buf.read())
790     try:
791         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
792     except NothingChanged:
793         return False
794     except JSONDecodeError:
795         raise ValueError(
796             f"File '{src}' cannot be parsed as valid Jupyter notebook."
797         ) from None
798
799     if write_back == WriteBack.YES:
800         with open(src, "w", encoding=encoding, newline=newline) as f:
801             f.write(dst_contents)
802     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
803         now = datetime.utcnow()
804         src_name = f"{src}\t{then} +0000"
805         dst_name = f"{src}\t{now} +0000"
806         if mode.is_ipynb:
807             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
808         else:
809             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
810
811         if write_back == WriteBack.COLOR_DIFF:
812             diff_contents = color_diff(diff_contents)
813
814         with lock or nullcontext():
815             f = io.TextIOWrapper(
816                 sys.stdout.buffer,
817                 encoding=encoding,
818                 newline=newline,
819                 write_through=True,
820             )
821             f = wrap_stream_for_windows(f)
822             f.write(diff_contents)
823             f.detach()
824
825     return True
826
827
828 def format_stdin_to_stdout(
829     fast: bool,
830     *,
831     content: Optional[str] = None,
832     write_back: WriteBack = WriteBack.NO,
833     mode: Mode,
834 ) -> bool:
835     """Format file on stdin. Return True if changed.
836
837     If content is None, it's read from sys.stdin.
838
839     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
840     write a diff to stdout. The `mode` argument is passed to
841     :func:`format_file_contents`.
842     """
843     then = datetime.utcnow()
844
845     if content is None:
846         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
847     else:
848         src, encoding, newline = content, "utf-8", ""
849
850     dst = src
851     try:
852         dst = format_file_contents(src, fast=fast, mode=mode)
853         return True
854
855     except NothingChanged:
856         return False
857
858     finally:
859         f = io.TextIOWrapper(
860             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
861         )
862         if write_back == WriteBack.YES:
863             # Make sure there's a newline after the content
864             if dst and dst[-1] != "\n":
865                 dst += "\n"
866             f.write(dst)
867         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
868             now = datetime.utcnow()
869             src_name = f"STDIN\t{then} +0000"
870             dst_name = f"STDOUT\t{now} +0000"
871             d = diff(src, dst, src_name, dst_name)
872             if write_back == WriteBack.COLOR_DIFF:
873                 d = color_diff(d)
874                 f = wrap_stream_for_windows(f)
875             f.write(d)
876         f.detach()
877
878
879 def check_stability_and_equivalence(
880     src_contents: str, dst_contents: str, *, mode: Mode
881 ) -> None:
882     """Perform stability and equivalence checks.
883
884     Raise AssertionError if source and destination contents are not
885     equivalent, or if a second pass of the formatter would format the
886     content differently.
887     """
888     assert_equivalent(src_contents, dst_contents)
889     assert_stable(src_contents, dst_contents, mode=mode)
890
891
892 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
893     """Reformat contents of a file and return new contents.
894
895     If `fast` is False, additionally confirm that the reformatted code is
896     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
897     `mode` is passed to :func:`format_str`.
898     """
899     if not src_contents.strip():
900         raise NothingChanged
901
902     if mode.is_ipynb:
903         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
904     else:
905         dst_contents = format_str(src_contents, mode=mode)
906     if src_contents == dst_contents:
907         raise NothingChanged
908
909     if not fast and not mode.is_ipynb:
910         # Jupyter notebooks will already have been checked above.
911         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
912     return dst_contents
913
914
915 def validate_cell(src: str, mode: Mode) -> None:
916     """Check that cell does not already contain TransformerManager transformations,
917     or non-Python cell magics, which might cause tokenizer_rt to break because of
918     indentations.
919
920     If a cell contains ``!ls``, then it'll be transformed to
921     ``get_ipython().system('ls')``. However, if the cell originally contained
922     ``get_ipython().system('ls')``, then it would get transformed in the same way:
923
924         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
925         "get_ipython().system('ls')\n"
926         >>> TransformerManager().transform_cell("!ls")
927         "get_ipython().system('ls')\n"
928
929     Due to the impossibility of safely roundtripping in such situations, cells
930     containing transformed magics will be ignored.
931     """
932     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
933         raise NothingChanged
934     if (
935         src[:2] == "%%"
936         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
937     ):
938         raise NothingChanged
939
940
941 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
942     """Format code in given cell of Jupyter notebook.
943
944     General idea is:
945
946       - if cell has trailing semicolon, remove it;
947       - if cell has IPython magics, mask them;
948       - format cell;
949       - reinstate IPython magics;
950       - reinstate trailing semicolon (if originally present);
951       - strip trailing newlines.
952
953     Cells with syntax errors will not be processed, as they
954     could potentially be automagics or multi-line magics, which
955     are currently not supported.
956     """
957     validate_cell(src, mode)
958     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
959         src
960     )
961     try:
962         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
963     except SyntaxError:
964         raise NothingChanged from None
965     masked_dst = format_str(masked_src, mode=mode)
966     if not fast:
967         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
968     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
969     dst = put_trailing_semicolon_back(
970         dst_without_trailing_semicolon, has_trailing_semicolon
971     )
972     dst = dst.rstrip("\n")
973     if dst == src:
974         raise NothingChanged from None
975     return dst
976
977
978 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
979     """If notebook is marked as non-Python, don't format it.
980
981     All notebook metadata fields are optional, see
982     https://nbformat.readthedocs.io/en/latest/format_description.html. So
983     if a notebook has empty metadata, we will try to parse it anyway.
984     """
985     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
986     if language is not None and language != "python":
987         raise NothingChanged from None
988
989
990 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
991     """Format Jupyter notebook.
992
993     Operate cell-by-cell, only on code cells, only for Python notebooks.
994     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
995     """
996     trailing_newline = src_contents[-1] == "\n"
997     modified = False
998     nb = json.loads(src_contents)
999     validate_metadata(nb)
1000     for cell in nb["cells"]:
1001         if cell.get("cell_type", None) == "code":
1002             try:
1003                 src = "".join(cell["source"])
1004                 dst = format_cell(src, fast=fast, mode=mode)
1005             except NothingChanged:
1006                 pass
1007             else:
1008                 cell["source"] = dst.splitlines(keepends=True)
1009                 modified = True
1010     if modified:
1011         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1012         if trailing_newline:
1013             dst_contents = dst_contents + "\n"
1014         return dst_contents
1015     else:
1016         raise NothingChanged
1017
1018
1019 def format_str(src_contents: str, *, mode: Mode) -> str:
1020     """Reformat a string and return new contents.
1021
1022     `mode` determines formatting options, such as how many characters per line are
1023     allowed.  Example:
1024
1025     >>> import black
1026     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1027     def f(arg: str = "") -> None:
1028         ...
1029
1030     A more complex example:
1031
1032     >>> print(
1033     ...   black.format_str(
1034     ...     "def f(arg:str='')->None: hey",
1035     ...     mode=black.Mode(
1036     ...       target_versions={black.TargetVersion.PY36},
1037     ...       line_length=10,
1038     ...       string_normalization=False,
1039     ...       is_pyi=False,
1040     ...     ),
1041     ...   ),
1042     ... )
1043     def f(
1044         arg: str = '',
1045     ) -> None:
1046         hey
1047
1048     """
1049     dst_contents = _format_str_once(src_contents, mode=mode)
1050     # Forced second pass to work around optional trailing commas (becoming
1051     # forced trailing commas on pass 2) interacting differently with optional
1052     # parentheses.  Admittedly ugly.
1053     if src_contents != dst_contents:
1054         return _format_str_once(dst_contents, mode=mode)
1055     return dst_contents
1056
1057
1058 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1059     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1060     dst_contents = []
1061     if mode.target_versions:
1062         versions = mode.target_versions
1063     else:
1064         future_imports = get_future_imports(src_node)
1065         versions = detect_target_versions(src_node, future_imports=future_imports)
1066
1067     normalize_fmt_off(src_node, preview=mode.preview)
1068     lines = LineGenerator(mode=mode)
1069     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1070     empty_line = Line(mode=mode)
1071     after = 0
1072     split_line_features = {
1073         feature
1074         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1075         if supports_feature(versions, feature)
1076     }
1077     for current_line in lines.visit(src_node):
1078         dst_contents.append(str(empty_line) * after)
1079         before, after = elt.maybe_empty_lines(current_line)
1080         dst_contents.append(str(empty_line) * before)
1081         for line in transform_line(
1082             current_line, mode=mode, features=split_line_features
1083         ):
1084             dst_contents.append(str(line))
1085     return "".join(dst_contents)
1086
1087
1088 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1089     """Return a tuple of (decoded_contents, encoding, newline).
1090
1091     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1092     universal newlines (i.e. only contains LF).
1093     """
1094     srcbuf = io.BytesIO(src)
1095     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1096     if not lines:
1097         return "", encoding, "\n"
1098
1099     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1100     srcbuf.seek(0)
1101     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1102         return tiow.read(), encoding, newline
1103
1104
1105 def get_features_used(  # noqa: C901
1106     node: Node, *, future_imports: Optional[Set[str]] = None
1107 ) -> Set[Feature]:
1108     """Return a set of (relatively) new Python features used in this file.
1109
1110     Currently looking for:
1111     - f-strings;
1112     - self-documenting expressions in f-strings (f"{x=}");
1113     - underscores in numeric literals;
1114     - trailing commas after * or ** in function signatures and calls;
1115     - positional only arguments in function signatures and lambdas;
1116     - assignment expression;
1117     - relaxed decorator syntax;
1118     - usage of __future__ flags (annotations);
1119     - print / exec statements;
1120     """
1121     features: Set[Feature] = set()
1122     if future_imports:
1123         features |= {
1124             FUTURE_FLAG_TO_FEATURE[future_import]
1125             for future_import in future_imports
1126             if future_import in FUTURE_FLAG_TO_FEATURE
1127         }
1128
1129     for n in node.pre_order():
1130         if is_string_token(n):
1131             value_head = n.value[:2]
1132             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1133                 features.add(Feature.F_STRINGS)
1134                 if Feature.DEBUG_F_STRINGS not in features:
1135                     for span_beg, span_end in iter_fexpr_spans(n.value):
1136                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1137                             features.add(Feature.DEBUG_F_STRINGS)
1138                             break
1139
1140         elif is_number_token(n):
1141             if "_" in n.value:
1142                 features.add(Feature.NUMERIC_UNDERSCORES)
1143
1144         elif n.type == token.SLASH:
1145             if n.parent and n.parent.type in {
1146                 syms.typedargslist,
1147                 syms.arglist,
1148                 syms.varargslist,
1149             }:
1150                 features.add(Feature.POS_ONLY_ARGUMENTS)
1151
1152         elif n.type == token.COLONEQUAL:
1153             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1154
1155         elif n.type == syms.decorator:
1156             if len(n.children) > 1 and not is_simple_decorator_expression(
1157                 n.children[1]
1158             ):
1159                 features.add(Feature.RELAXED_DECORATORS)
1160
1161         elif (
1162             n.type in {syms.typedargslist, syms.arglist}
1163             and n.children
1164             and n.children[-1].type == token.COMMA
1165         ):
1166             if n.type == syms.typedargslist:
1167                 feature = Feature.TRAILING_COMMA_IN_DEF
1168             else:
1169                 feature = Feature.TRAILING_COMMA_IN_CALL
1170
1171             for ch in n.children:
1172                 if ch.type in STARS:
1173                     features.add(feature)
1174
1175                 if ch.type == syms.argument:
1176                     for argch in ch.children:
1177                         if argch.type in STARS:
1178                             features.add(feature)
1179
1180         elif (
1181             n.type in {syms.return_stmt, syms.yield_expr}
1182             and len(n.children) >= 2
1183             and n.children[1].type == syms.testlist_star_expr
1184             and any(child.type == syms.star_expr for child in n.children[1].children)
1185         ):
1186             features.add(Feature.UNPACKING_ON_FLOW)
1187
1188         elif (
1189             n.type == syms.annassign
1190             and len(n.children) >= 4
1191             and n.children[3].type == syms.testlist_star_expr
1192         ):
1193             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1194
1195         elif (
1196             n.type == syms.except_clause
1197             and len(n.children) >= 2
1198             and n.children[1].type == token.STAR
1199         ):
1200             features.add(Feature.EXCEPT_STAR)
1201
1202         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1203             child.type == syms.star_expr for child in n.children
1204         ):
1205             features.add(Feature.VARIADIC_GENERICS)
1206
1207         elif (
1208             n.type == syms.tname_star
1209             and len(n.children) == 3
1210             and n.children[2].type == syms.star_expr
1211         ):
1212             features.add(Feature.VARIADIC_GENERICS)
1213
1214     return features
1215
1216
1217 def detect_target_versions(
1218     node: Node, *, future_imports: Optional[Set[str]] = None
1219 ) -> Set[TargetVersion]:
1220     """Detect the version to target based on the nodes used."""
1221     features = get_features_used(node, future_imports=future_imports)
1222     return {
1223         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1224     }
1225
1226
1227 def get_future_imports(node: Node) -> Set[str]:
1228     """Return a set of __future__ imports in the file."""
1229     imports: Set[str] = set()
1230
1231     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1232         for child in children:
1233             if isinstance(child, Leaf):
1234                 if child.type == token.NAME:
1235                     yield child.value
1236
1237             elif child.type == syms.import_as_name:
1238                 orig_name = child.children[0]
1239                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1240                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1241                 yield orig_name.value
1242
1243             elif child.type == syms.import_as_names:
1244                 yield from get_imports_from_children(child.children)
1245
1246             else:
1247                 raise AssertionError("Invalid syntax parsing imports")
1248
1249     for child in node.children:
1250         if child.type != syms.simple_stmt:
1251             break
1252
1253         first_child = child.children[0]
1254         if isinstance(first_child, Leaf):
1255             # Continue looking if we see a docstring; otherwise stop.
1256             if (
1257                 len(child.children) == 2
1258                 and first_child.type == token.STRING
1259                 and child.children[1].type == token.NEWLINE
1260             ):
1261                 continue
1262
1263             break
1264
1265         elif first_child.type == syms.import_from:
1266             module_name = first_child.children[1]
1267             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1268                 break
1269
1270             imports |= set(get_imports_from_children(first_child.children[3:]))
1271         else:
1272             break
1273
1274     return imports
1275
1276
1277 def assert_equivalent(src: str, dst: str) -> None:
1278     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1279     try:
1280         src_ast = parse_ast(src)
1281     except Exception as exc:
1282         raise AssertionError(
1283             "cannot use --safe with this file; failed to parse source file AST: "
1284             f"{exc}\n"
1285             "This could be caused by running Black with an older Python version "
1286             "that does not support new syntax used in your source file."
1287         ) from exc
1288
1289     try:
1290         dst_ast = parse_ast(dst)
1291     except Exception as exc:
1292         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1293         raise AssertionError(
1294             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1295             "Please report a bug on https://github.com/psf/black/issues.  "
1296             f"This invalid output might be helpful: {log}"
1297         ) from None
1298
1299     src_ast_str = "\n".join(stringify_ast(src_ast))
1300     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1301     if src_ast_str != dst_ast_str:
1302         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1303         raise AssertionError(
1304             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1305             " source.  Please report a bug on "
1306             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1307         ) from None
1308
1309
1310 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1311     """Raise AssertionError if `dst` reformats differently the second time."""
1312     # We shouldn't call format_str() here, because that formats the string
1313     # twice and may hide a bug where we bounce back and forth between two
1314     # versions.
1315     newdst = _format_str_once(dst, mode=mode)
1316     if dst != newdst:
1317         log = dump_to_file(
1318             str(mode),
1319             diff(src, dst, "source", "first pass"),
1320             diff(dst, newdst, "first pass", "second pass"),
1321         )
1322         raise AssertionError(
1323             "INTERNAL ERROR: Black produced different code on the second pass of the"
1324             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1325             f"  This diff might be helpful: {log}"
1326         ) from None
1327
1328
1329 @contextmanager
1330 def nullcontext() -> Iterator[None]:
1331     """Return an empty context manager.
1332
1333     To be used like `nullcontext` in Python 3.7.
1334     """
1335     yield
1336
1337
1338 def patch_click() -> None:
1339     """Make Click not crash on Python 3.6 with LANG=C.
1340
1341     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1342     default which restricts paths that it can access during the lifetime of the
1343     application.  Click refuses to work in this scenario by raising a RuntimeError.
1344
1345     In case of Black the likelihood that non-ASCII characters are going to be used in
1346     file paths is minimal since it's Python source code.  Moreover, this crash was
1347     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1348     """
1349     modules: List[Any] = []
1350     try:
1351         from click import core
1352     except ImportError:
1353         pass
1354     else:
1355         modules.append(core)
1356     try:
1357         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1358         # older versions installed.
1359         from click import _unicodefun  # type: ignore
1360     except ImportError:
1361         pass
1362     else:
1363         modules.append(_unicodefun)
1364
1365     for module in modules:
1366         if hasattr(module, "_verify_python3_env"):
1367             module._verify_python3_env = lambda: None  # type: ignore
1368         if hasattr(module, "_verify_python_env"):
1369             module._verify_python_env = lambda: None  # type: ignore
1370
1371
1372 def patched_main() -> None:
1373     if sys.platform == "win32" and getattr(sys, "frozen", False):
1374         from multiprocessing import freeze_support
1375
1376         freeze_support()
1377
1378     patch_click()
1379     main()
1380
1381
1382 if __name__ == "__main__":
1383     patched_main()