]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Update stable branch after publishing to PyPI (#3223)
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
34
35 from _black_version import version as __version__
36 from black.cache import Cache, get_cache_info, read_cache, write_cache
37 from black.comments import normalize_fmt_off
38 from black.const import (
39     DEFAULT_EXCLUDES,
40     DEFAULT_INCLUDES,
41     DEFAULT_LINE_LENGTH,
42     STDIN_PLACEHOLDER,
43 )
44 from black.files import (
45     find_project_root,
46     find_pyproject_toml,
47     find_user_pyproject_toml,
48     gen_python_files,
49     get_gitignore,
50     normalize_path_maybe_ignore,
51     parse_pyproject_toml,
52     wrap_stream_for_windows,
53 )
54 from black.handle_ipynb_magics import (
55     PYTHON_CELL_MAGICS,
56     TRANSFORMED_MAGICS,
57     jupyter_dependencies_are_installed,
58     mask_cell,
59     put_trailing_semicolon_back,
60     remove_trailing_semicolon,
61     unmask_cell,
62 )
63 from black.linegen import LN, LineGenerator, transform_line
64 from black.lines import EmptyLineTracker, Line
65 from black.mode import (
66     FUTURE_FLAG_TO_FEATURE,
67     VERSION_TO_FEATURES,
68     Feature,
69     Mode,
70     TargetVersion,
71     supports_feature,
72 )
73 from black.nodes import (
74     STARS,
75     is_number_token,
76     is_simple_decorator_expression,
77     is_string_token,
78     syms,
79 )
80 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
81 from black.parsing import InvalidInput  # noqa F401
82 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
83 from black.report import Changed, NothingChanged, Report
84 from black.trans import iter_fexpr_spans
85 from blib2to3.pgen2 import token
86 from blib2to3.pytree import Leaf, Node
87
88 COMPILED = Path(__file__).suffix in (".pyd", ".so")
89
90 # types
91 FileContent = str
92 Encoding = str
93 NewLine = str
94
95
96 class WriteBack(Enum):
97     NO = 0
98     YES = 1
99     DIFF = 2
100     CHECK = 3
101     COLOR_DIFF = 4
102
103     @classmethod
104     def from_configuration(
105         cls, *, check: bool, diff: bool, color: bool = False
106     ) -> "WriteBack":
107         if check and not diff:
108             return cls.CHECK
109
110         if diff and color:
111             return cls.COLOR_DIFF
112
113         return cls.DIFF if diff else cls.YES
114
115
116 # Legacy name, left for integrations.
117 FileMode = Mode
118
119
120 def read_pyproject_toml(
121     ctx: click.Context, param: click.Parameter, value: Optional[str]
122 ) -> Optional[str]:
123     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
124
125     Returns the path to a successfully found and read configuration file, None
126     otherwise.
127     """
128     if not value:
129         value = find_pyproject_toml(ctx.params.get("src", ()))
130         if value is None:
131             return None
132
133     try:
134         config = parse_pyproject_toml(value)
135     except (OSError, ValueError) as e:
136         raise click.FileError(
137             filename=value, hint=f"Error reading configuration file: {e}"
138         ) from None
139
140     if not config:
141         return None
142     else:
143         # Sanitize the values to be Click friendly. For more information please see:
144         # https://github.com/psf/black/issues/1458
145         # https://github.com/pallets/click/issues/1567
146         config = {
147             k: str(v) if not isinstance(v, (list, dict)) else v
148             for k, v in config.items()
149         }
150
151     target_version = config.get("target_version")
152     if target_version is not None and not isinstance(target_version, list):
153         raise click.BadOptionUsage(
154             "target-version", "Config key target-version must be a list"
155         )
156
157     default_map: Dict[str, Any] = {}
158     if ctx.default_map:
159         default_map.update(ctx.default_map)
160     default_map.update(config)
161
162     ctx.default_map = default_map
163     return value
164
165
166 def target_version_option_callback(
167     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
168 ) -> List[TargetVersion]:
169     """Compute the target versions from a --target-version flag.
170
171     This is its own function because mypy couldn't infer the type correctly
172     when it was a lambda, causing mypyc trouble.
173     """
174     return [TargetVersion[val.upper()] for val in v]
175
176
177 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
178     """Compile a regular expression string in `regex`.
179
180     If it contains newlines, use verbose mode.
181     """
182     if "\n" in regex:
183         regex = "(?x)" + regex
184     compiled: Pattern[str] = re.compile(regex)
185     return compiled
186
187
188 def validate_regex(
189     ctx: click.Context,
190     param: click.Parameter,
191     value: Optional[str],
192 ) -> Optional[Pattern[str]]:
193     try:
194         return re_compile_maybe_verbose(value) if value is not None else None
195     except re.error as e:
196         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
197
198
199 @click.command(
200     context_settings={"help_option_names": ["-h", "--help"]},
201     # While Click does set this field automatically using the docstring, mypyc
202     # (annoyingly) strips 'em so we need to set it here too.
203     help="The uncompromising code formatter.",
204 )
205 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
206 @click.option(
207     "-l",
208     "--line-length",
209     type=int,
210     default=DEFAULT_LINE_LENGTH,
211     help="How many characters per line to allow.",
212     show_default=True,
213 )
214 @click.option(
215     "-t",
216     "--target-version",
217     type=click.Choice([v.name.lower() for v in TargetVersion]),
218     callback=target_version_option_callback,
219     multiple=True,
220     help=(
221         "Python versions that should be supported by Black's output. [default: per-file"
222         " auto-detection]"
223     ),
224 )
225 @click.option(
226     "--pyi",
227     is_flag=True,
228     help=(
229         "Format all input files like typing stubs regardless of file extension (useful"
230         " when piping source on standard input)."
231     ),
232 )
233 @click.option(
234     "--ipynb",
235     is_flag=True,
236     help=(
237         "Format all input files like Jupyter Notebooks regardless of file extension "
238         "(useful when piping source on standard input)."
239     ),
240 )
241 @click.option(
242     "--python-cell-magics",
243     multiple=True,
244     help=(
245         "When processing Jupyter Notebooks, add the given magic to the list"
246         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
247         " Useful for formatting cells with custom python magics."
248     ),
249     default=[],
250 )
251 @click.option(
252     "-S",
253     "--skip-string-normalization",
254     is_flag=True,
255     help="Don't normalize string quotes or prefixes.",
256 )
257 @click.option(
258     "-C",
259     "--skip-magic-trailing-comma",
260     is_flag=True,
261     help="Don't use trailing commas as a reason to split lines.",
262 )
263 @click.option(
264     "--experimental-string-processing",
265     is_flag=True,
266     hidden=True,
267     help="(DEPRECATED and now included in --preview) Normalize string literals.",
268 )
269 @click.option(
270     "--preview",
271     is_flag=True,
272     help=(
273         "Enable potentially disruptive style changes that may be added to Black's main"
274         " functionality in the next major release."
275     ),
276 )
277 @click.option(
278     "--check",
279     is_flag=True,
280     help=(
281         "Don't write the files back, just return the status. Return code 0 means"
282         " nothing would change. Return code 1 means some files would be reformatted."
283         " Return code 123 means there was an internal error."
284     ),
285 )
286 @click.option(
287     "--diff",
288     is_flag=True,
289     help="Don't write the files back, just output a diff for each file on stdout.",
290 )
291 @click.option(
292     "--color/--no-color",
293     is_flag=True,
294     help="Show colored diff. Only applies when `--diff` is given.",
295 )
296 @click.option(
297     "--fast/--safe",
298     is_flag=True,
299     help="If --fast given, skip temporary sanity checks. [default: --safe]",
300 )
301 @click.option(
302     "--required-version",
303     type=str,
304     help=(
305         "Require a specific version of Black to be running (useful for unifying results"
306         " across many environments e.g. with a pyproject.toml file). It can be"
307         " either a major version number or an exact version."
308     ),
309 )
310 @click.option(
311     "--include",
312     type=str,
313     default=DEFAULT_INCLUDES,
314     callback=validate_regex,
315     help=(
316         "A regular expression that matches files and directories that should be"
317         " included on recursive searches. An empty value means all files are included"
318         " regardless of the name. Use forward slashes for directories on all platforms"
319         " (Windows, too). Exclusions are calculated first, inclusions later."
320     ),
321     show_default=True,
322 )
323 @click.option(
324     "--exclude",
325     type=str,
326     callback=validate_regex,
327     help=(
328         "A regular expression that matches files and directories that should be"
329         " excluded on recursive searches. An empty value means no paths are excluded."
330         " Use forward slashes for directories on all platforms (Windows, too)."
331         " Exclusions are calculated first, inclusions later. [default:"
332         f" {DEFAULT_EXCLUDES}]"
333     ),
334     show_default=False,
335 )
336 @click.option(
337     "--extend-exclude",
338     type=str,
339     callback=validate_regex,
340     help=(
341         "Like --exclude, but adds additional files and directories on top of the"
342         " excluded ones. (Useful if you simply want to add to the default)"
343     ),
344 )
345 @click.option(
346     "--force-exclude",
347     type=str,
348     callback=validate_regex,
349     help=(
350         "Like --exclude, but files and directories matching this regex will be "
351         "excluded even when they are passed explicitly as arguments."
352     ),
353 )
354 @click.option(
355     "--stdin-filename",
356     type=str,
357     help=(
358         "The name of the file when passing it through stdin. Useful to make "
359         "sure Black will respect --force-exclude option on some "
360         "editors that rely on using stdin."
361     ),
362 )
363 @click.option(
364     "-W",
365     "--workers",
366     type=click.IntRange(min=1),
367     default=None,
368     help="Number of parallel workers [default: number of CPUs in the system]",
369 )
370 @click.option(
371     "-q",
372     "--quiet",
373     is_flag=True,
374     help=(
375         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
376         " those with 2>/dev/null."
377     ),
378 )
379 @click.option(
380     "-v",
381     "--verbose",
382     is_flag=True,
383     help=(
384         "Also emit messages to stderr about files that were not changed or were ignored"
385         " due to exclusion patterns."
386     ),
387 )
388 @click.version_option(
389     version=__version__,
390     message=(
391         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
392         f"Python ({platform.python_implementation()}) {platform.python_version()}"
393     ),
394 )
395 @click.argument(
396     "src",
397     nargs=-1,
398     type=click.Path(
399         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
400     ),
401     is_eager=True,
402     metavar="SRC ...",
403 )
404 @click.option(
405     "--config",
406     type=click.Path(
407         exists=True,
408         file_okay=True,
409         dir_okay=False,
410         readable=True,
411         allow_dash=False,
412         path_type=str,
413     ),
414     is_eager=True,
415     callback=read_pyproject_toml,
416     help="Read configuration from FILE path.",
417 )
418 @click.pass_context
419 def main(  # noqa: C901
420     ctx: click.Context,
421     code: Optional[str],
422     line_length: int,
423     target_version: List[TargetVersion],
424     check: bool,
425     diff: bool,
426     color: bool,
427     fast: bool,
428     pyi: bool,
429     ipynb: bool,
430     python_cell_magics: Sequence[str],
431     skip_string_normalization: bool,
432     skip_magic_trailing_comma: bool,
433     experimental_string_processing: bool,
434     preview: bool,
435     quiet: bool,
436     verbose: bool,
437     required_version: Optional[str],
438     include: Pattern[str],
439     exclude: Optional[Pattern[str]],
440     extend_exclude: Optional[Pattern[str]],
441     force_exclude: Optional[Pattern[str]],
442     stdin_filename: Optional[str],
443     workers: Optional[int],
444     src: Tuple[str, ...],
445     config: Optional[str],
446 ) -> None:
447     """The uncompromising code formatter."""
448     ctx.ensure_object(dict)
449
450     if src and code is not None:
451         out(
452             main.get_usage(ctx)
453             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
454         )
455         ctx.exit(1)
456     if not src and code is None:
457         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
458         ctx.exit(1)
459
460     root, method = (
461         find_project_root(src, stdin_filename) if code is None else (None, None)
462     )
463     ctx.obj["root"] = root
464
465     if verbose:
466         if root:
467             out(
468                 f"Identified `{root}` as project root containing a {method}.",
469                 fg="blue",
470             )
471
472             normalized = [
473                 (source, source)
474                 if source == "-"
475                 else (normalize_path_maybe_ignore(Path(source), root), source)
476                 for source in src
477             ]
478             srcs_string = ", ".join(
479                 [
480                     f'"{_norm}"'
481                     if _norm
482                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
483                     for _norm, source in normalized
484                 ]
485             )
486             out(f"Sources to be formatted: {srcs_string}", fg="blue")
487
488         if config:
489             config_source = ctx.get_parameter_source("config")
490             user_level_config = str(find_user_pyproject_toml())
491             if config == user_level_config:
492                 out(
493                     "Using configuration from user-level config at "
494                     f"'{user_level_config}'.",
495                     fg="blue",
496                 )
497             elif config_source in (
498                 ParameterSource.DEFAULT,
499                 ParameterSource.DEFAULT_MAP,
500             ):
501                 out("Using configuration from project root.", fg="blue")
502             else:
503                 out(f"Using configuration in '{config}'.", fg="blue")
504
505     error_msg = "Oh no! 💥 💔 💥"
506     if (
507         required_version
508         and required_version != __version__
509         and required_version != __version__.split(".")[0]
510     ):
511         err(
512             f"{error_msg} The required version `{required_version}` does not match"
513             f" the running version `{__version__}`!"
514         )
515         ctx.exit(1)
516     if ipynb and pyi:
517         err("Cannot pass both `pyi` and `ipynb` flags!")
518         ctx.exit(1)
519
520     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
521     if target_version:
522         versions = set(target_version)
523     else:
524         # We'll autodetect later.
525         versions = set()
526     mode = Mode(
527         target_versions=versions,
528         line_length=line_length,
529         is_pyi=pyi,
530         is_ipynb=ipynb,
531         string_normalization=not skip_string_normalization,
532         magic_trailing_comma=not skip_magic_trailing_comma,
533         experimental_string_processing=experimental_string_processing,
534         preview=preview,
535         python_cell_magics=set(python_cell_magics),
536     )
537
538     if code is not None:
539         # Run in quiet mode by default with -c; the extra output isn't useful.
540         # You can still pass -v to get verbose output.
541         quiet = True
542
543     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
544
545     if code is not None:
546         reformat_code(
547             content=code, fast=fast, write_back=write_back, mode=mode, report=report
548         )
549     else:
550         try:
551             sources = get_sources(
552                 ctx=ctx,
553                 src=src,
554                 quiet=quiet,
555                 verbose=verbose,
556                 include=include,
557                 exclude=exclude,
558                 extend_exclude=extend_exclude,
559                 force_exclude=force_exclude,
560                 report=report,
561                 stdin_filename=stdin_filename,
562             )
563         except GitWildMatchPatternError:
564             ctx.exit(1)
565
566         path_empty(
567             sources,
568             "No Python files are present to be formatted. Nothing to do 😴",
569             quiet,
570             verbose,
571             ctx,
572         )
573
574         if len(sources) == 1:
575             reformat_one(
576                 src=sources.pop(),
577                 fast=fast,
578                 write_back=write_back,
579                 mode=mode,
580                 report=report,
581             )
582         else:
583             from black.concurrency import reformat_many
584
585             reformat_many(
586                 sources=sources,
587                 fast=fast,
588                 write_back=write_back,
589                 mode=mode,
590                 report=report,
591                 workers=workers,
592             )
593
594     if verbose or not quiet:
595         if code is None and (verbose or report.change_count or report.failure_count):
596             out()
597         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
598         if code is None:
599             click.echo(str(report), err=True)
600     ctx.exit(report.return_code)
601
602
603 def get_sources(
604     *,
605     ctx: click.Context,
606     src: Tuple[str, ...],
607     quiet: bool,
608     verbose: bool,
609     include: Pattern[str],
610     exclude: Optional[Pattern[str]],
611     extend_exclude: Optional[Pattern[str]],
612     force_exclude: Optional[Pattern[str]],
613     report: "Report",
614     stdin_filename: Optional[str],
615 ) -> Set[Path]:
616     """Compute the set of files to be formatted."""
617     sources: Set[Path] = set()
618     root = ctx.obj["root"]
619
620     for s in src:
621         if s == "-" and stdin_filename:
622             p = Path(stdin_filename)
623             is_stdin = True
624         else:
625             p = Path(s)
626             is_stdin = False
627
628         if is_stdin or p.is_file():
629             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
630             if normalized_path is None:
631                 continue
632
633             normalized_path = "/" + normalized_path
634             # Hard-exclude any files that matches the `--force-exclude` regex.
635             if force_exclude:
636                 force_exclude_match = force_exclude.search(normalized_path)
637             else:
638                 force_exclude_match = None
639             if force_exclude_match and force_exclude_match.group(0):
640                 report.path_ignored(p, "matches the --force-exclude regular expression")
641                 continue
642
643             if is_stdin:
644                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
645
646             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
647                 verbose=verbose, quiet=quiet
648             ):
649                 continue
650
651             sources.add(p)
652         elif p.is_dir():
653             if exclude is None:
654                 exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
655                 gitignore = get_gitignore(root)
656                 p_gitignore = get_gitignore(p)
657                 # No need to use p's gitignore if it is identical to root's gitignore
658                 # (i.e. root and p point to the same directory).
659                 if gitignore != p_gitignore:
660                     gitignore += p_gitignore
661             else:
662                 gitignore = None
663             sources.update(
664                 gen_python_files(
665                     p.iterdir(),
666                     ctx.obj["root"],
667                     include,
668                     exclude,
669                     extend_exclude,
670                     force_exclude,
671                     report,
672                     gitignore,
673                     verbose=verbose,
674                     quiet=quiet,
675                 )
676             )
677         elif s == "-":
678             sources.add(p)
679         else:
680             err(f"invalid path: {s}")
681     return sources
682
683
684 def path_empty(
685     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
686 ) -> None:
687     """
688     Exit if there is no `src` provided for formatting
689     """
690     if not src:
691         if verbose or not quiet:
692             out(msg)
693         ctx.exit(0)
694
695
696 def reformat_code(
697     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
698 ) -> None:
699     """
700     Reformat and print out `content` without spawning child processes.
701     Similar to `reformat_one`, but for string content.
702
703     `fast`, `write_back`, and `mode` options are passed to
704     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
705     """
706     path = Path("<string>")
707     try:
708         changed = Changed.NO
709         if format_stdin_to_stdout(
710             content=content, fast=fast, write_back=write_back, mode=mode
711         ):
712             changed = Changed.YES
713         report.done(path, changed)
714     except Exception as exc:
715         if report.verbose:
716             traceback.print_exc()
717         report.failed(path, str(exc))
718
719
720 # diff-shades depends on being to monkeypatch this function to operate. I know it's
721 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
722 @mypyc_attr(patchable=True)
723 def reformat_one(
724     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
725 ) -> None:
726     """Reformat a single file under `src` without spawning child processes.
727
728     `fast`, `write_back`, and `mode` options are passed to
729     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
730     """
731     try:
732         changed = Changed.NO
733
734         if str(src) == "-":
735             is_stdin = True
736         elif str(src).startswith(STDIN_PLACEHOLDER):
737             is_stdin = True
738             # Use the original name again in case we want to print something
739             # to the user
740             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
741         else:
742             is_stdin = False
743
744         if is_stdin:
745             if src.suffix == ".pyi":
746                 mode = replace(mode, is_pyi=True)
747             elif src.suffix == ".ipynb":
748                 mode = replace(mode, is_ipynb=True)
749             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
750                 changed = Changed.YES
751         else:
752             cache: Cache = {}
753             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
754                 cache = read_cache(mode)
755                 res_src = src.resolve()
756                 res_src_s = str(res_src)
757                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
758                     changed = Changed.CACHED
759             if changed is not Changed.CACHED and format_file_in_place(
760                 src, fast=fast, write_back=write_back, mode=mode
761             ):
762                 changed = Changed.YES
763             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
764                 write_back is WriteBack.CHECK and changed is Changed.NO
765             ):
766                 write_cache(cache, [src], mode)
767         report.done(src, changed)
768     except Exception as exc:
769         if report.verbose:
770             traceback.print_exc()
771         report.failed(src, str(exc))
772
773
774 def format_file_in_place(
775     src: Path,
776     fast: bool,
777     mode: Mode,
778     write_back: WriteBack = WriteBack.NO,
779     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
780 ) -> bool:
781     """Format file under `src` path. Return True if changed.
782
783     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
784     code to the file.
785     `mode` and `fast` options are passed to :func:`format_file_contents`.
786     """
787     if src.suffix == ".pyi":
788         mode = replace(mode, is_pyi=True)
789     elif src.suffix == ".ipynb":
790         mode = replace(mode, is_ipynb=True)
791
792     then = datetime.utcfromtimestamp(src.stat().st_mtime)
793     with open(src, "rb") as buf:
794         src_contents, encoding, newline = decode_bytes(buf.read())
795     try:
796         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
797     except NothingChanged:
798         return False
799     except JSONDecodeError:
800         raise ValueError(
801             f"File '{src}' cannot be parsed as valid Jupyter notebook."
802         ) from None
803
804     if write_back == WriteBack.YES:
805         with open(src, "w", encoding=encoding, newline=newline) as f:
806             f.write(dst_contents)
807     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
808         now = datetime.utcnow()
809         src_name = f"{src}\t{then} +0000"
810         dst_name = f"{src}\t{now} +0000"
811         if mode.is_ipynb:
812             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
813         else:
814             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
815
816         if write_back == WriteBack.COLOR_DIFF:
817             diff_contents = color_diff(diff_contents)
818
819         with lock or nullcontext():
820             f = io.TextIOWrapper(
821                 sys.stdout.buffer,
822                 encoding=encoding,
823                 newline=newline,
824                 write_through=True,
825             )
826             f = wrap_stream_for_windows(f)
827             f.write(diff_contents)
828             f.detach()
829
830     return True
831
832
833 def format_stdin_to_stdout(
834     fast: bool,
835     *,
836     content: Optional[str] = None,
837     write_back: WriteBack = WriteBack.NO,
838     mode: Mode,
839 ) -> bool:
840     """Format file on stdin. Return True if changed.
841
842     If content is None, it's read from sys.stdin.
843
844     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
845     write a diff to stdout. The `mode` argument is passed to
846     :func:`format_file_contents`.
847     """
848     then = datetime.utcnow()
849
850     if content is None:
851         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
852     else:
853         src, encoding, newline = content, "utf-8", ""
854
855     dst = src
856     try:
857         dst = format_file_contents(src, fast=fast, mode=mode)
858         return True
859
860     except NothingChanged:
861         return False
862
863     finally:
864         f = io.TextIOWrapper(
865             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
866         )
867         if write_back == WriteBack.YES:
868             # Make sure there's a newline after the content
869             if dst and dst[-1] != "\n":
870                 dst += "\n"
871             f.write(dst)
872         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
873             now = datetime.utcnow()
874             src_name = f"STDIN\t{then} +0000"
875             dst_name = f"STDOUT\t{now} +0000"
876             d = diff(src, dst, src_name, dst_name)
877             if write_back == WriteBack.COLOR_DIFF:
878                 d = color_diff(d)
879                 f = wrap_stream_for_windows(f)
880             f.write(d)
881         f.detach()
882
883
884 def check_stability_and_equivalence(
885     src_contents: str, dst_contents: str, *, mode: Mode
886 ) -> None:
887     """Perform stability and equivalence checks.
888
889     Raise AssertionError if source and destination contents are not
890     equivalent, or if a second pass of the formatter would format the
891     content differently.
892     """
893     assert_equivalent(src_contents, dst_contents)
894     assert_stable(src_contents, dst_contents, mode=mode)
895
896
897 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
898     """Reformat contents of a file and return new contents.
899
900     If `fast` is False, additionally confirm that the reformatted code is
901     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
902     `mode` is passed to :func:`format_str`.
903     """
904     if not src_contents.strip():
905         raise NothingChanged
906
907     if mode.is_ipynb:
908         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
909     else:
910         dst_contents = format_str(src_contents, mode=mode)
911     if src_contents == dst_contents:
912         raise NothingChanged
913
914     if not fast and not mode.is_ipynb:
915         # Jupyter notebooks will already have been checked above.
916         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
917     return dst_contents
918
919
920 def validate_cell(src: str, mode: Mode) -> None:
921     """Check that cell does not already contain TransformerManager transformations,
922     or non-Python cell magics, which might cause tokenizer_rt to break because of
923     indentations.
924
925     If a cell contains ``!ls``, then it'll be transformed to
926     ``get_ipython().system('ls')``. However, if the cell originally contained
927     ``get_ipython().system('ls')``, then it would get transformed in the same way:
928
929         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
930         "get_ipython().system('ls')\n"
931         >>> TransformerManager().transform_cell("!ls")
932         "get_ipython().system('ls')\n"
933
934     Due to the impossibility of safely roundtripping in such situations, cells
935     containing transformed magics will be ignored.
936     """
937     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
938         raise NothingChanged
939     if (
940         src[:2] == "%%"
941         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
942     ):
943         raise NothingChanged
944
945
946 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
947     """Format code in given cell of Jupyter notebook.
948
949     General idea is:
950
951       - if cell has trailing semicolon, remove it;
952       - if cell has IPython magics, mask them;
953       - format cell;
954       - reinstate IPython magics;
955       - reinstate trailing semicolon (if originally present);
956       - strip trailing newlines.
957
958     Cells with syntax errors will not be processed, as they
959     could potentially be automagics or multi-line magics, which
960     are currently not supported.
961     """
962     validate_cell(src, mode)
963     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
964         src
965     )
966     try:
967         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
968     except SyntaxError:
969         raise NothingChanged from None
970     masked_dst = format_str(masked_src, mode=mode)
971     if not fast:
972         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
973     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
974     dst = put_trailing_semicolon_back(
975         dst_without_trailing_semicolon, has_trailing_semicolon
976     )
977     dst = dst.rstrip("\n")
978     if dst == src:
979         raise NothingChanged from None
980     return dst
981
982
983 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
984     """If notebook is marked as non-Python, don't format it.
985
986     All notebook metadata fields are optional, see
987     https://nbformat.readthedocs.io/en/latest/format_description.html. So
988     if a notebook has empty metadata, we will try to parse it anyway.
989     """
990     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
991     if language is not None and language != "python":
992         raise NothingChanged from None
993
994
995 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
996     """Format Jupyter notebook.
997
998     Operate cell-by-cell, only on code cells, only for Python notebooks.
999     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1000     """
1001     trailing_newline = src_contents[-1] == "\n"
1002     modified = False
1003     nb = json.loads(src_contents)
1004     validate_metadata(nb)
1005     for cell in nb["cells"]:
1006         if cell.get("cell_type", None) == "code":
1007             try:
1008                 src = "".join(cell["source"])
1009                 dst = format_cell(src, fast=fast, mode=mode)
1010             except NothingChanged:
1011                 pass
1012             else:
1013                 cell["source"] = dst.splitlines(keepends=True)
1014                 modified = True
1015     if modified:
1016         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1017         if trailing_newline:
1018             dst_contents = dst_contents + "\n"
1019         return dst_contents
1020     else:
1021         raise NothingChanged
1022
1023
1024 def format_str(src_contents: str, *, mode: Mode) -> str:
1025     """Reformat a string and return new contents.
1026
1027     `mode` determines formatting options, such as how many characters per line are
1028     allowed.  Example:
1029
1030     >>> import black
1031     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1032     def f(arg: str = "") -> None:
1033         ...
1034
1035     A more complex example:
1036
1037     >>> print(
1038     ...   black.format_str(
1039     ...     "def f(arg:str='')->None: hey",
1040     ...     mode=black.Mode(
1041     ...       target_versions={black.TargetVersion.PY36},
1042     ...       line_length=10,
1043     ...       string_normalization=False,
1044     ...       is_pyi=False,
1045     ...     ),
1046     ...   ),
1047     ... )
1048     def f(
1049         arg: str = '',
1050     ) -> None:
1051         hey
1052
1053     """
1054     dst_contents = _format_str_once(src_contents, mode=mode)
1055     # Forced second pass to work around optional trailing commas (becoming
1056     # forced trailing commas on pass 2) interacting differently with optional
1057     # parentheses.  Admittedly ugly.
1058     if src_contents != dst_contents:
1059         return _format_str_once(dst_contents, mode=mode)
1060     return dst_contents
1061
1062
1063 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1064     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1065     dst_contents = []
1066     if mode.target_versions:
1067         versions = mode.target_versions
1068     else:
1069         future_imports = get_future_imports(src_node)
1070         versions = detect_target_versions(src_node, future_imports=future_imports)
1071
1072     normalize_fmt_off(src_node, preview=mode.preview)
1073     lines = LineGenerator(mode=mode)
1074     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1075     empty_line = Line(mode=mode)
1076     after = 0
1077     split_line_features = {
1078         feature
1079         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1080         if supports_feature(versions, feature)
1081     }
1082     for current_line in lines.visit(src_node):
1083         dst_contents.append(str(empty_line) * after)
1084         before, after = elt.maybe_empty_lines(current_line)
1085         dst_contents.append(str(empty_line) * before)
1086         for line in transform_line(
1087             current_line, mode=mode, features=split_line_features
1088         ):
1089             dst_contents.append(str(line))
1090     return "".join(dst_contents)
1091
1092
1093 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1094     """Return a tuple of (decoded_contents, encoding, newline).
1095
1096     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1097     universal newlines (i.e. only contains LF).
1098     """
1099     srcbuf = io.BytesIO(src)
1100     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1101     if not lines:
1102         return "", encoding, "\n"
1103
1104     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1105     srcbuf.seek(0)
1106     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1107         return tiow.read(), encoding, newline
1108
1109
1110 def get_features_used(  # noqa: C901
1111     node: Node, *, future_imports: Optional[Set[str]] = None
1112 ) -> Set[Feature]:
1113     """Return a set of (relatively) new Python features used in this file.
1114
1115     Currently looking for:
1116     - f-strings;
1117     - self-documenting expressions in f-strings (f"{x=}");
1118     - underscores in numeric literals;
1119     - trailing commas after * or ** in function signatures and calls;
1120     - positional only arguments in function signatures and lambdas;
1121     - assignment expression;
1122     - relaxed decorator syntax;
1123     - usage of __future__ flags (annotations);
1124     - print / exec statements;
1125     """
1126     features: Set[Feature] = set()
1127     if future_imports:
1128         features |= {
1129             FUTURE_FLAG_TO_FEATURE[future_import]
1130             for future_import in future_imports
1131             if future_import in FUTURE_FLAG_TO_FEATURE
1132         }
1133
1134     for n in node.pre_order():
1135         if is_string_token(n):
1136             value_head = n.value[:2]
1137             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1138                 features.add(Feature.F_STRINGS)
1139                 if Feature.DEBUG_F_STRINGS not in features:
1140                     for span_beg, span_end in iter_fexpr_spans(n.value):
1141                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1142                             features.add(Feature.DEBUG_F_STRINGS)
1143                             break
1144
1145         elif is_number_token(n):
1146             if "_" in n.value:
1147                 features.add(Feature.NUMERIC_UNDERSCORES)
1148
1149         elif n.type == token.SLASH:
1150             if n.parent and n.parent.type in {
1151                 syms.typedargslist,
1152                 syms.arglist,
1153                 syms.varargslist,
1154             }:
1155                 features.add(Feature.POS_ONLY_ARGUMENTS)
1156
1157         elif n.type == token.COLONEQUAL:
1158             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1159
1160         elif n.type == syms.decorator:
1161             if len(n.children) > 1 and not is_simple_decorator_expression(
1162                 n.children[1]
1163             ):
1164                 features.add(Feature.RELAXED_DECORATORS)
1165
1166         elif (
1167             n.type in {syms.typedargslist, syms.arglist}
1168             and n.children
1169             and n.children[-1].type == token.COMMA
1170         ):
1171             if n.type == syms.typedargslist:
1172                 feature = Feature.TRAILING_COMMA_IN_DEF
1173             else:
1174                 feature = Feature.TRAILING_COMMA_IN_CALL
1175
1176             for ch in n.children:
1177                 if ch.type in STARS:
1178                     features.add(feature)
1179
1180                 if ch.type == syms.argument:
1181                     for argch in ch.children:
1182                         if argch.type in STARS:
1183                             features.add(feature)
1184
1185         elif (
1186             n.type in {syms.return_stmt, syms.yield_expr}
1187             and len(n.children) >= 2
1188             and n.children[1].type == syms.testlist_star_expr
1189             and any(child.type == syms.star_expr for child in n.children[1].children)
1190         ):
1191             features.add(Feature.UNPACKING_ON_FLOW)
1192
1193         elif (
1194             n.type == syms.annassign
1195             and len(n.children) >= 4
1196             and n.children[3].type == syms.testlist_star_expr
1197         ):
1198             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1199
1200         elif (
1201             n.type == syms.except_clause
1202             and len(n.children) >= 2
1203             and n.children[1].type == token.STAR
1204         ):
1205             features.add(Feature.EXCEPT_STAR)
1206
1207         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1208             child.type == syms.star_expr for child in n.children
1209         ):
1210             features.add(Feature.VARIADIC_GENERICS)
1211
1212         elif (
1213             n.type == syms.tname_star
1214             and len(n.children) == 3
1215             and n.children[2].type == syms.star_expr
1216         ):
1217             features.add(Feature.VARIADIC_GENERICS)
1218
1219     return features
1220
1221
1222 def detect_target_versions(
1223     node: Node, *, future_imports: Optional[Set[str]] = None
1224 ) -> Set[TargetVersion]:
1225     """Detect the version to target based on the nodes used."""
1226     features = get_features_used(node, future_imports=future_imports)
1227     return {
1228         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1229     }
1230
1231
1232 def get_future_imports(node: Node) -> Set[str]:
1233     """Return a set of __future__ imports in the file."""
1234     imports: Set[str] = set()
1235
1236     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1237         for child in children:
1238             if isinstance(child, Leaf):
1239                 if child.type == token.NAME:
1240                     yield child.value
1241
1242             elif child.type == syms.import_as_name:
1243                 orig_name = child.children[0]
1244                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1245                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1246                 yield orig_name.value
1247
1248             elif child.type == syms.import_as_names:
1249                 yield from get_imports_from_children(child.children)
1250
1251             else:
1252                 raise AssertionError("Invalid syntax parsing imports")
1253
1254     for child in node.children:
1255         if child.type != syms.simple_stmt:
1256             break
1257
1258         first_child = child.children[0]
1259         if isinstance(first_child, Leaf):
1260             # Continue looking if we see a docstring; otherwise stop.
1261             if (
1262                 len(child.children) == 2
1263                 and first_child.type == token.STRING
1264                 and child.children[1].type == token.NEWLINE
1265             ):
1266                 continue
1267
1268             break
1269
1270         elif first_child.type == syms.import_from:
1271             module_name = first_child.children[1]
1272             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1273                 break
1274
1275             imports |= set(get_imports_from_children(first_child.children[3:]))
1276         else:
1277             break
1278
1279     return imports
1280
1281
1282 def assert_equivalent(src: str, dst: str) -> None:
1283     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1284     try:
1285         src_ast = parse_ast(src)
1286     except Exception as exc:
1287         raise AssertionError(
1288             "cannot use --safe with this file; failed to parse source file AST: "
1289             f"{exc}\n"
1290             "This could be caused by running Black with an older Python version "
1291             "that does not support new syntax used in your source file."
1292         ) from exc
1293
1294     try:
1295         dst_ast = parse_ast(dst)
1296     except Exception as exc:
1297         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1298         raise AssertionError(
1299             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1300             "Please report a bug on https://github.com/psf/black/issues.  "
1301             f"This invalid output might be helpful: {log}"
1302         ) from None
1303
1304     src_ast_str = "\n".join(stringify_ast(src_ast))
1305     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1306     if src_ast_str != dst_ast_str:
1307         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1308         raise AssertionError(
1309             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1310             " source.  Please report a bug on "
1311             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1312         ) from None
1313
1314
1315 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1316     """Raise AssertionError if `dst` reformats differently the second time."""
1317     # We shouldn't call format_str() here, because that formats the string
1318     # twice and may hide a bug where we bounce back and forth between two
1319     # versions.
1320     newdst = _format_str_once(dst, mode=mode)
1321     if dst != newdst:
1322         log = dump_to_file(
1323             str(mode),
1324             diff(src, dst, "source", "first pass"),
1325             diff(dst, newdst, "first pass", "second pass"),
1326         )
1327         raise AssertionError(
1328             "INTERNAL ERROR: Black produced different code on the second pass of the"
1329             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1330             f"  This diff might be helpful: {log}"
1331         ) from None
1332
1333
1334 @contextmanager
1335 def nullcontext() -> Iterator[None]:
1336     """Return an empty context manager.
1337
1338     To be used like `nullcontext` in Python 3.7.
1339     """
1340     yield
1341
1342
1343 def patch_click() -> None:
1344     """Make Click not crash on Python 3.6 with LANG=C.
1345
1346     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1347     default which restricts paths that it can access during the lifetime of the
1348     application.  Click refuses to work in this scenario by raising a RuntimeError.
1349
1350     In case of Black the likelihood that non-ASCII characters are going to be used in
1351     file paths is minimal since it's Python source code.  Moreover, this crash was
1352     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1353     """
1354     modules: List[Any] = []
1355     try:
1356         from click import core
1357     except ImportError:
1358         pass
1359     else:
1360         modules.append(core)
1361     try:
1362         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1363         # older versions installed.
1364         from click import _unicodefun  # type: ignore
1365     except ImportError:
1366         pass
1367     else:
1368         modules.append(_unicodefun)
1369
1370     for module in modules:
1371         if hasattr(module, "_verify_python3_env"):
1372             module._verify_python3_env = lambda: None  # type: ignore
1373         if hasattr(module, "_verify_python_env"):
1374             module._verify_python_env = lambda: None  # type: ignore
1375
1376
1377 def patched_main() -> None:
1378     if sys.platform == "win32" and getattr(sys, "frozen", False):
1379         from multiprocessing import freeze_support
1380
1381         freeze_support()
1382
1383     patch_click()
1384     main()
1385
1386
1387 if __name__ == "__main__":
1388     patched_main()