]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump sphinx-copybutton from 0.5.0 to 0.5.1 in /docs (#3390)
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache, get_cache_info, read_cache, write_cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import (
67     FUTURE_FLAG_TO_FEATURE,
68     VERSION_TO_FEATURES,
69     Feature,
70     Mode,
71     TargetVersion,
72     supports_feature,
73 )
74 from black.nodes import (
75     STARS,
76     is_number_token,
77     is_simple_decorator_expression,
78     is_string_token,
79     syms,
80 )
81 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
82 from black.parsing import InvalidInput  # noqa F401
83 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
84 from black.report import Changed, NothingChanged, Report
85 from black.trans import iter_fexpr_spans
86 from blib2to3.pgen2 import token
87 from blib2to3.pytree import Leaf, Node
88
89 COMPILED = Path(__file__).suffix in (".pyd", ".so")
90
91 # types
92 FileContent = str
93 Encoding = str
94 NewLine = str
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102     COLOR_DIFF = 4
103
104     @classmethod
105     def from_configuration(
106         cls, *, check: bool, diff: bool, color: bool = False
107     ) -> "WriteBack":
108         if check and not diff:
109             return cls.CHECK
110
111         if diff and color:
112             return cls.COLOR_DIFF
113
114         return cls.DIFF if diff else cls.YES
115
116
117 # Legacy name, left for integrations.
118 FileMode = Mode
119
120
121 def read_pyproject_toml(
122     ctx: click.Context, param: click.Parameter, value: Optional[str]
123 ) -> Optional[str]:
124     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
125
126     Returns the path to a successfully found and read configuration file, None
127     otherwise.
128     """
129     if not value:
130         value = find_pyproject_toml(ctx.params.get("src", ()))
131         if value is None:
132             return None
133
134     try:
135         config = parse_pyproject_toml(value)
136     except (OSError, ValueError) as e:
137         raise click.FileError(
138             filename=value, hint=f"Error reading configuration file: {e}"
139         ) from None
140
141     if not config:
142         return None
143     else:
144         # Sanitize the values to be Click friendly. For more information please see:
145         # https://github.com/psf/black/issues/1458
146         # https://github.com/pallets/click/issues/1567
147         config = {
148             k: str(v) if not isinstance(v, (list, dict)) else v
149             for k, v in config.items()
150         }
151
152     target_version = config.get("target_version")
153     if target_version is not None and not isinstance(target_version, list):
154         raise click.BadOptionUsage(
155             "target-version", "Config key target-version must be a list"
156         )
157
158     default_map: Dict[str, Any] = {}
159     if ctx.default_map:
160         default_map.update(ctx.default_map)
161     default_map.update(config)
162
163     ctx.default_map = default_map
164     return value
165
166
167 def target_version_option_callback(
168     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
169 ) -> List[TargetVersion]:
170     """Compute the target versions from a --target-version flag.
171
172     This is its own function because mypy couldn't infer the type correctly
173     when it was a lambda, causing mypyc trouble.
174     """
175     return [TargetVersion[val.upper()] for val in v]
176
177
178 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
179     """Compile a regular expression string in `regex`.
180
181     If it contains newlines, use verbose mode.
182     """
183     if "\n" in regex:
184         regex = "(?x)" + regex
185     compiled: Pattern[str] = re.compile(regex)
186     return compiled
187
188
189 def validate_regex(
190     ctx: click.Context,
191     param: click.Parameter,
192     value: Optional[str],
193 ) -> Optional[Pattern[str]]:
194     try:
195         return re_compile_maybe_verbose(value) if value is not None else None
196     except re.error as e:
197         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
198
199
200 @click.command(
201     context_settings={"help_option_names": ["-h", "--help"]},
202     # While Click does set this field automatically using the docstring, mypyc
203     # (annoyingly) strips 'em so we need to set it here too.
204     help="The uncompromising code formatter.",
205 )
206 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
207 @click.option(
208     "-l",
209     "--line-length",
210     type=int,
211     default=DEFAULT_LINE_LENGTH,
212     help="How many characters per line to allow.",
213     show_default=True,
214 )
215 @click.option(
216     "-t",
217     "--target-version",
218     type=click.Choice([v.name.lower() for v in TargetVersion]),
219     callback=target_version_option_callback,
220     multiple=True,
221     help=(
222         "Python versions that should be supported by Black's output. [default: per-file"
223         " auto-detection]"
224     ),
225 )
226 @click.option(
227     "--pyi",
228     is_flag=True,
229     help=(
230         "Format all input files like typing stubs regardless of file extension (useful"
231         " when piping source on standard input)."
232     ),
233 )
234 @click.option(
235     "--ipynb",
236     is_flag=True,
237     help=(
238         "Format all input files like Jupyter Notebooks regardless of file extension "
239         "(useful when piping source on standard input)."
240     ),
241 )
242 @click.option(
243     "--python-cell-magics",
244     multiple=True,
245     help=(
246         "When processing Jupyter Notebooks, add the given magic to the list"
247         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
248         " Useful for formatting cells with custom python magics."
249     ),
250     default=[],
251 )
252 @click.option(
253     "-x",
254     "--skip-source-first-line",
255     is_flag=True,
256     help="Skip the first line of the source code.",
257 )
258 @click.option(
259     "-S",
260     "--skip-string-normalization",
261     is_flag=True,
262     help="Don't normalize string quotes or prefixes.",
263 )
264 @click.option(
265     "-C",
266     "--skip-magic-trailing-comma",
267     is_flag=True,
268     help="Don't use trailing commas as a reason to split lines.",
269 )
270 @click.option(
271     "--experimental-string-processing",
272     is_flag=True,
273     hidden=True,
274     help="(DEPRECATED and now included in --preview) Normalize string literals.",
275 )
276 @click.option(
277     "--preview",
278     is_flag=True,
279     help=(
280         "Enable potentially disruptive style changes that may be added to Black's main"
281         " functionality in the next major release."
282     ),
283 )
284 @click.option(
285     "--check",
286     is_flag=True,
287     help=(
288         "Don't write the files back, just return the status. Return code 0 means"
289         " nothing would change. Return code 1 means some files would be reformatted."
290         " Return code 123 means there was an internal error."
291     ),
292 )
293 @click.option(
294     "--diff",
295     is_flag=True,
296     help="Don't write the files back, just output a diff for each file on stdout.",
297 )
298 @click.option(
299     "--color/--no-color",
300     is_flag=True,
301     help="Show colored diff. Only applies when `--diff` is given.",
302 )
303 @click.option(
304     "--fast/--safe",
305     is_flag=True,
306     help="If --fast given, skip temporary sanity checks. [default: --safe]",
307 )
308 @click.option(
309     "--required-version",
310     type=str,
311     help=(
312         "Require a specific version of Black to be running (useful for unifying results"
313         " across many environments e.g. with a pyproject.toml file). It can be"
314         " either a major version number or an exact version."
315     ),
316 )
317 @click.option(
318     "--include",
319     type=str,
320     default=DEFAULT_INCLUDES,
321     callback=validate_regex,
322     help=(
323         "A regular expression that matches files and directories that should be"
324         " included on recursive searches. An empty value means all files are included"
325         " regardless of the name. Use forward slashes for directories on all platforms"
326         " (Windows, too). Exclusions are calculated first, inclusions later."
327     ),
328     show_default=True,
329 )
330 @click.option(
331     "--exclude",
332     type=str,
333     callback=validate_regex,
334     help=(
335         "A regular expression that matches files and directories that should be"
336         " excluded on recursive searches. An empty value means no paths are excluded."
337         " Use forward slashes for directories on all platforms (Windows, too)."
338         " Exclusions are calculated first, inclusions later. [default:"
339         f" {DEFAULT_EXCLUDES}]"
340     ),
341     show_default=False,
342 )
343 @click.option(
344     "--extend-exclude",
345     type=str,
346     callback=validate_regex,
347     help=(
348         "Like --exclude, but adds additional files and directories on top of the"
349         " excluded ones. (Useful if you simply want to add to the default)"
350     ),
351 )
352 @click.option(
353     "--force-exclude",
354     type=str,
355     callback=validate_regex,
356     help=(
357         "Like --exclude, but files and directories matching this regex will be "
358         "excluded even when they are passed explicitly as arguments."
359     ),
360 )
361 @click.option(
362     "--stdin-filename",
363     type=str,
364     help=(
365         "The name of the file when passing it through stdin. Useful to make "
366         "sure Black will respect --force-exclude option on some "
367         "editors that rely on using stdin."
368     ),
369 )
370 @click.option(
371     "-W",
372     "--workers",
373     type=click.IntRange(min=1),
374     default=None,
375     help="Number of parallel workers [default: number of CPUs in the system]",
376 )
377 @click.option(
378     "-q",
379     "--quiet",
380     is_flag=True,
381     help=(
382         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
383         " those with 2>/dev/null."
384     ),
385 )
386 @click.option(
387     "-v",
388     "--verbose",
389     is_flag=True,
390     help=(
391         "Also emit messages to stderr about files that were not changed or were ignored"
392         " due to exclusion patterns."
393     ),
394 )
395 @click.version_option(
396     version=__version__,
397     message=(
398         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
399         f"Python ({platform.python_implementation()}) {platform.python_version()}"
400     ),
401 )
402 @click.argument(
403     "src",
404     nargs=-1,
405     type=click.Path(
406         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
407     ),
408     is_eager=True,
409     metavar="SRC ...",
410 )
411 @click.option(
412     "--config",
413     type=click.Path(
414         exists=True,
415         file_okay=True,
416         dir_okay=False,
417         readable=True,
418         allow_dash=False,
419         path_type=str,
420     ),
421     is_eager=True,
422     callback=read_pyproject_toml,
423     help="Read configuration from FILE path.",
424 )
425 @click.pass_context
426 def main(  # noqa: C901
427     ctx: click.Context,
428     code: Optional[str],
429     line_length: int,
430     target_version: List[TargetVersion],
431     check: bool,
432     diff: bool,
433     color: bool,
434     fast: bool,
435     pyi: bool,
436     ipynb: bool,
437     python_cell_magics: Sequence[str],
438     skip_source_first_line: bool,
439     skip_string_normalization: bool,
440     skip_magic_trailing_comma: bool,
441     experimental_string_processing: bool,
442     preview: bool,
443     quiet: bool,
444     verbose: bool,
445     required_version: Optional[str],
446     include: Pattern[str],
447     exclude: Optional[Pattern[str]],
448     extend_exclude: Optional[Pattern[str]],
449     force_exclude: Optional[Pattern[str]],
450     stdin_filename: Optional[str],
451     workers: Optional[int],
452     src: Tuple[str, ...],
453     config: Optional[str],
454 ) -> None:
455     """The uncompromising code formatter."""
456     ctx.ensure_object(dict)
457
458     if src and code is not None:
459         out(
460             main.get_usage(ctx)
461             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
462         )
463         ctx.exit(1)
464     if not src and code is None:
465         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
466         ctx.exit(1)
467
468     root, method = (
469         find_project_root(src, stdin_filename) if code is None else (None, None)
470     )
471     ctx.obj["root"] = root
472
473     if verbose:
474         if root:
475             out(
476                 f"Identified `{root}` as project root containing a {method}.",
477                 fg="blue",
478             )
479
480             normalized = [
481                 (source, source)
482                 if source == "-"
483                 else (normalize_path_maybe_ignore(Path(source), root), source)
484                 for source in src
485             ]
486             srcs_string = ", ".join(
487                 [
488                     f'"{_norm}"'
489                     if _norm
490                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
491                     for _norm, source in normalized
492                 ]
493             )
494             out(f"Sources to be formatted: {srcs_string}", fg="blue")
495
496         if config:
497             config_source = ctx.get_parameter_source("config")
498             user_level_config = str(find_user_pyproject_toml())
499             if config == user_level_config:
500                 out(
501                     (
502                         "Using configuration from user-level config at "
503                         f"'{user_level_config}'."
504                     ),
505                     fg="blue",
506                 )
507             elif config_source in (
508                 ParameterSource.DEFAULT,
509                 ParameterSource.DEFAULT_MAP,
510             ):
511                 out("Using configuration from project root.", fg="blue")
512             else:
513                 out(f"Using configuration in '{config}'.", fg="blue")
514
515     error_msg = "Oh no! 💥 💔 💥"
516     if (
517         required_version
518         and required_version != __version__
519         and required_version != __version__.split(".")[0]
520     ):
521         err(
522             f"{error_msg} The required version `{required_version}` does not match"
523             f" the running version `{__version__}`!"
524         )
525         ctx.exit(1)
526     if ipynb and pyi:
527         err("Cannot pass both `pyi` and `ipynb` flags!")
528         ctx.exit(1)
529
530     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
531     if target_version:
532         versions = set(target_version)
533     else:
534         # We'll autodetect later.
535         versions = set()
536     mode = Mode(
537         target_versions=versions,
538         line_length=line_length,
539         is_pyi=pyi,
540         is_ipynb=ipynb,
541         skip_source_first_line=skip_source_first_line,
542         string_normalization=not skip_string_normalization,
543         magic_trailing_comma=not skip_magic_trailing_comma,
544         experimental_string_processing=experimental_string_processing,
545         preview=preview,
546         python_cell_magics=set(python_cell_magics),
547     )
548
549     if code is not None:
550         # Run in quiet mode by default with -c; the extra output isn't useful.
551         # You can still pass -v to get verbose output.
552         quiet = True
553
554     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
555
556     if code is not None:
557         reformat_code(
558             content=code, fast=fast, write_back=write_back, mode=mode, report=report
559         )
560     else:
561         try:
562             sources = get_sources(
563                 ctx=ctx,
564                 src=src,
565                 quiet=quiet,
566                 verbose=verbose,
567                 include=include,
568                 exclude=exclude,
569                 extend_exclude=extend_exclude,
570                 force_exclude=force_exclude,
571                 report=report,
572                 stdin_filename=stdin_filename,
573             )
574         except GitWildMatchPatternError:
575             ctx.exit(1)
576
577         path_empty(
578             sources,
579             "No Python files are present to be formatted. Nothing to do 😴",
580             quiet,
581             verbose,
582             ctx,
583         )
584
585         if len(sources) == 1:
586             reformat_one(
587                 src=sources.pop(),
588                 fast=fast,
589                 write_back=write_back,
590                 mode=mode,
591                 report=report,
592             )
593         else:
594             from black.concurrency import reformat_many
595
596             reformat_many(
597                 sources=sources,
598                 fast=fast,
599                 write_back=write_back,
600                 mode=mode,
601                 report=report,
602                 workers=workers,
603             )
604
605     if verbose or not quiet:
606         if code is None and (verbose or report.change_count or report.failure_count):
607             out()
608         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
609         if code is None:
610             click.echo(str(report), err=True)
611     ctx.exit(report.return_code)
612
613
614 def get_sources(
615     *,
616     ctx: click.Context,
617     src: Tuple[str, ...],
618     quiet: bool,
619     verbose: bool,
620     include: Pattern[str],
621     exclude: Optional[Pattern[str]],
622     extend_exclude: Optional[Pattern[str]],
623     force_exclude: Optional[Pattern[str]],
624     report: "Report",
625     stdin_filename: Optional[str],
626 ) -> Set[Path]:
627     """Compute the set of files to be formatted."""
628     sources: Set[Path] = set()
629     root = ctx.obj["root"]
630
631     using_default_exclude = exclude is None
632     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
633     gitignore: Optional[PathSpec] = None
634     root_gitignore = get_gitignore(root)
635
636     for s in src:
637         if s == "-" and stdin_filename:
638             p = Path(stdin_filename)
639             is_stdin = True
640         else:
641             p = Path(s)
642             is_stdin = False
643
644         if is_stdin or p.is_file():
645             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
646             if normalized_path is None:
647                 continue
648
649             normalized_path = "/" + normalized_path
650             # Hard-exclude any files that matches the `--force-exclude` regex.
651             if force_exclude:
652                 force_exclude_match = force_exclude.search(normalized_path)
653             else:
654                 force_exclude_match = None
655             if force_exclude_match and force_exclude_match.group(0):
656                 report.path_ignored(p, "matches the --force-exclude regular expression")
657                 continue
658
659             if is_stdin:
660                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
661
662             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
663                 verbose=verbose, quiet=quiet
664             ):
665                 continue
666
667             sources.add(p)
668         elif p.is_dir():
669             if using_default_exclude:
670                 gitignore = {
671                     root: root_gitignore,
672                     root / p: get_gitignore(p),
673                 }
674             sources.update(
675                 gen_python_files(
676                     p.iterdir(),
677                     ctx.obj["root"],
678                     include,
679                     exclude,
680                     extend_exclude,
681                     force_exclude,
682                     report,
683                     gitignore,
684                     verbose=verbose,
685                     quiet=quiet,
686                 )
687             )
688         elif s == "-":
689             sources.add(p)
690         else:
691             err(f"invalid path: {s}")
692     return sources
693
694
695 def path_empty(
696     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
697 ) -> None:
698     """
699     Exit if there is no `src` provided for formatting
700     """
701     if not src:
702         if verbose or not quiet:
703             out(msg)
704         ctx.exit(0)
705
706
707 def reformat_code(
708     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
709 ) -> None:
710     """
711     Reformat and print out `content` without spawning child processes.
712     Similar to `reformat_one`, but for string content.
713
714     `fast`, `write_back`, and `mode` options are passed to
715     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
716     """
717     path = Path("<string>")
718     try:
719         changed = Changed.NO
720         if format_stdin_to_stdout(
721             content=content, fast=fast, write_back=write_back, mode=mode
722         ):
723             changed = Changed.YES
724         report.done(path, changed)
725     except Exception as exc:
726         if report.verbose:
727             traceback.print_exc()
728         report.failed(path, str(exc))
729
730
731 # diff-shades depends on being to monkeypatch this function to operate. I know it's
732 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
733 @mypyc_attr(patchable=True)
734 def reformat_one(
735     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
736 ) -> None:
737     """Reformat a single file under `src` without spawning child processes.
738
739     `fast`, `write_back`, and `mode` options are passed to
740     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
741     """
742     try:
743         changed = Changed.NO
744
745         if str(src) == "-":
746             is_stdin = True
747         elif str(src).startswith(STDIN_PLACEHOLDER):
748             is_stdin = True
749             # Use the original name again in case we want to print something
750             # to the user
751             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
752         else:
753             is_stdin = False
754
755         if is_stdin:
756             if src.suffix == ".pyi":
757                 mode = replace(mode, is_pyi=True)
758             elif src.suffix == ".ipynb":
759                 mode = replace(mode, is_ipynb=True)
760             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
761                 changed = Changed.YES
762         else:
763             cache: Cache = {}
764             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
765                 cache = read_cache(mode)
766                 res_src = src.resolve()
767                 res_src_s = str(res_src)
768                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
769                     changed = Changed.CACHED
770             if changed is not Changed.CACHED and format_file_in_place(
771                 src, fast=fast, write_back=write_back, mode=mode
772             ):
773                 changed = Changed.YES
774             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
775                 write_back is WriteBack.CHECK and changed is Changed.NO
776             ):
777                 write_cache(cache, [src], mode)
778         report.done(src, changed)
779     except Exception as exc:
780         if report.verbose:
781             traceback.print_exc()
782         report.failed(src, str(exc))
783
784
785 def format_file_in_place(
786     src: Path,
787     fast: bool,
788     mode: Mode,
789     write_back: WriteBack = WriteBack.NO,
790     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
791 ) -> bool:
792     """Format file under `src` path. Return True if changed.
793
794     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
795     code to the file.
796     `mode` and `fast` options are passed to :func:`format_file_contents`.
797     """
798     if src.suffix == ".pyi":
799         mode = replace(mode, is_pyi=True)
800     elif src.suffix == ".ipynb":
801         mode = replace(mode, is_ipynb=True)
802
803     then = datetime.utcfromtimestamp(src.stat().st_mtime)
804     header = b""
805     with open(src, "rb") as buf:
806         if mode.skip_source_first_line:
807             header = buf.readline()
808         src_contents, encoding, newline = decode_bytes(buf.read())
809     try:
810         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
811     except NothingChanged:
812         return False
813     except JSONDecodeError:
814         raise ValueError(
815             f"File '{src}' cannot be parsed as valid Jupyter notebook."
816         ) from None
817     src_contents = header.decode(encoding) + src_contents
818     dst_contents = header.decode(encoding) + dst_contents
819
820     if write_back == WriteBack.YES:
821         with open(src, "w", encoding=encoding, newline=newline) as f:
822             f.write(dst_contents)
823     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
824         now = datetime.utcnow()
825         src_name = f"{src}\t{then} +0000"
826         dst_name = f"{src}\t{now} +0000"
827         if mode.is_ipynb:
828             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
829         else:
830             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
831
832         if write_back == WriteBack.COLOR_DIFF:
833             diff_contents = color_diff(diff_contents)
834
835         with lock or nullcontext():
836             f = io.TextIOWrapper(
837                 sys.stdout.buffer,
838                 encoding=encoding,
839                 newline=newline,
840                 write_through=True,
841             )
842             f = wrap_stream_for_windows(f)
843             f.write(diff_contents)
844             f.detach()
845
846     return True
847
848
849 def format_stdin_to_stdout(
850     fast: bool,
851     *,
852     content: Optional[str] = None,
853     write_back: WriteBack = WriteBack.NO,
854     mode: Mode,
855 ) -> bool:
856     """Format file on stdin. Return True if changed.
857
858     If content is None, it's read from sys.stdin.
859
860     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
861     write a diff to stdout. The `mode` argument is passed to
862     :func:`format_file_contents`.
863     """
864     then = datetime.utcnow()
865
866     if content is None:
867         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
868     else:
869         src, encoding, newline = content, "utf-8", ""
870
871     dst = src
872     try:
873         dst = format_file_contents(src, fast=fast, mode=mode)
874         return True
875
876     except NothingChanged:
877         return False
878
879     finally:
880         f = io.TextIOWrapper(
881             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
882         )
883         if write_back == WriteBack.YES:
884             # Make sure there's a newline after the content
885             if dst and dst[-1] != "\n":
886                 dst += "\n"
887             f.write(dst)
888         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
889             now = datetime.utcnow()
890             src_name = f"STDIN\t{then} +0000"
891             dst_name = f"STDOUT\t{now} +0000"
892             d = diff(src, dst, src_name, dst_name)
893             if write_back == WriteBack.COLOR_DIFF:
894                 d = color_diff(d)
895                 f = wrap_stream_for_windows(f)
896             f.write(d)
897         f.detach()
898
899
900 def check_stability_and_equivalence(
901     src_contents: str, dst_contents: str, *, mode: Mode
902 ) -> None:
903     """Perform stability and equivalence checks.
904
905     Raise AssertionError if source and destination contents are not
906     equivalent, or if a second pass of the formatter would format the
907     content differently.
908     """
909     assert_equivalent(src_contents, dst_contents)
910     assert_stable(src_contents, dst_contents, mode=mode)
911
912
913 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
914     """Reformat contents of a file and return new contents.
915
916     If `fast` is False, additionally confirm that the reformatted code is
917     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
918     `mode` is passed to :func:`format_str`.
919     """
920     if not mode.preview and not src_contents.strip():
921         raise NothingChanged
922
923     if mode.is_ipynb:
924         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
925     else:
926         dst_contents = format_str(src_contents, mode=mode)
927     if src_contents == dst_contents:
928         raise NothingChanged
929
930     if not fast and not mode.is_ipynb:
931         # Jupyter notebooks will already have been checked above.
932         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
933     return dst_contents
934
935
936 def validate_cell(src: str, mode: Mode) -> None:
937     """Check that cell does not already contain TransformerManager transformations,
938     or non-Python cell magics, which might cause tokenizer_rt to break because of
939     indentations.
940
941     If a cell contains ``!ls``, then it'll be transformed to
942     ``get_ipython().system('ls')``. However, if the cell originally contained
943     ``get_ipython().system('ls')``, then it would get transformed in the same way:
944
945         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
946         "get_ipython().system('ls')\n"
947         >>> TransformerManager().transform_cell("!ls")
948         "get_ipython().system('ls')\n"
949
950     Due to the impossibility of safely roundtripping in such situations, cells
951     containing transformed magics will be ignored.
952     """
953     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
954         raise NothingChanged
955     if (
956         src[:2] == "%%"
957         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
958     ):
959         raise NothingChanged
960
961
962 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
963     """Format code in given cell of Jupyter notebook.
964
965     General idea is:
966
967       - if cell has trailing semicolon, remove it;
968       - if cell has IPython magics, mask them;
969       - format cell;
970       - reinstate IPython magics;
971       - reinstate trailing semicolon (if originally present);
972       - strip trailing newlines.
973
974     Cells with syntax errors will not be processed, as they
975     could potentially be automagics or multi-line magics, which
976     are currently not supported.
977     """
978     validate_cell(src, mode)
979     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
980         src
981     )
982     try:
983         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
984     except SyntaxError:
985         raise NothingChanged from None
986     masked_dst = format_str(masked_src, mode=mode)
987     if not fast:
988         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
989     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
990     dst = put_trailing_semicolon_back(
991         dst_without_trailing_semicolon, has_trailing_semicolon
992     )
993     dst = dst.rstrip("\n")
994     if dst == src:
995         raise NothingChanged from None
996     return dst
997
998
999 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1000     """If notebook is marked as non-Python, don't format it.
1001
1002     All notebook metadata fields are optional, see
1003     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1004     if a notebook has empty metadata, we will try to parse it anyway.
1005     """
1006     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1007     if language is not None and language != "python":
1008         raise NothingChanged from None
1009
1010
1011 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1012     """Format Jupyter notebook.
1013
1014     Operate cell-by-cell, only on code cells, only for Python notebooks.
1015     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1016     """
1017     if mode.preview and not src_contents:
1018         raise NothingChanged
1019
1020     trailing_newline = src_contents[-1] == "\n"
1021     modified = False
1022     nb = json.loads(src_contents)
1023     validate_metadata(nb)
1024     for cell in nb["cells"]:
1025         if cell.get("cell_type", None) == "code":
1026             try:
1027                 src = "".join(cell["source"])
1028                 dst = format_cell(src, fast=fast, mode=mode)
1029             except NothingChanged:
1030                 pass
1031             else:
1032                 cell["source"] = dst.splitlines(keepends=True)
1033                 modified = True
1034     if modified:
1035         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1036         if trailing_newline:
1037             dst_contents = dst_contents + "\n"
1038         return dst_contents
1039     else:
1040         raise NothingChanged
1041
1042
1043 def format_str(src_contents: str, *, mode: Mode) -> str:
1044     """Reformat a string and return new contents.
1045
1046     `mode` determines formatting options, such as how many characters per line are
1047     allowed.  Example:
1048
1049     >>> import black
1050     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1051     def f(arg: str = "") -> None:
1052         ...
1053
1054     A more complex example:
1055
1056     >>> print(
1057     ...   black.format_str(
1058     ...     "def f(arg:str='')->None: hey",
1059     ...     mode=black.Mode(
1060     ...       target_versions={black.TargetVersion.PY36},
1061     ...       line_length=10,
1062     ...       string_normalization=False,
1063     ...       is_pyi=False,
1064     ...     ),
1065     ...   ),
1066     ... )
1067     def f(
1068         arg: str = '',
1069     ) -> None:
1070         hey
1071
1072     """
1073     dst_contents = _format_str_once(src_contents, mode=mode)
1074     # Forced second pass to work around optional trailing commas (becoming
1075     # forced trailing commas on pass 2) interacting differently with optional
1076     # parentheses.  Admittedly ugly.
1077     if src_contents != dst_contents:
1078         return _format_str_once(dst_contents, mode=mode)
1079     return dst_contents
1080
1081
1082 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1083     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1084     dst_blocks: List[LinesBlock] = []
1085     if mode.target_versions:
1086         versions = mode.target_versions
1087     else:
1088         future_imports = get_future_imports(src_node)
1089         versions = detect_target_versions(src_node, future_imports=future_imports)
1090
1091     normalize_fmt_off(src_node, preview=mode.preview)
1092     lines = LineGenerator(mode=mode)
1093     elt = EmptyLineTracker(mode=mode)
1094     split_line_features = {
1095         feature
1096         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1097         if supports_feature(versions, feature)
1098     }
1099     block: Optional[LinesBlock] = None
1100     for current_line in lines.visit(src_node):
1101         block = elt.maybe_empty_lines(current_line)
1102         dst_blocks.append(block)
1103         for line in transform_line(
1104             current_line, mode=mode, features=split_line_features
1105         ):
1106             block.content_lines.append(str(line))
1107     if dst_blocks:
1108         dst_blocks[-1].after = 0
1109     dst_contents = []
1110     for block in dst_blocks:
1111         dst_contents.extend(block.all_lines())
1112     if mode.preview and not dst_contents:
1113         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1114         # and check if normalized_content has more than one line
1115         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1116         if "\n" in normalized_content:
1117             return newline
1118         return ""
1119     return "".join(dst_contents)
1120
1121
1122 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1123     """Return a tuple of (decoded_contents, encoding, newline).
1124
1125     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1126     universal newlines (i.e. only contains LF).
1127     """
1128     srcbuf = io.BytesIO(src)
1129     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1130     if not lines:
1131         return "", encoding, "\n"
1132
1133     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1134     srcbuf.seek(0)
1135     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1136         return tiow.read(), encoding, newline
1137
1138
1139 def get_features_used(  # noqa: C901
1140     node: Node, *, future_imports: Optional[Set[str]] = None
1141 ) -> Set[Feature]:
1142     """Return a set of (relatively) new Python features used in this file.
1143
1144     Currently looking for:
1145     - f-strings;
1146     - self-documenting expressions in f-strings (f"{x=}");
1147     - underscores in numeric literals;
1148     - trailing commas after * or ** in function signatures and calls;
1149     - positional only arguments in function signatures and lambdas;
1150     - assignment expression;
1151     - relaxed decorator syntax;
1152     - usage of __future__ flags (annotations);
1153     - print / exec statements;
1154     """
1155     features: Set[Feature] = set()
1156     if future_imports:
1157         features |= {
1158             FUTURE_FLAG_TO_FEATURE[future_import]
1159             for future_import in future_imports
1160             if future_import in FUTURE_FLAG_TO_FEATURE
1161         }
1162
1163     for n in node.pre_order():
1164         if is_string_token(n):
1165             value_head = n.value[:2]
1166             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1167                 features.add(Feature.F_STRINGS)
1168                 if Feature.DEBUG_F_STRINGS not in features:
1169                     for span_beg, span_end in iter_fexpr_spans(n.value):
1170                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1171                             features.add(Feature.DEBUG_F_STRINGS)
1172                             break
1173
1174         elif is_number_token(n):
1175             if "_" in n.value:
1176                 features.add(Feature.NUMERIC_UNDERSCORES)
1177
1178         elif n.type == token.SLASH:
1179             if n.parent and n.parent.type in {
1180                 syms.typedargslist,
1181                 syms.arglist,
1182                 syms.varargslist,
1183             }:
1184                 features.add(Feature.POS_ONLY_ARGUMENTS)
1185
1186         elif n.type == token.COLONEQUAL:
1187             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1188
1189         elif n.type == syms.decorator:
1190             if len(n.children) > 1 and not is_simple_decorator_expression(
1191                 n.children[1]
1192             ):
1193                 features.add(Feature.RELAXED_DECORATORS)
1194
1195         elif (
1196             n.type in {syms.typedargslist, syms.arglist}
1197             and n.children
1198             and n.children[-1].type == token.COMMA
1199         ):
1200             if n.type == syms.typedargslist:
1201                 feature = Feature.TRAILING_COMMA_IN_DEF
1202             else:
1203                 feature = Feature.TRAILING_COMMA_IN_CALL
1204
1205             for ch in n.children:
1206                 if ch.type in STARS:
1207                     features.add(feature)
1208
1209                 if ch.type == syms.argument:
1210                     for argch in ch.children:
1211                         if argch.type in STARS:
1212                             features.add(feature)
1213
1214         elif (
1215             n.type in {syms.return_stmt, syms.yield_expr}
1216             and len(n.children) >= 2
1217             and n.children[1].type == syms.testlist_star_expr
1218             and any(child.type == syms.star_expr for child in n.children[1].children)
1219         ):
1220             features.add(Feature.UNPACKING_ON_FLOW)
1221
1222         elif (
1223             n.type == syms.annassign
1224             and len(n.children) >= 4
1225             and n.children[3].type == syms.testlist_star_expr
1226         ):
1227             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1228
1229         elif (
1230             n.type == syms.except_clause
1231             and len(n.children) >= 2
1232             and n.children[1].type == token.STAR
1233         ):
1234             features.add(Feature.EXCEPT_STAR)
1235
1236         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1237             child.type == syms.star_expr for child in n.children
1238         ):
1239             features.add(Feature.VARIADIC_GENERICS)
1240
1241         elif (
1242             n.type == syms.tname_star
1243             and len(n.children) == 3
1244             and n.children[2].type == syms.star_expr
1245         ):
1246             features.add(Feature.VARIADIC_GENERICS)
1247
1248     return features
1249
1250
1251 def detect_target_versions(
1252     node: Node, *, future_imports: Optional[Set[str]] = None
1253 ) -> Set[TargetVersion]:
1254     """Detect the version to target based on the nodes used."""
1255     features = get_features_used(node, future_imports=future_imports)
1256     return {
1257         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1258     }
1259
1260
1261 def get_future_imports(node: Node) -> Set[str]:
1262     """Return a set of __future__ imports in the file."""
1263     imports: Set[str] = set()
1264
1265     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1266         for child in children:
1267             if isinstance(child, Leaf):
1268                 if child.type == token.NAME:
1269                     yield child.value
1270
1271             elif child.type == syms.import_as_name:
1272                 orig_name = child.children[0]
1273                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1274                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1275                 yield orig_name.value
1276
1277             elif child.type == syms.import_as_names:
1278                 yield from get_imports_from_children(child.children)
1279
1280             else:
1281                 raise AssertionError("Invalid syntax parsing imports")
1282
1283     for child in node.children:
1284         if child.type != syms.simple_stmt:
1285             break
1286
1287         first_child = child.children[0]
1288         if isinstance(first_child, Leaf):
1289             # Continue looking if we see a docstring; otherwise stop.
1290             if (
1291                 len(child.children) == 2
1292                 and first_child.type == token.STRING
1293                 and child.children[1].type == token.NEWLINE
1294             ):
1295                 continue
1296
1297             break
1298
1299         elif first_child.type == syms.import_from:
1300             module_name = first_child.children[1]
1301             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1302                 break
1303
1304             imports |= set(get_imports_from_children(first_child.children[3:]))
1305         else:
1306             break
1307
1308     return imports
1309
1310
1311 def assert_equivalent(src: str, dst: str) -> None:
1312     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1313     try:
1314         src_ast = parse_ast(src)
1315     except Exception as exc:
1316         raise AssertionError(
1317             "cannot use --safe with this file; failed to parse source file AST: "
1318             f"{exc}\n"
1319             "This could be caused by running Black with an older Python version "
1320             "that does not support new syntax used in your source file."
1321         ) from exc
1322
1323     try:
1324         dst_ast = parse_ast(dst)
1325     except Exception as exc:
1326         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1327         raise AssertionError(
1328             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1329             "Please report a bug on https://github.com/psf/black/issues.  "
1330             f"This invalid output might be helpful: {log}"
1331         ) from None
1332
1333     src_ast_str = "\n".join(stringify_ast(src_ast))
1334     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1335     if src_ast_str != dst_ast_str:
1336         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1337         raise AssertionError(
1338             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1339             " source.  Please report a bug on "
1340             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1341         ) from None
1342
1343
1344 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1345     """Raise AssertionError if `dst` reformats differently the second time."""
1346     # We shouldn't call format_str() here, because that formats the string
1347     # twice and may hide a bug where we bounce back and forth between two
1348     # versions.
1349     newdst = _format_str_once(dst, mode=mode)
1350     if dst != newdst:
1351         log = dump_to_file(
1352             str(mode),
1353             diff(src, dst, "source", "first pass"),
1354             diff(dst, newdst, "first pass", "second pass"),
1355         )
1356         raise AssertionError(
1357             "INTERNAL ERROR: Black produced different code on the second pass of the"
1358             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1359             f"  This diff might be helpful: {log}"
1360         ) from None
1361
1362
1363 @contextmanager
1364 def nullcontext() -> Iterator[None]:
1365     """Return an empty context manager.
1366
1367     To be used like `nullcontext` in Python 3.7.
1368     """
1369     yield
1370
1371
1372 def patch_click() -> None:
1373     """Make Click not crash on Python 3.6 with LANG=C.
1374
1375     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1376     default which restricts paths that it can access during the lifetime of the
1377     application.  Click refuses to work in this scenario by raising a RuntimeError.
1378
1379     In case of Black the likelihood that non-ASCII characters are going to be used in
1380     file paths is minimal since it's Python source code.  Moreover, this crash was
1381     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1382     """
1383     modules: List[Any] = []
1384     try:
1385         from click import core
1386     except ImportError:
1387         pass
1388     else:
1389         modules.append(core)
1390     try:
1391         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1392         # older versions installed.
1393         from click import _unicodefun  # type: ignore
1394     except ImportError:
1395         pass
1396     else:
1397         modules.append(_unicodefun)
1398
1399     for module in modules:
1400         if hasattr(module, "_verify_python3_env"):
1401             module._verify_python3_env = lambda: None
1402         if hasattr(module, "_verify_python_env"):
1403             module._verify_python_env = lambda: None
1404
1405
1406 def patched_main() -> None:
1407     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1408     # environments so just assume we always need to call it if frozen.
1409     if getattr(sys, "frozen", False):
1410         from multiprocessing import freeze_support
1411
1412         freeze_support()
1413
1414     patch_click()
1415     main()
1416
1417
1418 if __name__ == "__main__":
1419     patched_main()