]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Remove blank lines before class docstring (#3692)
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache, get_cache_info, read_cache, write_cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import (
67     FUTURE_FLAG_TO_FEATURE,
68     VERSION_TO_FEATURES,
69     Feature,
70     Mode,
71     TargetVersion,
72     supports_feature,
73 )
74 from black.nodes import (
75     STARS,
76     is_number_token,
77     is_simple_decorator_expression,
78     is_string_token,
79     syms,
80 )
81 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
82 from black.parsing import InvalidInput  # noqa F401
83 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
84 from black.report import Changed, NothingChanged, Report
85 from black.trans import iter_fexpr_spans
86 from blib2to3.pgen2 import token
87 from blib2to3.pytree import Leaf, Node
88
89 COMPILED = Path(__file__).suffix in (".pyd", ".so")
90
91 # types
92 FileContent = str
93 Encoding = str
94 NewLine = str
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102     COLOR_DIFF = 4
103
104     @classmethod
105     def from_configuration(
106         cls, *, check: bool, diff: bool, color: bool = False
107     ) -> "WriteBack":
108         if check and not diff:
109             return cls.CHECK
110
111         if diff and color:
112             return cls.COLOR_DIFF
113
114         return cls.DIFF if diff else cls.YES
115
116
117 # Legacy name, left for integrations.
118 FileMode = Mode
119
120
121 def read_pyproject_toml(
122     ctx: click.Context, param: click.Parameter, value: Optional[str]
123 ) -> Optional[str]:
124     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
125
126     Returns the path to a successfully found and read configuration file, None
127     otherwise.
128     """
129     if not value:
130         value = find_pyproject_toml(ctx.params.get("src", ()))
131         if value is None:
132             return None
133
134     try:
135         config = parse_pyproject_toml(value)
136     except (OSError, ValueError) as e:
137         raise click.FileError(
138             filename=value, hint=f"Error reading configuration file: {e}"
139         ) from None
140
141     if not config:
142         return None
143     else:
144         # Sanitize the values to be Click friendly. For more information please see:
145         # https://github.com/psf/black/issues/1458
146         # https://github.com/pallets/click/issues/1567
147         config = {
148             k: str(v) if not isinstance(v, (list, dict)) else v
149             for k, v in config.items()
150         }
151
152     target_version = config.get("target_version")
153     if target_version is not None and not isinstance(target_version, list):
154         raise click.BadOptionUsage(
155             "target-version", "Config key target-version must be a list"
156         )
157
158     default_map: Dict[str, Any] = {}
159     if ctx.default_map:
160         default_map.update(ctx.default_map)
161     default_map.update(config)
162
163     ctx.default_map = default_map
164     return value
165
166
167 def target_version_option_callback(
168     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
169 ) -> List[TargetVersion]:
170     """Compute the target versions from a --target-version flag.
171
172     This is its own function because mypy couldn't infer the type correctly
173     when it was a lambda, causing mypyc trouble.
174     """
175     return [TargetVersion[val.upper()] for val in v]
176
177
178 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
179     """Compile a regular expression string in `regex`.
180
181     If it contains newlines, use verbose mode.
182     """
183     if "\n" in regex:
184         regex = "(?x)" + regex
185     compiled: Pattern[str] = re.compile(regex)
186     return compiled
187
188
189 def validate_regex(
190     ctx: click.Context,
191     param: click.Parameter,
192     value: Optional[str],
193 ) -> Optional[Pattern[str]]:
194     try:
195         return re_compile_maybe_verbose(value) if value is not None else None
196     except re.error as e:
197         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
198
199
200 @click.command(
201     context_settings={"help_option_names": ["-h", "--help"]},
202     # While Click does set this field automatically using the docstring, mypyc
203     # (annoyingly) strips 'em so we need to set it here too.
204     help="The uncompromising code formatter.",
205 )
206 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
207 @click.option(
208     "-l",
209     "--line-length",
210     type=int,
211     default=DEFAULT_LINE_LENGTH,
212     help="How many characters per line to allow.",
213     show_default=True,
214 )
215 @click.option(
216     "-t",
217     "--target-version",
218     type=click.Choice([v.name.lower() for v in TargetVersion]),
219     callback=target_version_option_callback,
220     multiple=True,
221     help=(
222         "Python versions that should be supported by Black's output. By default, Black"
223         " will try to infer this from the project metadata in pyproject.toml. If this"
224         " does not yield conclusive results, Black will use per-file auto-detection."
225     ),
226 )
227 @click.option(
228     "--pyi",
229     is_flag=True,
230     help=(
231         "Format all input files like typing stubs regardless of file extension (useful"
232         " when piping source on standard input)."
233     ),
234 )
235 @click.option(
236     "--ipynb",
237     is_flag=True,
238     help=(
239         "Format all input files like Jupyter Notebooks regardless of file extension "
240         "(useful when piping source on standard input)."
241     ),
242 )
243 @click.option(
244     "--python-cell-magics",
245     multiple=True,
246     help=(
247         "When processing Jupyter Notebooks, add the given magic to the list"
248         f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
249         " Useful for formatting cells with custom python magics."
250     ),
251     default=[],
252 )
253 @click.option(
254     "-x",
255     "--skip-source-first-line",
256     is_flag=True,
257     help="Skip the first line of the source code.",
258 )
259 @click.option(
260     "-S",
261     "--skip-string-normalization",
262     is_flag=True,
263     help="Don't normalize string quotes or prefixes.",
264 )
265 @click.option(
266     "-C",
267     "--skip-magic-trailing-comma",
268     is_flag=True,
269     help="Don't use trailing commas as a reason to split lines.",
270 )
271 @click.option(
272     "--experimental-string-processing",
273     is_flag=True,
274     hidden=True,
275     help="(DEPRECATED and now included in --preview) Normalize string literals.",
276 )
277 @click.option(
278     "--preview",
279     is_flag=True,
280     help=(
281         "Enable potentially disruptive style changes that may be added to Black's main"
282         " functionality in the next major release."
283     ),
284 )
285 @click.option(
286     "--check",
287     is_flag=True,
288     help=(
289         "Don't write the files back, just return the status. Return code 0 means"
290         " nothing would change. Return code 1 means some files would be reformatted."
291         " Return code 123 means there was an internal error."
292     ),
293 )
294 @click.option(
295     "--diff",
296     is_flag=True,
297     help="Don't write the files back, just output a diff for each file on stdout.",
298 )
299 @click.option(
300     "--color/--no-color",
301     is_flag=True,
302     help="Show colored diff. Only applies when `--diff` is given.",
303 )
304 @click.option(
305     "--fast/--safe",
306     is_flag=True,
307     help="If --fast given, skip temporary sanity checks. [default: --safe]",
308 )
309 @click.option(
310     "--required-version",
311     type=str,
312     help=(
313         "Require a specific version of Black to be running (useful for unifying results"
314         " across many environments e.g. with a pyproject.toml file). It can be"
315         " either a major version number or an exact version."
316     ),
317 )
318 @click.option(
319     "--include",
320     type=str,
321     default=DEFAULT_INCLUDES,
322     callback=validate_regex,
323     help=(
324         "A regular expression that matches files and directories that should be"
325         " included on recursive searches. An empty value means all files are included"
326         " regardless of the name. Use forward slashes for directories on all platforms"
327         " (Windows, too). Exclusions are calculated first, inclusions later."
328     ),
329     show_default=True,
330 )
331 @click.option(
332     "--exclude",
333     type=str,
334     callback=validate_regex,
335     help=(
336         "A regular expression that matches files and directories that should be"
337         " excluded on recursive searches. An empty value means no paths are excluded."
338         " Use forward slashes for directories on all platforms (Windows, too)."
339         " Exclusions are calculated first, inclusions later. [default:"
340         f" {DEFAULT_EXCLUDES}]"
341     ),
342     show_default=False,
343 )
344 @click.option(
345     "--extend-exclude",
346     type=str,
347     callback=validate_regex,
348     help=(
349         "Like --exclude, but adds additional files and directories on top of the"
350         " excluded ones. (Useful if you simply want to add to the default)"
351     ),
352 )
353 @click.option(
354     "--force-exclude",
355     type=str,
356     callback=validate_regex,
357     help=(
358         "Like --exclude, but files and directories matching this regex will be "
359         "excluded even when they are passed explicitly as arguments."
360     ),
361 )
362 @click.option(
363     "--stdin-filename",
364     type=str,
365     help=(
366         "The name of the file when passing it through stdin. Useful to make "
367         "sure Black will respect --force-exclude option on some "
368         "editors that rely on using stdin."
369     ),
370 )
371 @click.option(
372     "-W",
373     "--workers",
374     type=click.IntRange(min=1),
375     default=None,
376     help="Number of parallel workers [default: number of CPUs in the system]",
377 )
378 @click.option(
379     "-q",
380     "--quiet",
381     is_flag=True,
382     help=(
383         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
384         " those with 2>/dev/null."
385     ),
386 )
387 @click.option(
388     "-v",
389     "--verbose",
390     is_flag=True,
391     help=(
392         "Also emit messages to stderr about files that were not changed or were ignored"
393         " due to exclusion patterns."
394     ),
395 )
396 @click.version_option(
397     version=__version__,
398     message=(
399         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
400         f"Python ({platform.python_implementation()}) {platform.python_version()}"
401     ),
402 )
403 @click.argument(
404     "src",
405     nargs=-1,
406     type=click.Path(
407         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
408     ),
409     is_eager=True,
410     metavar="SRC ...",
411 )
412 @click.option(
413     "--config",
414     type=click.Path(
415         exists=True,
416         file_okay=True,
417         dir_okay=False,
418         readable=True,
419         allow_dash=False,
420         path_type=str,
421     ),
422     is_eager=True,
423     callback=read_pyproject_toml,
424     help="Read configuration from FILE path.",
425 )
426 @click.pass_context
427 def main(  # noqa: C901
428     ctx: click.Context,
429     code: Optional[str],
430     line_length: int,
431     target_version: List[TargetVersion],
432     check: bool,
433     diff: bool,
434     color: bool,
435     fast: bool,
436     pyi: bool,
437     ipynb: bool,
438     python_cell_magics: Sequence[str],
439     skip_source_first_line: bool,
440     skip_string_normalization: bool,
441     skip_magic_trailing_comma: bool,
442     experimental_string_processing: bool,
443     preview: bool,
444     quiet: bool,
445     verbose: bool,
446     required_version: Optional[str],
447     include: Pattern[str],
448     exclude: Optional[Pattern[str]],
449     extend_exclude: Optional[Pattern[str]],
450     force_exclude: Optional[Pattern[str]],
451     stdin_filename: Optional[str],
452     workers: Optional[int],
453     src: Tuple[str, ...],
454     config: Optional[str],
455 ) -> None:
456     """The uncompromising code formatter."""
457     ctx.ensure_object(dict)
458
459     if src and code is not None:
460         out(
461             main.get_usage(ctx)
462             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
463         )
464         ctx.exit(1)
465     if not src and code is None:
466         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
467         ctx.exit(1)
468
469     root, method = (
470         find_project_root(src, stdin_filename) if code is None else (None, None)
471     )
472     ctx.obj["root"] = root
473
474     if verbose:
475         if root:
476             out(
477                 f"Identified `{root}` as project root containing a {method}.",
478                 fg="blue",
479             )
480
481             normalized = [
482                 (
483                     (source, source)
484                     if source == "-"
485                     else (normalize_path_maybe_ignore(Path(source), root), source)
486                 )
487                 for source in src
488             ]
489             srcs_string = ", ".join(
490                 [
491                     (
492                         f'"{_norm}"'
493                         if _norm
494                         else f'\033[31m"{source} (skipping - invalid)"\033[34m'
495                     )
496                     for _norm, source in normalized
497                 ]
498             )
499             out(f"Sources to be formatted: {srcs_string}", fg="blue")
500
501         if config:
502             config_source = ctx.get_parameter_source("config")
503             user_level_config = str(find_user_pyproject_toml())
504             if config == user_level_config:
505                 out(
506                     "Using configuration from user-level config at "
507                     f"'{user_level_config}'.",
508                     fg="blue",
509                 )
510             elif config_source in (
511                 ParameterSource.DEFAULT,
512                 ParameterSource.DEFAULT_MAP,
513             ):
514                 out("Using configuration from project root.", fg="blue")
515             else:
516                 out(f"Using configuration in '{config}'.", fg="blue")
517             if ctx.default_map:
518                 for param, value in ctx.default_map.items():
519                     out(f"{param}: {value}")
520
521     error_msg = "Oh no! 💥 💔 💥"
522     if (
523         required_version
524         and required_version != __version__
525         and required_version != __version__.split(".")[0]
526     ):
527         err(
528             f"{error_msg} The required version `{required_version}` does not match"
529             f" the running version `{__version__}`!"
530         )
531         ctx.exit(1)
532     if ipynb and pyi:
533         err("Cannot pass both `pyi` and `ipynb` flags!")
534         ctx.exit(1)
535
536     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
537     if target_version:
538         versions = set(target_version)
539     else:
540         # We'll autodetect later.
541         versions = set()
542     mode = Mode(
543         target_versions=versions,
544         line_length=line_length,
545         is_pyi=pyi,
546         is_ipynb=ipynb,
547         skip_source_first_line=skip_source_first_line,
548         string_normalization=not skip_string_normalization,
549         magic_trailing_comma=not skip_magic_trailing_comma,
550         experimental_string_processing=experimental_string_processing,
551         preview=preview,
552         python_cell_magics=set(python_cell_magics),
553     )
554
555     if code is not None:
556         # Run in quiet mode by default with -c; the extra output isn't useful.
557         # You can still pass -v to get verbose output.
558         quiet = True
559
560     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
561
562     if code is not None:
563         reformat_code(
564             content=code, fast=fast, write_back=write_back, mode=mode, report=report
565         )
566     else:
567         try:
568             sources = get_sources(
569                 ctx=ctx,
570                 src=src,
571                 quiet=quiet,
572                 verbose=verbose,
573                 include=include,
574                 exclude=exclude,
575                 extend_exclude=extend_exclude,
576                 force_exclude=force_exclude,
577                 report=report,
578                 stdin_filename=stdin_filename,
579             )
580         except GitWildMatchPatternError:
581             ctx.exit(1)
582
583         path_empty(
584             sources,
585             "No Python files are present to be formatted. Nothing to do 😴",
586             quiet,
587             verbose,
588             ctx,
589         )
590
591         if len(sources) == 1:
592             reformat_one(
593                 src=sources.pop(),
594                 fast=fast,
595                 write_back=write_back,
596                 mode=mode,
597                 report=report,
598             )
599         else:
600             from black.concurrency import reformat_many
601
602             reformat_many(
603                 sources=sources,
604                 fast=fast,
605                 write_back=write_back,
606                 mode=mode,
607                 report=report,
608                 workers=workers,
609             )
610
611     if verbose or not quiet:
612         if code is None and (verbose or report.change_count or report.failure_count):
613             out()
614         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
615         if code is None:
616             click.echo(str(report), err=True)
617     ctx.exit(report.return_code)
618
619
620 def get_sources(
621     *,
622     ctx: click.Context,
623     src: Tuple[str, ...],
624     quiet: bool,
625     verbose: bool,
626     include: Pattern[str],
627     exclude: Optional[Pattern[str]],
628     extend_exclude: Optional[Pattern[str]],
629     force_exclude: Optional[Pattern[str]],
630     report: "Report",
631     stdin_filename: Optional[str],
632 ) -> Set[Path]:
633     """Compute the set of files to be formatted."""
634     sources: Set[Path] = set()
635     root = ctx.obj["root"]
636
637     using_default_exclude = exclude is None
638     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
639     gitignore: Optional[Dict[Path, PathSpec]] = None
640     root_gitignore = get_gitignore(root)
641
642     for s in src:
643         if s == "-" and stdin_filename:
644             p = Path(stdin_filename)
645             is_stdin = True
646         else:
647             p = Path(s)
648             is_stdin = False
649
650         if is_stdin or p.is_file():
651             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
652             if normalized_path is None:
653                 continue
654
655             normalized_path = "/" + normalized_path
656             # Hard-exclude any files that matches the `--force-exclude` regex.
657             if force_exclude:
658                 force_exclude_match = force_exclude.search(normalized_path)
659             else:
660                 force_exclude_match = None
661             if force_exclude_match and force_exclude_match.group(0):
662                 report.path_ignored(p, "matches the --force-exclude regular expression")
663                 continue
664
665             if is_stdin:
666                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
667
668             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
669                 verbose=verbose, quiet=quiet
670             ):
671                 continue
672
673             sources.add(p)
674         elif p.is_dir():
675             p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
676             if using_default_exclude:
677                 gitignore = {
678                     root: root_gitignore,
679                     p: get_gitignore(p),
680                 }
681             sources.update(
682                 gen_python_files(
683                     p.iterdir(),
684                     ctx.obj["root"],
685                     include,
686                     exclude,
687                     extend_exclude,
688                     force_exclude,
689                     report,
690                     gitignore,
691                     verbose=verbose,
692                     quiet=quiet,
693                 )
694             )
695         elif s == "-":
696             sources.add(p)
697         else:
698             err(f"invalid path: {s}")
699     return sources
700
701
702 def path_empty(
703     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
704 ) -> None:
705     """
706     Exit if there is no `src` provided for formatting
707     """
708     if not src:
709         if verbose or not quiet:
710             out(msg)
711         ctx.exit(0)
712
713
714 def reformat_code(
715     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
716 ) -> None:
717     """
718     Reformat and print out `content` without spawning child processes.
719     Similar to `reformat_one`, but for string content.
720
721     `fast`, `write_back`, and `mode` options are passed to
722     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
723     """
724     path = Path("<string>")
725     try:
726         changed = Changed.NO
727         if format_stdin_to_stdout(
728             content=content, fast=fast, write_back=write_back, mode=mode
729         ):
730             changed = Changed.YES
731         report.done(path, changed)
732     except Exception as exc:
733         if report.verbose:
734             traceback.print_exc()
735         report.failed(path, str(exc))
736
737
738 # diff-shades depends on being to monkeypatch this function to operate. I know it's
739 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
740 @mypyc_attr(patchable=True)
741 def reformat_one(
742     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
743 ) -> None:
744     """Reformat a single file under `src` without spawning child processes.
745
746     `fast`, `write_back`, and `mode` options are passed to
747     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
748     """
749     try:
750         changed = Changed.NO
751
752         if str(src) == "-":
753             is_stdin = True
754         elif str(src).startswith(STDIN_PLACEHOLDER):
755             is_stdin = True
756             # Use the original name again in case we want to print something
757             # to the user
758             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
759         else:
760             is_stdin = False
761
762         if is_stdin:
763             if src.suffix == ".pyi":
764                 mode = replace(mode, is_pyi=True)
765             elif src.suffix == ".ipynb":
766                 mode = replace(mode, is_ipynb=True)
767             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
768                 changed = Changed.YES
769         else:
770             cache: Cache = {}
771             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
772                 cache = read_cache(mode)
773                 res_src = src.resolve()
774                 res_src_s = str(res_src)
775                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
776                     changed = Changed.CACHED
777             if changed is not Changed.CACHED and format_file_in_place(
778                 src, fast=fast, write_back=write_back, mode=mode
779             ):
780                 changed = Changed.YES
781             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
782                 write_back is WriteBack.CHECK and changed is Changed.NO
783             ):
784                 write_cache(cache, [src], mode)
785         report.done(src, changed)
786     except Exception as exc:
787         if report.verbose:
788             traceback.print_exc()
789         report.failed(src, str(exc))
790
791
792 def format_file_in_place(
793     src: Path,
794     fast: bool,
795     mode: Mode,
796     write_back: WriteBack = WriteBack.NO,
797     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
798 ) -> bool:
799     """Format file under `src` path. Return True if changed.
800
801     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
802     code to the file.
803     `mode` and `fast` options are passed to :func:`format_file_contents`.
804     """
805     if src.suffix == ".pyi":
806         mode = replace(mode, is_pyi=True)
807     elif src.suffix == ".ipynb":
808         mode = replace(mode, is_ipynb=True)
809
810     then = datetime.utcfromtimestamp(src.stat().st_mtime)
811     header = b""
812     with open(src, "rb") as buf:
813         if mode.skip_source_first_line:
814             header = buf.readline()
815         src_contents, encoding, newline = decode_bytes(buf.read())
816     try:
817         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
818     except NothingChanged:
819         return False
820     except JSONDecodeError:
821         raise ValueError(
822             f"File '{src}' cannot be parsed as valid Jupyter notebook."
823         ) from None
824     src_contents = header.decode(encoding) + src_contents
825     dst_contents = header.decode(encoding) + dst_contents
826
827     if write_back == WriteBack.YES:
828         with open(src, "w", encoding=encoding, newline=newline) as f:
829             f.write(dst_contents)
830     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
831         now = datetime.utcnow()
832         src_name = f"{src}\t{then} +0000"
833         dst_name = f"{src}\t{now} +0000"
834         if mode.is_ipynb:
835             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
836         else:
837             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
838
839         if write_back == WriteBack.COLOR_DIFF:
840             diff_contents = color_diff(diff_contents)
841
842         with lock or nullcontext():
843             f = io.TextIOWrapper(
844                 sys.stdout.buffer,
845                 encoding=encoding,
846                 newline=newline,
847                 write_through=True,
848             )
849             f = wrap_stream_for_windows(f)
850             f.write(diff_contents)
851             f.detach()
852
853     return True
854
855
856 def format_stdin_to_stdout(
857     fast: bool,
858     *,
859     content: Optional[str] = None,
860     write_back: WriteBack = WriteBack.NO,
861     mode: Mode,
862 ) -> bool:
863     """Format file on stdin. Return True if changed.
864
865     If content is None, it's read from sys.stdin.
866
867     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
868     write a diff to stdout. The `mode` argument is passed to
869     :func:`format_file_contents`.
870     """
871     then = datetime.utcnow()
872
873     if content is None:
874         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
875     else:
876         src, encoding, newline = content, "utf-8", ""
877
878     dst = src
879     try:
880         dst = format_file_contents(src, fast=fast, mode=mode)
881         return True
882
883     except NothingChanged:
884         return False
885
886     finally:
887         f = io.TextIOWrapper(
888             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
889         )
890         if write_back == WriteBack.YES:
891             # Make sure there's a newline after the content
892             if dst and dst[-1] != "\n":
893                 dst += "\n"
894             f.write(dst)
895         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
896             now = datetime.utcnow()
897             src_name = f"STDIN\t{then} +0000"
898             dst_name = f"STDOUT\t{now} +0000"
899             d = diff(src, dst, src_name, dst_name)
900             if write_back == WriteBack.COLOR_DIFF:
901                 d = color_diff(d)
902                 f = wrap_stream_for_windows(f)
903             f.write(d)
904         f.detach()
905
906
907 def check_stability_and_equivalence(
908     src_contents: str, dst_contents: str, *, mode: Mode
909 ) -> None:
910     """Perform stability and equivalence checks.
911
912     Raise AssertionError if source and destination contents are not
913     equivalent, or if a second pass of the formatter would format the
914     content differently.
915     """
916     assert_equivalent(src_contents, dst_contents)
917     assert_stable(src_contents, dst_contents, mode=mode)
918
919
920 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
921     """Reformat contents of a file and return new contents.
922
923     If `fast` is False, additionally confirm that the reformatted code is
924     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
925     `mode` is passed to :func:`format_str`.
926     """
927     if mode.is_ipynb:
928         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
929     else:
930         dst_contents = format_str(src_contents, mode=mode)
931     if src_contents == dst_contents:
932         raise NothingChanged
933
934     if not fast and not mode.is_ipynb:
935         # Jupyter notebooks will already have been checked above.
936         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
937     return dst_contents
938
939
940 def validate_cell(src: str, mode: Mode) -> None:
941     """Check that cell does not already contain TransformerManager transformations,
942     or non-Python cell magics, which might cause tokenizer_rt to break because of
943     indentations.
944
945     If a cell contains ``!ls``, then it'll be transformed to
946     ``get_ipython().system('ls')``. However, if the cell originally contained
947     ``get_ipython().system('ls')``, then it would get transformed in the same way:
948
949         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
950         "get_ipython().system('ls')\n"
951         >>> TransformerManager().transform_cell("!ls")
952         "get_ipython().system('ls')\n"
953
954     Due to the impossibility of safely roundtripping in such situations, cells
955     containing transformed magics will be ignored.
956     """
957     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
958         raise NothingChanged
959     if (
960         src[:2] == "%%"
961         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
962     ):
963         raise NothingChanged
964
965
966 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
967     """Format code in given cell of Jupyter notebook.
968
969     General idea is:
970
971       - if cell has trailing semicolon, remove it;
972       - if cell has IPython magics, mask them;
973       - format cell;
974       - reinstate IPython magics;
975       - reinstate trailing semicolon (if originally present);
976       - strip trailing newlines.
977
978     Cells with syntax errors will not be processed, as they
979     could potentially be automagics or multi-line magics, which
980     are currently not supported.
981     """
982     validate_cell(src, mode)
983     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
984         src
985     )
986     try:
987         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
988     except SyntaxError:
989         raise NothingChanged from None
990     masked_dst = format_str(masked_src, mode=mode)
991     if not fast:
992         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
993     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
994     dst = put_trailing_semicolon_back(
995         dst_without_trailing_semicolon, has_trailing_semicolon
996     )
997     dst = dst.rstrip("\n")
998     if dst == src:
999         raise NothingChanged from None
1000     return dst
1001
1002
1003 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1004     """If notebook is marked as non-Python, don't format it.
1005
1006     All notebook metadata fields are optional, see
1007     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1008     if a notebook has empty metadata, we will try to parse it anyway.
1009     """
1010     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1011     if language is not None and language != "python":
1012         raise NothingChanged from None
1013
1014
1015 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1016     """Format Jupyter notebook.
1017
1018     Operate cell-by-cell, only on code cells, only for Python notebooks.
1019     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1020     """
1021     if not src_contents:
1022         raise NothingChanged
1023
1024     trailing_newline = src_contents[-1] == "\n"
1025     modified = False
1026     nb = json.loads(src_contents)
1027     validate_metadata(nb)
1028     for cell in nb["cells"]:
1029         if cell.get("cell_type", None) == "code":
1030             try:
1031                 src = "".join(cell["source"])
1032                 dst = format_cell(src, fast=fast, mode=mode)
1033             except NothingChanged:
1034                 pass
1035             else:
1036                 cell["source"] = dst.splitlines(keepends=True)
1037                 modified = True
1038     if modified:
1039         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1040         if trailing_newline:
1041             dst_contents = dst_contents + "\n"
1042         return dst_contents
1043     else:
1044         raise NothingChanged
1045
1046
1047 def format_str(src_contents: str, *, mode: Mode) -> str:
1048     """Reformat a string and return new contents.
1049
1050     `mode` determines formatting options, such as how many characters per line are
1051     allowed.  Example:
1052
1053     >>> import black
1054     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1055     def f(arg: str = "") -> None:
1056         ...
1057
1058     A more complex example:
1059
1060     >>> print(
1061     ...   black.format_str(
1062     ...     "def f(arg:str='')->None: hey",
1063     ...     mode=black.Mode(
1064     ...       target_versions={black.TargetVersion.PY36},
1065     ...       line_length=10,
1066     ...       string_normalization=False,
1067     ...       is_pyi=False,
1068     ...     ),
1069     ...   ),
1070     ... )
1071     def f(
1072         arg: str = '',
1073     ) -> None:
1074         hey
1075
1076     """
1077     dst_contents = _format_str_once(src_contents, mode=mode)
1078     # Forced second pass to work around optional trailing commas (becoming
1079     # forced trailing commas on pass 2) interacting differently with optional
1080     # parentheses.  Admittedly ugly.
1081     if src_contents != dst_contents:
1082         return _format_str_once(dst_contents, mode=mode)
1083     return dst_contents
1084
1085
1086 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1087     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1088     dst_blocks: List[LinesBlock] = []
1089     if mode.target_versions:
1090         versions = mode.target_versions
1091     else:
1092         future_imports = get_future_imports(src_node)
1093         versions = detect_target_versions(src_node, future_imports=future_imports)
1094
1095     context_manager_features = {
1096         feature
1097         for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
1098         if supports_feature(versions, feature)
1099     }
1100     normalize_fmt_off(src_node)
1101     lines = LineGenerator(mode=mode, features=context_manager_features)
1102     elt = EmptyLineTracker(mode=mode)
1103     split_line_features = {
1104         feature
1105         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1106         if supports_feature(versions, feature)
1107     }
1108     block: Optional[LinesBlock] = None
1109     for current_line in lines.visit(src_node):
1110         block = elt.maybe_empty_lines(current_line)
1111         dst_blocks.append(block)
1112         for line in transform_line(
1113             current_line, mode=mode, features=split_line_features
1114         ):
1115             block.content_lines.append(str(line))
1116     if dst_blocks:
1117         dst_blocks[-1].after = 0
1118     dst_contents = []
1119     for block in dst_blocks:
1120         dst_contents.extend(block.all_lines())
1121     if not dst_contents:
1122         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1123         # and check if normalized_content has more than one line
1124         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1125         if "\n" in normalized_content:
1126             return newline
1127         return ""
1128     return "".join(dst_contents)
1129
1130
1131 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1132     """Return a tuple of (decoded_contents, encoding, newline).
1133
1134     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1135     universal newlines (i.e. only contains LF).
1136     """
1137     srcbuf = io.BytesIO(src)
1138     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1139     if not lines:
1140         return "", encoding, "\n"
1141
1142     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1143     srcbuf.seek(0)
1144     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1145         return tiow.read(), encoding, newline
1146
1147
1148 def get_features_used(  # noqa: C901
1149     node: Node, *, future_imports: Optional[Set[str]] = None
1150 ) -> Set[Feature]:
1151     """Return a set of (relatively) new Python features used in this file.
1152
1153     Currently looking for:
1154     - f-strings;
1155     - self-documenting expressions in f-strings (f"{x=}");
1156     - underscores in numeric literals;
1157     - trailing commas after * or ** in function signatures and calls;
1158     - positional only arguments in function signatures and lambdas;
1159     - assignment expression;
1160     - relaxed decorator syntax;
1161     - usage of __future__ flags (annotations);
1162     - print / exec statements;
1163     - parenthesized context managers;
1164     - match statements;
1165     - except* clause;
1166     - variadic generics;
1167     """
1168     features: Set[Feature] = set()
1169     if future_imports:
1170         features |= {
1171             FUTURE_FLAG_TO_FEATURE[future_import]
1172             for future_import in future_imports
1173             if future_import in FUTURE_FLAG_TO_FEATURE
1174         }
1175
1176     for n in node.pre_order():
1177         if is_string_token(n):
1178             value_head = n.value[:2]
1179             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1180                 features.add(Feature.F_STRINGS)
1181                 if Feature.DEBUG_F_STRINGS not in features:
1182                     for span_beg, span_end in iter_fexpr_spans(n.value):
1183                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1184                             features.add(Feature.DEBUG_F_STRINGS)
1185                             break
1186
1187         elif is_number_token(n):
1188             if "_" in n.value:
1189                 features.add(Feature.NUMERIC_UNDERSCORES)
1190
1191         elif n.type == token.SLASH:
1192             if n.parent and n.parent.type in {
1193                 syms.typedargslist,
1194                 syms.arglist,
1195                 syms.varargslist,
1196             }:
1197                 features.add(Feature.POS_ONLY_ARGUMENTS)
1198
1199         elif n.type == token.COLONEQUAL:
1200             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1201
1202         elif n.type == syms.decorator:
1203             if len(n.children) > 1 and not is_simple_decorator_expression(
1204                 n.children[1]
1205             ):
1206                 features.add(Feature.RELAXED_DECORATORS)
1207
1208         elif (
1209             n.type in {syms.typedargslist, syms.arglist}
1210             and n.children
1211             and n.children[-1].type == token.COMMA
1212         ):
1213             if n.type == syms.typedargslist:
1214                 feature = Feature.TRAILING_COMMA_IN_DEF
1215             else:
1216                 feature = Feature.TRAILING_COMMA_IN_CALL
1217
1218             for ch in n.children:
1219                 if ch.type in STARS:
1220                     features.add(feature)
1221
1222                 if ch.type == syms.argument:
1223                     for argch in ch.children:
1224                         if argch.type in STARS:
1225                             features.add(feature)
1226
1227         elif (
1228             n.type in {syms.return_stmt, syms.yield_expr}
1229             and len(n.children) >= 2
1230             and n.children[1].type == syms.testlist_star_expr
1231             and any(child.type == syms.star_expr for child in n.children[1].children)
1232         ):
1233             features.add(Feature.UNPACKING_ON_FLOW)
1234
1235         elif (
1236             n.type == syms.annassign
1237             and len(n.children) >= 4
1238             and n.children[3].type == syms.testlist_star_expr
1239         ):
1240             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1241
1242         elif (
1243             n.type == syms.with_stmt
1244             and len(n.children) > 2
1245             and n.children[1].type == syms.atom
1246         ):
1247             atom_children = n.children[1].children
1248             if (
1249                 len(atom_children) == 3
1250                 and atom_children[0].type == token.LPAR
1251                 and atom_children[1].type == syms.testlist_gexp
1252                 and atom_children[2].type == token.RPAR
1253             ):
1254                 features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
1255
1256         elif n.type == syms.match_stmt:
1257             features.add(Feature.PATTERN_MATCHING)
1258
1259         elif (
1260             n.type == syms.except_clause
1261             and len(n.children) >= 2
1262             and n.children[1].type == token.STAR
1263         ):
1264             features.add(Feature.EXCEPT_STAR)
1265
1266         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1267             child.type == syms.star_expr for child in n.children
1268         ):
1269             features.add(Feature.VARIADIC_GENERICS)
1270
1271         elif (
1272             n.type == syms.tname_star
1273             and len(n.children) == 3
1274             and n.children[2].type == syms.star_expr
1275         ):
1276             features.add(Feature.VARIADIC_GENERICS)
1277
1278     return features
1279
1280
1281 def detect_target_versions(
1282     node: Node, *, future_imports: Optional[Set[str]] = None
1283 ) -> Set[TargetVersion]:
1284     """Detect the version to target based on the nodes used."""
1285     features = get_features_used(node, future_imports=future_imports)
1286     return {
1287         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1288     }
1289
1290
1291 def get_future_imports(node: Node) -> Set[str]:
1292     """Return a set of __future__ imports in the file."""
1293     imports: Set[str] = set()
1294
1295     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1296         for child in children:
1297             if isinstance(child, Leaf):
1298                 if child.type == token.NAME:
1299                     yield child.value
1300
1301             elif child.type == syms.import_as_name:
1302                 orig_name = child.children[0]
1303                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1304                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1305                 yield orig_name.value
1306
1307             elif child.type == syms.import_as_names:
1308                 yield from get_imports_from_children(child.children)
1309
1310             else:
1311                 raise AssertionError("Invalid syntax parsing imports")
1312
1313     for child in node.children:
1314         if child.type != syms.simple_stmt:
1315             break
1316
1317         first_child = child.children[0]
1318         if isinstance(first_child, Leaf):
1319             # Continue looking if we see a docstring; otherwise stop.
1320             if (
1321                 len(child.children) == 2
1322                 and first_child.type == token.STRING
1323                 and child.children[1].type == token.NEWLINE
1324             ):
1325                 continue
1326
1327             break
1328
1329         elif first_child.type == syms.import_from:
1330             module_name = first_child.children[1]
1331             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1332                 break
1333
1334             imports |= set(get_imports_from_children(first_child.children[3:]))
1335         else:
1336             break
1337
1338     return imports
1339
1340
1341 def assert_equivalent(src: str, dst: str) -> None:
1342     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1343     try:
1344         src_ast = parse_ast(src)
1345     except Exception as exc:
1346         raise AssertionError(
1347             "cannot use --safe with this file; failed to parse source file AST: "
1348             f"{exc}\n"
1349             "This could be caused by running Black with an older Python version "
1350             "that does not support new syntax used in your source file."
1351         ) from exc
1352
1353     try:
1354         dst_ast = parse_ast(dst)
1355     except Exception as exc:
1356         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1357         raise AssertionError(
1358             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1359             "Please report a bug on https://github.com/psf/black/issues.  "
1360             f"This invalid output might be helpful: {log}"
1361         ) from None
1362
1363     src_ast_str = "\n".join(stringify_ast(src_ast))
1364     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1365     if src_ast_str != dst_ast_str:
1366         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1367         raise AssertionError(
1368             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1369             " source.  Please report a bug on "
1370             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1371         ) from None
1372
1373
1374 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1375     """Raise AssertionError if `dst` reformats differently the second time."""
1376     # We shouldn't call format_str() here, because that formats the string
1377     # twice and may hide a bug where we bounce back and forth between two
1378     # versions.
1379     newdst = _format_str_once(dst, mode=mode)
1380     if dst != newdst:
1381         log = dump_to_file(
1382             str(mode),
1383             diff(src, dst, "source", "first pass"),
1384             diff(dst, newdst, "first pass", "second pass"),
1385         )
1386         raise AssertionError(
1387             "INTERNAL ERROR: Black produced different code on the second pass of the"
1388             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1389             f"  This diff might be helpful: {log}"
1390         ) from None
1391
1392
1393 @contextmanager
1394 def nullcontext() -> Iterator[None]:
1395     """Return an empty context manager.
1396
1397     To be used like `nullcontext` in Python 3.7.
1398     """
1399     yield
1400
1401
1402 def patch_click() -> None:
1403     """Make Click not crash on Python 3.6 with LANG=C.
1404
1405     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1406     default which restricts paths that it can access during the lifetime of the
1407     application.  Click refuses to work in this scenario by raising a RuntimeError.
1408
1409     In case of Black the likelihood that non-ASCII characters are going to be used in
1410     file paths is minimal since it's Python source code.  Moreover, this crash was
1411     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1412     """
1413     modules: List[Any] = []
1414     try:
1415         from click import core
1416     except ImportError:
1417         pass
1418     else:
1419         modules.append(core)
1420     try:
1421         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1422         # older versions installed.
1423         from click import _unicodefun  # type: ignore
1424     except ImportError:
1425         pass
1426     else:
1427         modules.append(_unicodefun)
1428
1429     for module in modules:
1430         if hasattr(module, "_verify_python3_env"):
1431             module._verify_python3_env = lambda: None
1432         if hasattr(module, "_verify_python_env"):
1433             module._verify_python_env = lambda: None
1434
1435
1436 def patched_main() -> None:
1437     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1438     # environments so just assume we always need to call it if frozen.
1439     if getattr(sys, "frozen", False):
1440         from multiprocessing import freeze_support
1441
1442         freeze_support()
1443
1444     patch_click()
1445     main()
1446
1447
1448 if __name__ == "__main__":
1449     patched_main()