]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix not honouring pyproject.toml when using stdin and calling black from parent direc...
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime, timezone
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache, get_cache_info, read_cache, write_cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import (
67     FUTURE_FLAG_TO_FEATURE,
68     VERSION_TO_FEATURES,
69     Feature,
70     Mode,
71     TargetVersion,
72     supports_feature,
73 )
74 from black.nodes import (
75     STARS,
76     is_number_token,
77     is_simple_decorator_expression,
78     is_string_token,
79     syms,
80 )
81 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
82 from black.parsing import InvalidInput  # noqa F401
83 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
84 from black.report import Changed, NothingChanged, Report
85 from black.trans import iter_fexpr_spans
86 from blib2to3.pgen2 import token
87 from blib2to3.pytree import Leaf, Node
88
89 COMPILED = Path(__file__).suffix in (".pyd", ".so")
90
91 # types
92 FileContent = str
93 Encoding = str
94 NewLine = str
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102     COLOR_DIFF = 4
103
104     @classmethod
105     def from_configuration(
106         cls, *, check: bool, diff: bool, color: bool = False
107     ) -> "WriteBack":
108         if check and not diff:
109             return cls.CHECK
110
111         if diff and color:
112             return cls.COLOR_DIFF
113
114         return cls.DIFF if diff else cls.YES
115
116
117 # Legacy name, left for integrations.
118 FileMode = Mode
119
120
121 def read_pyproject_toml(
122     ctx: click.Context, param: click.Parameter, value: Optional[str]
123 ) -> Optional[str]:
124     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
125
126     Returns the path to a successfully found and read configuration file, None
127     otherwise.
128     """
129     if not value:
130         value = find_pyproject_toml(
131             ctx.params.get("src", ()), ctx.params.get("stdin_filename", None)
132         )
133         if value is None:
134             return None
135
136     try:
137         config = parse_pyproject_toml(value)
138     except (OSError, ValueError) as e:
139         raise click.FileError(
140             filename=value, hint=f"Error reading configuration file: {e}"
141         ) from None
142
143     if not config:
144         return None
145     else:
146         # Sanitize the values to be Click friendly. For more information please see:
147         # https://github.com/psf/black/issues/1458
148         # https://github.com/pallets/click/issues/1567
149         config = {
150             k: str(v) if not isinstance(v, (list, dict)) else v
151             for k, v in config.items()
152         }
153
154     target_version = config.get("target_version")
155     if target_version is not None and not isinstance(target_version, list):
156         raise click.BadOptionUsage(
157             "target-version", "Config key target-version must be a list"
158         )
159
160     default_map: Dict[str, Any] = {}
161     if ctx.default_map:
162         default_map.update(ctx.default_map)
163     default_map.update(config)
164
165     ctx.default_map = default_map
166     return value
167
168
169 def target_version_option_callback(
170     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
171 ) -> List[TargetVersion]:
172     """Compute the target versions from a --target-version flag.
173
174     This is its own function because mypy couldn't infer the type correctly
175     when it was a lambda, causing mypyc trouble.
176     """
177     return [TargetVersion[val.upper()] for val in v]
178
179
180 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
181     """Compile a regular expression string in `regex`.
182
183     If it contains newlines, use verbose mode.
184     """
185     if "\n" in regex:
186         regex = "(?x)" + regex
187     compiled: Pattern[str] = re.compile(regex)
188     return compiled
189
190
191 def validate_regex(
192     ctx: click.Context,
193     param: click.Parameter,
194     value: Optional[str],
195 ) -> Optional[Pattern[str]]:
196     try:
197         return re_compile_maybe_verbose(value) if value is not None else None
198     except re.error as e:
199         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
200
201
202 @click.command(
203     context_settings={"help_option_names": ["-h", "--help"]},
204     # While Click does set this field automatically using the docstring, mypyc
205     # (annoyingly) strips 'em so we need to set it here too.
206     help="The uncompromising code formatter.",
207 )
208 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
209 @click.option(
210     "-l",
211     "--line-length",
212     type=int,
213     default=DEFAULT_LINE_LENGTH,
214     help="How many characters per line to allow.",
215     show_default=True,
216 )
217 @click.option(
218     "-t",
219     "--target-version",
220     type=click.Choice([v.name.lower() for v in TargetVersion]),
221     callback=target_version_option_callback,
222     multiple=True,
223     help=(
224         "Python versions that should be supported by Black's output. By default, Black"
225         " will try to infer this from the project metadata in pyproject.toml. If this"
226         " does not yield conclusive results, Black will use per-file auto-detection."
227     ),
228 )
229 @click.option(
230     "--pyi",
231     is_flag=True,
232     help=(
233         "Format all input files like typing stubs regardless of file extension (useful"
234         " when piping source on standard input)."
235     ),
236 )
237 @click.option(
238     "--ipynb",
239     is_flag=True,
240     help=(
241         "Format all input files like Jupyter Notebooks regardless of file extension "
242         "(useful when piping source on standard input)."
243     ),
244 )
245 @click.option(
246     "--python-cell-magics",
247     multiple=True,
248     help=(
249         "When processing Jupyter Notebooks, add the given magic to the list"
250         f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
251         " Useful for formatting cells with custom python magics."
252     ),
253     default=[],
254 )
255 @click.option(
256     "-x",
257     "--skip-source-first-line",
258     is_flag=True,
259     help="Skip the first line of the source code.",
260 )
261 @click.option(
262     "-S",
263     "--skip-string-normalization",
264     is_flag=True,
265     help="Don't normalize string quotes or prefixes.",
266 )
267 @click.option(
268     "-C",
269     "--skip-magic-trailing-comma",
270     is_flag=True,
271     help="Don't use trailing commas as a reason to split lines.",
272 )
273 @click.option(
274     "--experimental-string-processing",
275     is_flag=True,
276     hidden=True,
277     help="(DEPRECATED and now included in --preview) Normalize string literals.",
278 )
279 @click.option(
280     "--preview",
281     is_flag=True,
282     help=(
283         "Enable potentially disruptive style changes that may be added to Black's main"
284         " functionality in the next major release."
285     ),
286 )
287 @click.option(
288     "--check",
289     is_flag=True,
290     help=(
291         "Don't write the files back, just return the status. Return code 0 means"
292         " nothing would change. Return code 1 means some files would be reformatted."
293         " Return code 123 means there was an internal error."
294     ),
295 )
296 @click.option(
297     "--diff",
298     is_flag=True,
299     help="Don't write the files back, just output a diff for each file on stdout.",
300 )
301 @click.option(
302     "--color/--no-color",
303     is_flag=True,
304     help="Show colored diff. Only applies when `--diff` is given.",
305 )
306 @click.option(
307     "--fast/--safe",
308     is_flag=True,
309     help="If --fast given, skip temporary sanity checks. [default: --safe]",
310 )
311 @click.option(
312     "--required-version",
313     type=str,
314     help=(
315         "Require a specific version of Black to be running (useful for unifying results"
316         " across many environments e.g. with a pyproject.toml file). It can be"
317         " either a major version number or an exact version."
318     ),
319 )
320 @click.option(
321     "--include",
322     type=str,
323     default=DEFAULT_INCLUDES,
324     callback=validate_regex,
325     help=(
326         "A regular expression that matches files and directories that should be"
327         " included on recursive searches. An empty value means all files are included"
328         " regardless of the name. Use forward slashes for directories on all platforms"
329         " (Windows, too). Exclusions are calculated first, inclusions later."
330     ),
331     show_default=True,
332 )
333 @click.option(
334     "--exclude",
335     type=str,
336     callback=validate_regex,
337     help=(
338         "A regular expression that matches files and directories that should be"
339         " excluded on recursive searches. An empty value means no paths are excluded."
340         " Use forward slashes for directories on all platforms (Windows, too)."
341         " Exclusions are calculated first, inclusions later. [default:"
342         f" {DEFAULT_EXCLUDES}]"
343     ),
344     show_default=False,
345 )
346 @click.option(
347     "--extend-exclude",
348     type=str,
349     callback=validate_regex,
350     help=(
351         "Like --exclude, but adds additional files and directories on top of the"
352         " excluded ones. (Useful if you simply want to add to the default)"
353     ),
354 )
355 @click.option(
356     "--force-exclude",
357     type=str,
358     callback=validate_regex,
359     help=(
360         "Like --exclude, but files and directories matching this regex will be "
361         "excluded even when they are passed explicitly as arguments."
362     ),
363 )
364 @click.option(
365     "--stdin-filename",
366     type=str,
367     is_eager=True,
368     help=(
369         "The name of the file when passing it through stdin. Useful to make "
370         "sure Black will respect --force-exclude option on some "
371         "editors that rely on using stdin."
372     ),
373 )
374 @click.option(
375     "-W",
376     "--workers",
377     type=click.IntRange(min=1),
378     default=None,
379     help="Number of parallel workers [default: number of CPUs in the system]",
380 )
381 @click.option(
382     "-q",
383     "--quiet",
384     is_flag=True,
385     help=(
386         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
387         " those with 2>/dev/null."
388     ),
389 )
390 @click.option(
391     "-v",
392     "--verbose",
393     is_flag=True,
394     help=(
395         "Also emit messages to stderr about files that were not changed or were ignored"
396         " due to exclusion patterns."
397     ),
398 )
399 @click.version_option(
400     version=__version__,
401     message=(
402         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
403         f"Python ({platform.python_implementation()}) {platform.python_version()}"
404     ),
405 )
406 @click.argument(
407     "src",
408     nargs=-1,
409     type=click.Path(
410         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
411     ),
412     is_eager=True,
413     metavar="SRC ...",
414 )
415 @click.option(
416     "--config",
417     type=click.Path(
418         exists=True,
419         file_okay=True,
420         dir_okay=False,
421         readable=True,
422         allow_dash=False,
423         path_type=str,
424     ),
425     is_eager=True,
426     callback=read_pyproject_toml,
427     help="Read configuration from FILE path.",
428 )
429 @click.pass_context
430 def main(  # noqa: C901
431     ctx: click.Context,
432     code: Optional[str],
433     line_length: int,
434     target_version: List[TargetVersion],
435     check: bool,
436     diff: bool,
437     color: bool,
438     fast: bool,
439     pyi: bool,
440     ipynb: bool,
441     python_cell_magics: Sequence[str],
442     skip_source_first_line: bool,
443     skip_string_normalization: bool,
444     skip_magic_trailing_comma: bool,
445     experimental_string_processing: bool,
446     preview: bool,
447     quiet: bool,
448     verbose: bool,
449     required_version: Optional[str],
450     include: Pattern[str],
451     exclude: Optional[Pattern[str]],
452     extend_exclude: Optional[Pattern[str]],
453     force_exclude: Optional[Pattern[str]],
454     stdin_filename: Optional[str],
455     workers: Optional[int],
456     src: Tuple[str, ...],
457     config: Optional[str],
458 ) -> None:
459     """The uncompromising code formatter."""
460     ctx.ensure_object(dict)
461
462     if src and code is not None:
463         out(
464             main.get_usage(ctx)
465             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
466         )
467         ctx.exit(1)
468     if not src and code is None:
469         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
470         ctx.exit(1)
471
472     root, method = (
473         find_project_root(src, stdin_filename) if code is None else (None, None)
474     )
475     ctx.obj["root"] = root
476
477     if verbose:
478         if root:
479             out(
480                 f"Identified `{root}` as project root containing a {method}.",
481                 fg="blue",
482             )
483
484             normalized = [
485                 (
486                     (source, source)
487                     if source == "-"
488                     else (normalize_path_maybe_ignore(Path(source), root), source)
489                 )
490                 for source in src
491             ]
492             srcs_string = ", ".join(
493                 [
494                     (
495                         f'"{_norm}"'
496                         if _norm
497                         else f'\033[31m"{source} (skipping - invalid)"\033[34m'
498                     )
499                     for _norm, source in normalized
500                 ]
501             )
502             out(f"Sources to be formatted: {srcs_string}", fg="blue")
503
504         if config:
505             config_source = ctx.get_parameter_source("config")
506             user_level_config = str(find_user_pyproject_toml())
507             if config == user_level_config:
508                 out(
509                     "Using configuration from user-level config at "
510                     f"'{user_level_config}'.",
511                     fg="blue",
512                 )
513             elif config_source in (
514                 ParameterSource.DEFAULT,
515                 ParameterSource.DEFAULT_MAP,
516             ):
517                 out("Using configuration from project root.", fg="blue")
518             else:
519                 out(f"Using configuration in '{config}'.", fg="blue")
520             if ctx.default_map:
521                 for param, value in ctx.default_map.items():
522                     out(f"{param}: {value}")
523
524     error_msg = "Oh no! 💥 💔 💥"
525     if (
526         required_version
527         and required_version != __version__
528         and required_version != __version__.split(".")[0]
529     ):
530         err(
531             f"{error_msg} The required version `{required_version}` does not match"
532             f" the running version `{__version__}`!"
533         )
534         ctx.exit(1)
535     if ipynb and pyi:
536         err("Cannot pass both `pyi` and `ipynb` flags!")
537         ctx.exit(1)
538
539     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
540     if target_version:
541         versions = set(target_version)
542     else:
543         # We'll autodetect later.
544         versions = set()
545     mode = Mode(
546         target_versions=versions,
547         line_length=line_length,
548         is_pyi=pyi,
549         is_ipynb=ipynb,
550         skip_source_first_line=skip_source_first_line,
551         string_normalization=not skip_string_normalization,
552         magic_trailing_comma=not skip_magic_trailing_comma,
553         experimental_string_processing=experimental_string_processing,
554         preview=preview,
555         python_cell_magics=set(python_cell_magics),
556     )
557
558     if code is not None:
559         # Run in quiet mode by default with -c; the extra output isn't useful.
560         # You can still pass -v to get verbose output.
561         quiet = True
562
563     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
564
565     if code is not None:
566         reformat_code(
567             content=code, fast=fast, write_back=write_back, mode=mode, report=report
568         )
569     else:
570         try:
571             sources = get_sources(
572                 ctx=ctx,
573                 src=src,
574                 quiet=quiet,
575                 verbose=verbose,
576                 include=include,
577                 exclude=exclude,
578                 extend_exclude=extend_exclude,
579                 force_exclude=force_exclude,
580                 report=report,
581                 stdin_filename=stdin_filename,
582             )
583         except GitWildMatchPatternError:
584             ctx.exit(1)
585
586         path_empty(
587             sources,
588             "No Python files are present to be formatted. Nothing to do 😴",
589             quiet,
590             verbose,
591             ctx,
592         )
593
594         if len(sources) == 1:
595             reformat_one(
596                 src=sources.pop(),
597                 fast=fast,
598                 write_back=write_back,
599                 mode=mode,
600                 report=report,
601             )
602         else:
603             from black.concurrency import reformat_many
604
605             reformat_many(
606                 sources=sources,
607                 fast=fast,
608                 write_back=write_back,
609                 mode=mode,
610                 report=report,
611                 workers=workers,
612             )
613
614     if verbose or not quiet:
615         if code is None and (verbose or report.change_count or report.failure_count):
616             out()
617         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
618         if code is None:
619             click.echo(str(report), err=True)
620     ctx.exit(report.return_code)
621
622
623 def get_sources(
624     *,
625     ctx: click.Context,
626     src: Tuple[str, ...],
627     quiet: bool,
628     verbose: bool,
629     include: Pattern[str],
630     exclude: Optional[Pattern[str]],
631     extend_exclude: Optional[Pattern[str]],
632     force_exclude: Optional[Pattern[str]],
633     report: "Report",
634     stdin_filename: Optional[str],
635 ) -> Set[Path]:
636     """Compute the set of files to be formatted."""
637     sources: Set[Path] = set()
638     root = ctx.obj["root"]
639
640     using_default_exclude = exclude is None
641     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
642     gitignore: Optional[Dict[Path, PathSpec]] = None
643     root_gitignore = get_gitignore(root)
644
645     for s in src:
646         if s == "-" and stdin_filename:
647             p = Path(stdin_filename)
648             is_stdin = True
649         else:
650             p = Path(s)
651             is_stdin = False
652
653         if is_stdin or p.is_file():
654             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
655             if normalized_path is None:
656                 continue
657
658             normalized_path = "/" + normalized_path
659             # Hard-exclude any files that matches the `--force-exclude` regex.
660             if force_exclude:
661                 force_exclude_match = force_exclude.search(normalized_path)
662             else:
663                 force_exclude_match = None
664             if force_exclude_match and force_exclude_match.group(0):
665                 report.path_ignored(p, "matches the --force-exclude regular expression")
666                 continue
667
668             if is_stdin:
669                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
670
671             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
672                 verbose=verbose, quiet=quiet
673             ):
674                 continue
675
676             sources.add(p)
677         elif p.is_dir():
678             p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
679             if using_default_exclude:
680                 gitignore = {
681                     root: root_gitignore,
682                     p: get_gitignore(p),
683                 }
684             sources.update(
685                 gen_python_files(
686                     p.iterdir(),
687                     ctx.obj["root"],
688                     include,
689                     exclude,
690                     extend_exclude,
691                     force_exclude,
692                     report,
693                     gitignore,
694                     verbose=verbose,
695                     quiet=quiet,
696                 )
697             )
698         elif s == "-":
699             sources.add(p)
700         else:
701             err(f"invalid path: {s}")
702     return sources
703
704
705 def path_empty(
706     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
707 ) -> None:
708     """
709     Exit if there is no `src` provided for formatting
710     """
711     if not src:
712         if verbose or not quiet:
713             out(msg)
714         ctx.exit(0)
715
716
717 def reformat_code(
718     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
719 ) -> None:
720     """
721     Reformat and print out `content` without spawning child processes.
722     Similar to `reformat_one`, but for string content.
723
724     `fast`, `write_back`, and `mode` options are passed to
725     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
726     """
727     path = Path("<string>")
728     try:
729         changed = Changed.NO
730         if format_stdin_to_stdout(
731             content=content, fast=fast, write_back=write_back, mode=mode
732         ):
733             changed = Changed.YES
734         report.done(path, changed)
735     except Exception as exc:
736         if report.verbose:
737             traceback.print_exc()
738         report.failed(path, str(exc))
739
740
741 # diff-shades depends on being to monkeypatch this function to operate. I know it's
742 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
743 @mypyc_attr(patchable=True)
744 def reformat_one(
745     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
746 ) -> None:
747     """Reformat a single file under `src` without spawning child processes.
748
749     `fast`, `write_back`, and `mode` options are passed to
750     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
751     """
752     try:
753         changed = Changed.NO
754
755         if str(src) == "-":
756             is_stdin = True
757         elif str(src).startswith(STDIN_PLACEHOLDER):
758             is_stdin = True
759             # Use the original name again in case we want to print something
760             # to the user
761             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
762         else:
763             is_stdin = False
764
765         if is_stdin:
766             if src.suffix == ".pyi":
767                 mode = replace(mode, is_pyi=True)
768             elif src.suffix == ".ipynb":
769                 mode = replace(mode, is_ipynb=True)
770             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
771                 changed = Changed.YES
772         else:
773             cache: Cache = {}
774             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
775                 cache = read_cache(mode)
776                 res_src = src.resolve()
777                 res_src_s = str(res_src)
778                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
779                     changed = Changed.CACHED
780             if changed is not Changed.CACHED and format_file_in_place(
781                 src, fast=fast, write_back=write_back, mode=mode
782             ):
783                 changed = Changed.YES
784             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
785                 write_back is WriteBack.CHECK and changed is Changed.NO
786             ):
787                 write_cache(cache, [src], mode)
788         report.done(src, changed)
789     except Exception as exc:
790         if report.verbose:
791             traceback.print_exc()
792         report.failed(src, str(exc))
793
794
795 def format_file_in_place(
796     src: Path,
797     fast: bool,
798     mode: Mode,
799     write_back: WriteBack = WriteBack.NO,
800     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
801 ) -> bool:
802     """Format file under `src` path. Return True if changed.
803
804     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
805     code to the file.
806     `mode` and `fast` options are passed to :func:`format_file_contents`.
807     """
808     if src.suffix == ".pyi":
809         mode = replace(mode, is_pyi=True)
810     elif src.suffix == ".ipynb":
811         mode = replace(mode, is_ipynb=True)
812
813     then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)
814     header = b""
815     with open(src, "rb") as buf:
816         if mode.skip_source_first_line:
817             header = buf.readline()
818         src_contents, encoding, newline = decode_bytes(buf.read())
819     try:
820         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
821     except NothingChanged:
822         return False
823     except JSONDecodeError:
824         raise ValueError(
825             f"File '{src}' cannot be parsed as valid Jupyter notebook."
826         ) from None
827     src_contents = header.decode(encoding) + src_contents
828     dst_contents = header.decode(encoding) + dst_contents
829
830     if write_back == WriteBack.YES:
831         with open(src, "w", encoding=encoding, newline=newline) as f:
832             f.write(dst_contents)
833     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
834         now = datetime.now(timezone.utc)
835         src_name = f"{src}\t{then}"
836         dst_name = f"{src}\t{now}"
837         if mode.is_ipynb:
838             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
839         else:
840             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
841
842         if write_back == WriteBack.COLOR_DIFF:
843             diff_contents = color_diff(diff_contents)
844
845         with lock or nullcontext():
846             f = io.TextIOWrapper(
847                 sys.stdout.buffer,
848                 encoding=encoding,
849                 newline=newline,
850                 write_through=True,
851             )
852             f = wrap_stream_for_windows(f)
853             f.write(diff_contents)
854             f.detach()
855
856     return True
857
858
859 def format_stdin_to_stdout(
860     fast: bool,
861     *,
862     content: Optional[str] = None,
863     write_back: WriteBack = WriteBack.NO,
864     mode: Mode,
865 ) -> bool:
866     """Format file on stdin. Return True if changed.
867
868     If content is None, it's read from sys.stdin.
869
870     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
871     write a diff to stdout. The `mode` argument is passed to
872     :func:`format_file_contents`.
873     """
874     then = datetime.now(timezone.utc)
875
876     if content is None:
877         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
878     else:
879         src, encoding, newline = content, "utf-8", ""
880
881     dst = src
882     try:
883         dst = format_file_contents(src, fast=fast, mode=mode)
884         return True
885
886     except NothingChanged:
887         return False
888
889     finally:
890         f = io.TextIOWrapper(
891             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
892         )
893         if write_back == WriteBack.YES:
894             # Make sure there's a newline after the content
895             if dst and dst[-1] != "\n":
896                 dst += "\n"
897             f.write(dst)
898         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
899             now = datetime.now(timezone.utc)
900             src_name = f"STDIN\t{then}"
901             dst_name = f"STDOUT\t{now}"
902             d = diff(src, dst, src_name, dst_name)
903             if write_back == WriteBack.COLOR_DIFF:
904                 d = color_diff(d)
905                 f = wrap_stream_for_windows(f)
906             f.write(d)
907         f.detach()
908
909
910 def check_stability_and_equivalence(
911     src_contents: str, dst_contents: str, *, mode: Mode
912 ) -> None:
913     """Perform stability and equivalence checks.
914
915     Raise AssertionError if source and destination contents are not
916     equivalent, or if a second pass of the formatter would format the
917     content differently.
918     """
919     assert_equivalent(src_contents, dst_contents)
920     assert_stable(src_contents, dst_contents, mode=mode)
921
922
923 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
924     """Reformat contents of a file and return new contents.
925
926     If `fast` is False, additionally confirm that the reformatted code is
927     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
928     `mode` is passed to :func:`format_str`.
929     """
930     if mode.is_ipynb:
931         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
932     else:
933         dst_contents = format_str(src_contents, mode=mode)
934     if src_contents == dst_contents:
935         raise NothingChanged
936
937     if not fast and not mode.is_ipynb:
938         # Jupyter notebooks will already have been checked above.
939         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
940     return dst_contents
941
942
943 def validate_cell(src: str, mode: Mode) -> None:
944     """Check that cell does not already contain TransformerManager transformations,
945     or non-Python cell magics, which might cause tokenizer_rt to break because of
946     indentations.
947
948     If a cell contains ``!ls``, then it'll be transformed to
949     ``get_ipython().system('ls')``. However, if the cell originally contained
950     ``get_ipython().system('ls')``, then it would get transformed in the same way:
951
952         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
953         "get_ipython().system('ls')\n"
954         >>> TransformerManager().transform_cell("!ls")
955         "get_ipython().system('ls')\n"
956
957     Due to the impossibility of safely roundtripping in such situations, cells
958     containing transformed magics will be ignored.
959     """
960     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
961         raise NothingChanged
962     if (
963         src[:2] == "%%"
964         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
965     ):
966         raise NothingChanged
967
968
969 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
970     """Format code in given cell of Jupyter notebook.
971
972     General idea is:
973
974       - if cell has trailing semicolon, remove it;
975       - if cell has IPython magics, mask them;
976       - format cell;
977       - reinstate IPython magics;
978       - reinstate trailing semicolon (if originally present);
979       - strip trailing newlines.
980
981     Cells with syntax errors will not be processed, as they
982     could potentially be automagics or multi-line magics, which
983     are currently not supported.
984     """
985     validate_cell(src, mode)
986     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
987         src
988     )
989     try:
990         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
991     except SyntaxError:
992         raise NothingChanged from None
993     masked_dst = format_str(masked_src, mode=mode)
994     if not fast:
995         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
996     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
997     dst = put_trailing_semicolon_back(
998         dst_without_trailing_semicolon, has_trailing_semicolon
999     )
1000     dst = dst.rstrip("\n")
1001     if dst == src:
1002         raise NothingChanged from None
1003     return dst
1004
1005
1006 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1007     """If notebook is marked as non-Python, don't format it.
1008
1009     All notebook metadata fields are optional, see
1010     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1011     if a notebook has empty metadata, we will try to parse it anyway.
1012     """
1013     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1014     if language is not None and language != "python":
1015         raise NothingChanged from None
1016
1017
1018 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1019     """Format Jupyter notebook.
1020
1021     Operate cell-by-cell, only on code cells, only for Python notebooks.
1022     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1023     """
1024     if not src_contents:
1025         raise NothingChanged
1026
1027     trailing_newline = src_contents[-1] == "\n"
1028     modified = False
1029     nb = json.loads(src_contents)
1030     validate_metadata(nb)
1031     for cell in nb["cells"]:
1032         if cell.get("cell_type", None) == "code":
1033             try:
1034                 src = "".join(cell["source"])
1035                 dst = format_cell(src, fast=fast, mode=mode)
1036             except NothingChanged:
1037                 pass
1038             else:
1039                 cell["source"] = dst.splitlines(keepends=True)
1040                 modified = True
1041     if modified:
1042         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1043         if trailing_newline:
1044             dst_contents = dst_contents + "\n"
1045         return dst_contents
1046     else:
1047         raise NothingChanged
1048
1049
1050 def format_str(src_contents: str, *, mode: Mode) -> str:
1051     """Reformat a string and return new contents.
1052
1053     `mode` determines formatting options, such as how many characters per line are
1054     allowed.  Example:
1055
1056     >>> import black
1057     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1058     def f(arg: str = "") -> None:
1059         ...
1060
1061     A more complex example:
1062
1063     >>> print(
1064     ...   black.format_str(
1065     ...     "def f(arg:str='')->None: hey",
1066     ...     mode=black.Mode(
1067     ...       target_versions={black.TargetVersion.PY36},
1068     ...       line_length=10,
1069     ...       string_normalization=False,
1070     ...       is_pyi=False,
1071     ...     ),
1072     ...   ),
1073     ... )
1074     def f(
1075         arg: str = '',
1076     ) -> None:
1077         hey
1078
1079     """
1080     dst_contents = _format_str_once(src_contents, mode=mode)
1081     # Forced second pass to work around optional trailing commas (becoming
1082     # forced trailing commas on pass 2) interacting differently with optional
1083     # parentheses.  Admittedly ugly.
1084     if src_contents != dst_contents:
1085         return _format_str_once(dst_contents, mode=mode)
1086     return dst_contents
1087
1088
1089 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1090     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1091     dst_blocks: List[LinesBlock] = []
1092     if mode.target_versions:
1093         versions = mode.target_versions
1094     else:
1095         future_imports = get_future_imports(src_node)
1096         versions = detect_target_versions(src_node, future_imports=future_imports)
1097
1098     context_manager_features = {
1099         feature
1100         for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
1101         if supports_feature(versions, feature)
1102     }
1103     normalize_fmt_off(src_node)
1104     lines = LineGenerator(mode=mode, features=context_manager_features)
1105     elt = EmptyLineTracker(mode=mode)
1106     split_line_features = {
1107         feature
1108         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1109         if supports_feature(versions, feature)
1110     }
1111     block: Optional[LinesBlock] = None
1112     for current_line in lines.visit(src_node):
1113         block = elt.maybe_empty_lines(current_line)
1114         dst_blocks.append(block)
1115         for line in transform_line(
1116             current_line, mode=mode, features=split_line_features
1117         ):
1118             block.content_lines.append(str(line))
1119     if dst_blocks:
1120         dst_blocks[-1].after = 0
1121     dst_contents = []
1122     for block in dst_blocks:
1123         dst_contents.extend(block.all_lines())
1124     if not dst_contents:
1125         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1126         # and check if normalized_content has more than one line
1127         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1128         if "\n" in normalized_content:
1129             return newline
1130         return ""
1131     return "".join(dst_contents)
1132
1133
1134 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1135     """Return a tuple of (decoded_contents, encoding, newline).
1136
1137     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1138     universal newlines (i.e. only contains LF).
1139     """
1140     srcbuf = io.BytesIO(src)
1141     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1142     if not lines:
1143         return "", encoding, "\n"
1144
1145     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1146     srcbuf.seek(0)
1147     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1148         return tiow.read(), encoding, newline
1149
1150
1151 def get_features_used(  # noqa: C901
1152     node: Node, *, future_imports: Optional[Set[str]] = None
1153 ) -> Set[Feature]:
1154     """Return a set of (relatively) new Python features used in this file.
1155
1156     Currently looking for:
1157     - f-strings;
1158     - self-documenting expressions in f-strings (f"{x=}");
1159     - underscores in numeric literals;
1160     - trailing commas after * or ** in function signatures and calls;
1161     - positional only arguments in function signatures and lambdas;
1162     - assignment expression;
1163     - relaxed decorator syntax;
1164     - usage of __future__ flags (annotations);
1165     - print / exec statements;
1166     - parenthesized context managers;
1167     - match statements;
1168     - except* clause;
1169     - variadic generics;
1170     """
1171     features: Set[Feature] = set()
1172     if future_imports:
1173         features |= {
1174             FUTURE_FLAG_TO_FEATURE[future_import]
1175             for future_import in future_imports
1176             if future_import in FUTURE_FLAG_TO_FEATURE
1177         }
1178
1179     for n in node.pre_order():
1180         if is_string_token(n):
1181             value_head = n.value[:2]
1182             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1183                 features.add(Feature.F_STRINGS)
1184                 if Feature.DEBUG_F_STRINGS not in features:
1185                     for span_beg, span_end in iter_fexpr_spans(n.value):
1186                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1187                             features.add(Feature.DEBUG_F_STRINGS)
1188                             break
1189
1190         elif is_number_token(n):
1191             if "_" in n.value:
1192                 features.add(Feature.NUMERIC_UNDERSCORES)
1193
1194         elif n.type == token.SLASH:
1195             if n.parent and n.parent.type in {
1196                 syms.typedargslist,
1197                 syms.arglist,
1198                 syms.varargslist,
1199             }:
1200                 features.add(Feature.POS_ONLY_ARGUMENTS)
1201
1202         elif n.type == token.COLONEQUAL:
1203             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1204
1205         elif n.type == syms.decorator:
1206             if len(n.children) > 1 and not is_simple_decorator_expression(
1207                 n.children[1]
1208             ):
1209                 features.add(Feature.RELAXED_DECORATORS)
1210
1211         elif (
1212             n.type in {syms.typedargslist, syms.arglist}
1213             and n.children
1214             and n.children[-1].type == token.COMMA
1215         ):
1216             if n.type == syms.typedargslist:
1217                 feature = Feature.TRAILING_COMMA_IN_DEF
1218             else:
1219                 feature = Feature.TRAILING_COMMA_IN_CALL
1220
1221             for ch in n.children:
1222                 if ch.type in STARS:
1223                     features.add(feature)
1224
1225                 if ch.type == syms.argument:
1226                     for argch in ch.children:
1227                         if argch.type in STARS:
1228                             features.add(feature)
1229
1230         elif (
1231             n.type in {syms.return_stmt, syms.yield_expr}
1232             and len(n.children) >= 2
1233             and n.children[1].type == syms.testlist_star_expr
1234             and any(child.type == syms.star_expr for child in n.children[1].children)
1235         ):
1236             features.add(Feature.UNPACKING_ON_FLOW)
1237
1238         elif (
1239             n.type == syms.annassign
1240             and len(n.children) >= 4
1241             and n.children[3].type == syms.testlist_star_expr
1242         ):
1243             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1244
1245         elif (
1246             n.type == syms.with_stmt
1247             and len(n.children) > 2
1248             and n.children[1].type == syms.atom
1249         ):
1250             atom_children = n.children[1].children
1251             if (
1252                 len(atom_children) == 3
1253                 and atom_children[0].type == token.LPAR
1254                 and atom_children[1].type == syms.testlist_gexp
1255                 and atom_children[2].type == token.RPAR
1256             ):
1257                 features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
1258
1259         elif n.type == syms.match_stmt:
1260             features.add(Feature.PATTERN_MATCHING)
1261
1262         elif (
1263             n.type == syms.except_clause
1264             and len(n.children) >= 2
1265             and n.children[1].type == token.STAR
1266         ):
1267             features.add(Feature.EXCEPT_STAR)
1268
1269         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1270             child.type == syms.star_expr for child in n.children
1271         ):
1272             features.add(Feature.VARIADIC_GENERICS)
1273
1274         elif (
1275             n.type == syms.tname_star
1276             and len(n.children) == 3
1277             and n.children[2].type == syms.star_expr
1278         ):
1279             features.add(Feature.VARIADIC_GENERICS)
1280
1281         elif n.type in (syms.type_stmt, syms.typeparams):
1282             features.add(Feature.TYPE_PARAMS)
1283
1284     return features
1285
1286
1287 def detect_target_versions(
1288     node: Node, *, future_imports: Optional[Set[str]] = None
1289 ) -> Set[TargetVersion]:
1290     """Detect the version to target based on the nodes used."""
1291     features = get_features_used(node, future_imports=future_imports)
1292     return {
1293         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1294     }
1295
1296
1297 def get_future_imports(node: Node) -> Set[str]:
1298     """Return a set of __future__ imports in the file."""
1299     imports: Set[str] = set()
1300
1301     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1302         for child in children:
1303             if isinstance(child, Leaf):
1304                 if child.type == token.NAME:
1305                     yield child.value
1306
1307             elif child.type == syms.import_as_name:
1308                 orig_name = child.children[0]
1309                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1310                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1311                 yield orig_name.value
1312
1313             elif child.type == syms.import_as_names:
1314                 yield from get_imports_from_children(child.children)
1315
1316             else:
1317                 raise AssertionError("Invalid syntax parsing imports")
1318
1319     for child in node.children:
1320         if child.type != syms.simple_stmt:
1321             break
1322
1323         first_child = child.children[0]
1324         if isinstance(first_child, Leaf):
1325             # Continue looking if we see a docstring; otherwise stop.
1326             if (
1327                 len(child.children) == 2
1328                 and first_child.type == token.STRING
1329                 and child.children[1].type == token.NEWLINE
1330             ):
1331                 continue
1332
1333             break
1334
1335         elif first_child.type == syms.import_from:
1336             module_name = first_child.children[1]
1337             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1338                 break
1339
1340             imports |= set(get_imports_from_children(first_child.children[3:]))
1341         else:
1342             break
1343
1344     return imports
1345
1346
1347 def assert_equivalent(src: str, dst: str) -> None:
1348     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1349     try:
1350         src_ast = parse_ast(src)
1351     except Exception as exc:
1352         raise AssertionError(
1353             "cannot use --safe with this file; failed to parse source file AST: "
1354             f"{exc}\n"
1355             "This could be caused by running Black with an older Python version "
1356             "that does not support new syntax used in your source file."
1357         ) from exc
1358
1359     try:
1360         dst_ast = parse_ast(dst)
1361     except Exception as exc:
1362         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1363         raise AssertionError(
1364             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1365             "Please report a bug on https://github.com/psf/black/issues.  "
1366             f"This invalid output might be helpful: {log}"
1367         ) from None
1368
1369     src_ast_str = "\n".join(stringify_ast(src_ast))
1370     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1371     if src_ast_str != dst_ast_str:
1372         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1373         raise AssertionError(
1374             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1375             " source.  Please report a bug on "
1376             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1377         ) from None
1378
1379
1380 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1381     """Raise AssertionError if `dst` reformats differently the second time."""
1382     # We shouldn't call format_str() here, because that formats the string
1383     # twice and may hide a bug where we bounce back and forth between two
1384     # versions.
1385     newdst = _format_str_once(dst, mode=mode)
1386     if dst != newdst:
1387         log = dump_to_file(
1388             str(mode),
1389             diff(src, dst, "source", "first pass"),
1390             diff(dst, newdst, "first pass", "second pass"),
1391         )
1392         raise AssertionError(
1393             "INTERNAL ERROR: Black produced different code on the second pass of the"
1394             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1395             f"  This diff might be helpful: {log}"
1396         ) from None
1397
1398
1399 @contextmanager
1400 def nullcontext() -> Iterator[None]:
1401     """Return an empty context manager.
1402
1403     To be used like `nullcontext` in Python 3.7.
1404     """
1405     yield
1406
1407
1408 def patch_click() -> None:
1409     """Make Click not crash on Python 3.6 with LANG=C.
1410
1411     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1412     default which restricts paths that it can access during the lifetime of the
1413     application.  Click refuses to work in this scenario by raising a RuntimeError.
1414
1415     In case of Black the likelihood that non-ASCII characters are going to be used in
1416     file paths is minimal since it's Python source code.  Moreover, this crash was
1417     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1418     """
1419     modules: List[Any] = []
1420     try:
1421         from click import core
1422     except ImportError:
1423         pass
1424     else:
1425         modules.append(core)
1426     try:
1427         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1428         # older versions installed.
1429         from click import _unicodefun  # type: ignore
1430     except ImportError:
1431         pass
1432     else:
1433         modules.append(_unicodefun)
1434
1435     for module in modules:
1436         if hasattr(module, "_verify_python3_env"):
1437             module._verify_python3_env = lambda: None
1438         if hasattr(module, "_verify_python_env"):
1439             module._verify_python_env = lambda: None
1440
1441
1442 def patched_main() -> None:
1443     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1444     # environments so just assume we always need to call it if frozen.
1445     if getattr(sys, "frozen", False):
1446         from multiprocessing import freeze_support
1447
1448         freeze_support()
1449
1450     patch_click()
1451     main()
1452
1453
1454 if __name__ == "__main__":
1455     patched_main()