]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Enable `PYTHONWARNDEFAULTENCODING = 1` in CI (#3763)
[etc/vim.git] / src / black / __init__.py
1 import io
2 import json
3 import platform
4 import re
5 import sys
6 import tokenize
7 import traceback
8 from contextlib import contextmanager
9 from dataclasses import replace
10 from datetime import datetime, timezone
11 from enum import Enum
12 from json.decoder import JSONDecodeError
13 from pathlib import Path
14 from typing import (
15     Any,
16     Dict,
17     Generator,
18     Iterator,
19     List,
20     MutableMapping,
21     Optional,
22     Pattern,
23     Sequence,
24     Set,
25     Sized,
26     Tuple,
27     Union,
28 )
29
30 import click
31 from click.core import ParameterSource
32 from mypy_extensions import mypyc_attr
33 from pathspec import PathSpec
34 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
35
36 from _black_version import version as __version__
37 from black.cache import Cache, get_cache_info, read_cache, write_cache
38 from black.comments import normalize_fmt_off
39 from black.const import (
40     DEFAULT_EXCLUDES,
41     DEFAULT_INCLUDES,
42     DEFAULT_LINE_LENGTH,
43     STDIN_PLACEHOLDER,
44 )
45 from black.files import (
46     find_project_root,
47     find_pyproject_toml,
48     find_user_pyproject_toml,
49     gen_python_files,
50     get_gitignore,
51     normalize_path_maybe_ignore,
52     parse_pyproject_toml,
53     wrap_stream_for_windows,
54 )
55 from black.handle_ipynb_magics import (
56     PYTHON_CELL_MAGICS,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59     mask_cell,
60     put_trailing_semicolon_back,
61     remove_trailing_semicolon,
62     unmask_cell,
63 )
64 from black.linegen import LN, LineGenerator, transform_line
65 from black.lines import EmptyLineTracker, LinesBlock
66 from black.mode import (
67     FUTURE_FLAG_TO_FEATURE,
68     VERSION_TO_FEATURES,
69     Feature,
70     Mode,
71     TargetVersion,
72     supports_feature,
73 )
74 from black.nodes import (
75     STARS,
76     is_number_token,
77     is_simple_decorator_expression,
78     is_string_token,
79     syms,
80 )
81 from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
82 from black.parsing import InvalidInput  # noqa F401
83 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
84 from black.report import Changed, NothingChanged, Report
85 from black.trans import iter_fexpr_spans
86 from blib2to3.pgen2 import token
87 from blib2to3.pytree import Leaf, Node
88
89 COMPILED = Path(__file__).suffix in (".pyd", ".so")
90
91 # types
92 FileContent = str
93 Encoding = str
94 NewLine = str
95
96
97 class WriteBack(Enum):
98     NO = 0
99     YES = 1
100     DIFF = 2
101     CHECK = 3
102     COLOR_DIFF = 4
103
104     @classmethod
105     def from_configuration(
106         cls, *, check: bool, diff: bool, color: bool = False
107     ) -> "WriteBack":
108         if check and not diff:
109             return cls.CHECK
110
111         if diff and color:
112             return cls.COLOR_DIFF
113
114         return cls.DIFF if diff else cls.YES
115
116
117 # Legacy name, left for integrations.
118 FileMode = Mode
119
120
121 def read_pyproject_toml(
122     ctx: click.Context, param: click.Parameter, value: Optional[str]
123 ) -> Optional[str]:
124     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
125
126     Returns the path to a successfully found and read configuration file, None
127     otherwise.
128     """
129     if not value:
130         value = find_pyproject_toml(
131             ctx.params.get("src", ()), ctx.params.get("stdin_filename", None)
132         )
133         if value is None:
134             return None
135
136     try:
137         config = parse_pyproject_toml(value)
138     except (OSError, ValueError) as e:
139         raise click.FileError(
140             filename=value, hint=f"Error reading configuration file: {e}"
141         ) from None
142
143     if not config:
144         return None
145     else:
146         # Sanitize the values to be Click friendly. For more information please see:
147         # https://github.com/psf/black/issues/1458
148         # https://github.com/pallets/click/issues/1567
149         config = {
150             k: str(v) if not isinstance(v, (list, dict)) else v
151             for k, v in config.items()
152         }
153
154     target_version = config.get("target_version")
155     if target_version is not None and not isinstance(target_version, list):
156         raise click.BadOptionUsage(
157             "target-version", "Config key target-version must be a list"
158         )
159
160     default_map: Dict[str, Any] = {}
161     if ctx.default_map:
162         default_map.update(ctx.default_map)
163     default_map.update(config)
164
165     ctx.default_map = default_map
166     return value
167
168
169 def target_version_option_callback(
170     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
171 ) -> List[TargetVersion]:
172     """Compute the target versions from a --target-version flag.
173
174     This is its own function because mypy couldn't infer the type correctly
175     when it was a lambda, causing mypyc trouble.
176     """
177     return [TargetVersion[val.upper()] for val in v]
178
179
180 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
181     """Compile a regular expression string in `regex`.
182
183     If it contains newlines, use verbose mode.
184     """
185     if "\n" in regex:
186         regex = "(?x)" + regex
187     compiled: Pattern[str] = re.compile(regex)
188     return compiled
189
190
191 def validate_regex(
192     ctx: click.Context,
193     param: click.Parameter,
194     value: Optional[str],
195 ) -> Optional[Pattern[str]]:
196     try:
197         return re_compile_maybe_verbose(value) if value is not None else None
198     except re.error as e:
199         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
200
201
202 @click.command(
203     context_settings={"help_option_names": ["-h", "--help"]},
204     # While Click does set this field automatically using the docstring, mypyc
205     # (annoyingly) strips 'em so we need to set it here too.
206     help="The uncompromising code formatter.",
207 )
208 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
209 @click.option(
210     "-l",
211     "--line-length",
212     type=int,
213     default=DEFAULT_LINE_LENGTH,
214     help="How many characters per line to allow.",
215     show_default=True,
216 )
217 @click.option(
218     "-t",
219     "--target-version",
220     type=click.Choice([v.name.lower() for v in TargetVersion]),
221     callback=target_version_option_callback,
222     multiple=True,
223     help=(
224         "Python versions that should be supported by Black's output. By default, Black"
225         " will try to infer this from the project metadata in pyproject.toml. If this"
226         " does not yield conclusive results, Black will use per-file auto-detection."
227     ),
228 )
229 @click.option(
230     "--pyi",
231     is_flag=True,
232     help=(
233         "Format all input files like typing stubs regardless of file extension (useful"
234         " when piping source on standard input)."
235     ),
236 )
237 @click.option(
238     "--ipynb",
239     is_flag=True,
240     help=(
241         "Format all input files like Jupyter Notebooks regardless of file extension "
242         "(useful when piping source on standard input)."
243     ),
244 )
245 @click.option(
246     "--python-cell-magics",
247     multiple=True,
248     help=(
249         "When processing Jupyter Notebooks, add the given magic to the list"
250         f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
251         " Useful for formatting cells with custom python magics."
252     ),
253     default=[],
254 )
255 @click.option(
256     "-x",
257     "--skip-source-first-line",
258     is_flag=True,
259     help="Skip the first line of the source code.",
260 )
261 @click.option(
262     "-S",
263     "--skip-string-normalization",
264     is_flag=True,
265     help="Don't normalize string quotes or prefixes.",
266 )
267 @click.option(
268     "-C",
269     "--skip-magic-trailing-comma",
270     is_flag=True,
271     help="Don't use trailing commas as a reason to split lines.",
272 )
273 @click.option(
274     "--experimental-string-processing",
275     is_flag=True,
276     hidden=True,
277     help="(DEPRECATED and now included in --preview) Normalize string literals.",
278 )
279 @click.option(
280     "--preview",
281     is_flag=True,
282     help=(
283         "Enable potentially disruptive style changes that may be added to Black's main"
284         " functionality in the next major release."
285     ),
286 )
287 @click.option(
288     "--check",
289     is_flag=True,
290     help=(
291         "Don't write the files back, just return the status. Return code 0 means"
292         " nothing would change. Return code 1 means some files would be reformatted."
293         " Return code 123 means there was an internal error."
294     ),
295 )
296 @click.option(
297     "--diff",
298     is_flag=True,
299     help="Don't write the files back, just output a diff for each file on stdout.",
300 )
301 @click.option(
302     "--color/--no-color",
303     is_flag=True,
304     help="Show colored diff. Only applies when `--diff` is given.",
305 )
306 @click.option(
307     "--fast/--safe",
308     is_flag=True,
309     help="If --fast given, skip temporary sanity checks. [default: --safe]",
310 )
311 @click.option(
312     "--required-version",
313     type=str,
314     help=(
315         "Require a specific version of Black to be running (useful for unifying results"
316         " across many environments e.g. with a pyproject.toml file). It can be"
317         " either a major version number or an exact version."
318     ),
319 )
320 @click.option(
321     "--include",
322     type=str,
323     default=DEFAULT_INCLUDES,
324     callback=validate_regex,
325     help=(
326         "A regular expression that matches files and directories that should be"
327         " included on recursive searches. An empty value means all files are included"
328         " regardless of the name. Use forward slashes for directories on all platforms"
329         " (Windows, too). Exclusions are calculated first, inclusions later."
330     ),
331     show_default=True,
332 )
333 @click.option(
334     "--exclude",
335     type=str,
336     callback=validate_regex,
337     help=(
338         "A regular expression that matches files and directories that should be"
339         " excluded on recursive searches. An empty value means no paths are excluded."
340         " Use forward slashes for directories on all platforms (Windows, too)."
341         " Exclusions are calculated first, inclusions later. [default:"
342         f" {DEFAULT_EXCLUDES}]"
343     ),
344     show_default=False,
345 )
346 @click.option(
347     "--extend-exclude",
348     type=str,
349     callback=validate_regex,
350     help=(
351         "Like --exclude, but adds additional files and directories on top of the"
352         " excluded ones. (Useful if you simply want to add to the default)"
353     ),
354 )
355 @click.option(
356     "--force-exclude",
357     type=str,
358     callback=validate_regex,
359     help=(
360         "Like --exclude, but files and directories matching this regex will be "
361         "excluded even when they are passed explicitly as arguments."
362     ),
363 )
364 @click.option(
365     "--stdin-filename",
366     type=str,
367     is_eager=True,
368     help=(
369         "The name of the file when passing it through stdin. Useful to make "
370         "sure Black will respect --force-exclude option on some "
371         "editors that rely on using stdin."
372     ),
373 )
374 @click.option(
375     "-W",
376     "--workers",
377     type=click.IntRange(min=1),
378     default=None,
379     help=(
380         "Number of parallel workers [default: BLACK_NUM_WORKERS environment variable "
381         "or number of CPUs in the system]"
382     ),
383 )
384 @click.option(
385     "-q",
386     "--quiet",
387     is_flag=True,
388     help=(
389         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
390         " those with 2>/dev/null."
391     ),
392 )
393 @click.option(
394     "-v",
395     "--verbose",
396     is_flag=True,
397     help=(
398         "Also emit messages to stderr about files that were not changed or were ignored"
399         " due to exclusion patterns."
400     ),
401 )
402 @click.version_option(
403     version=__version__,
404     message=(
405         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
406         f"Python ({platform.python_implementation()}) {platform.python_version()}"
407     ),
408 )
409 @click.argument(
410     "src",
411     nargs=-1,
412     type=click.Path(
413         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
414     ),
415     is_eager=True,
416     metavar="SRC ...",
417 )
418 @click.option(
419     "--config",
420     type=click.Path(
421         exists=True,
422         file_okay=True,
423         dir_okay=False,
424         readable=True,
425         allow_dash=False,
426         path_type=str,
427     ),
428     is_eager=True,
429     callback=read_pyproject_toml,
430     help="Read configuration from FILE path.",
431 )
432 @click.pass_context
433 def main(  # noqa: C901
434     ctx: click.Context,
435     code: Optional[str],
436     line_length: int,
437     target_version: List[TargetVersion],
438     check: bool,
439     diff: bool,
440     color: bool,
441     fast: bool,
442     pyi: bool,
443     ipynb: bool,
444     python_cell_magics: Sequence[str],
445     skip_source_first_line: bool,
446     skip_string_normalization: bool,
447     skip_magic_trailing_comma: bool,
448     experimental_string_processing: bool,
449     preview: bool,
450     quiet: bool,
451     verbose: bool,
452     required_version: Optional[str],
453     include: Pattern[str],
454     exclude: Optional[Pattern[str]],
455     extend_exclude: Optional[Pattern[str]],
456     force_exclude: Optional[Pattern[str]],
457     stdin_filename: Optional[str],
458     workers: Optional[int],
459     src: Tuple[str, ...],
460     config: Optional[str],
461 ) -> None:
462     """The uncompromising code formatter."""
463     ctx.ensure_object(dict)
464
465     if src and code is not None:
466         out(
467             main.get_usage(ctx)
468             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
469         )
470         ctx.exit(1)
471     if not src and code is None:
472         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
473         ctx.exit(1)
474
475     root, method = (
476         find_project_root(src, stdin_filename) if code is None else (None, None)
477     )
478     ctx.obj["root"] = root
479
480     if verbose:
481         if root:
482             out(
483                 f"Identified `{root}` as project root containing a {method}.",
484                 fg="blue",
485             )
486
487         if config:
488             config_source = ctx.get_parameter_source("config")
489             user_level_config = str(find_user_pyproject_toml())
490             if config == user_level_config:
491                 out(
492                     "Using configuration from user-level config at "
493                     f"'{user_level_config}'.",
494                     fg="blue",
495                 )
496             elif config_source in (
497                 ParameterSource.DEFAULT,
498                 ParameterSource.DEFAULT_MAP,
499             ):
500                 out("Using configuration from project root.", fg="blue")
501             else:
502                 out(f"Using configuration in '{config}'.", fg="blue")
503             if ctx.default_map:
504                 for param, value in ctx.default_map.items():
505                     out(f"{param}: {value}")
506
507     error_msg = "Oh no! 💥 💔 💥"
508     if (
509         required_version
510         and required_version != __version__
511         and required_version != __version__.split(".")[0]
512     ):
513         err(
514             f"{error_msg} The required version `{required_version}` does not match"
515             f" the running version `{__version__}`!"
516         )
517         ctx.exit(1)
518     if ipynb and pyi:
519         err("Cannot pass both `pyi` and `ipynb` flags!")
520         ctx.exit(1)
521
522     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
523     if target_version:
524         versions = set(target_version)
525     else:
526         # We'll autodetect later.
527         versions = set()
528     mode = Mode(
529         target_versions=versions,
530         line_length=line_length,
531         is_pyi=pyi,
532         is_ipynb=ipynb,
533         skip_source_first_line=skip_source_first_line,
534         string_normalization=not skip_string_normalization,
535         magic_trailing_comma=not skip_magic_trailing_comma,
536         experimental_string_processing=experimental_string_processing,
537         preview=preview,
538         python_cell_magics=set(python_cell_magics),
539     )
540
541     if code is not None:
542         # Run in quiet mode by default with -c; the extra output isn't useful.
543         # You can still pass -v to get verbose output.
544         quiet = True
545
546     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
547
548     if code is not None:
549         reformat_code(
550             content=code, fast=fast, write_back=write_back, mode=mode, report=report
551         )
552     else:
553         try:
554             sources = get_sources(
555                 ctx=ctx,
556                 src=src,
557                 quiet=quiet,
558                 verbose=verbose,
559                 include=include,
560                 exclude=exclude,
561                 extend_exclude=extend_exclude,
562                 force_exclude=force_exclude,
563                 report=report,
564                 stdin_filename=stdin_filename,
565             )
566         except GitWildMatchPatternError:
567             ctx.exit(1)
568
569         path_empty(
570             sources,
571             "No Python files are present to be formatted. Nothing to do 😴",
572             quiet,
573             verbose,
574             ctx,
575         )
576
577         if len(sources) == 1:
578             reformat_one(
579                 src=sources.pop(),
580                 fast=fast,
581                 write_back=write_back,
582                 mode=mode,
583                 report=report,
584             )
585         else:
586             from black.concurrency import reformat_many
587
588             reformat_many(
589                 sources=sources,
590                 fast=fast,
591                 write_back=write_back,
592                 mode=mode,
593                 report=report,
594                 workers=workers,
595             )
596
597     if verbose or not quiet:
598         if code is None and (verbose or report.change_count or report.failure_count):
599             out()
600         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
601         if code is None:
602             click.echo(str(report), err=True)
603     ctx.exit(report.return_code)
604
605
606 def get_sources(
607     *,
608     ctx: click.Context,
609     src: Tuple[str, ...],
610     quiet: bool,
611     verbose: bool,
612     include: Pattern[str],
613     exclude: Optional[Pattern[str]],
614     extend_exclude: Optional[Pattern[str]],
615     force_exclude: Optional[Pattern[str]],
616     report: "Report",
617     stdin_filename: Optional[str],
618 ) -> Set[Path]:
619     """Compute the set of files to be formatted."""
620     sources: Set[Path] = set()
621     root = ctx.obj["root"]
622
623     using_default_exclude = exclude is None
624     exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
625     gitignore: Optional[Dict[Path, PathSpec]] = None
626     root_gitignore = get_gitignore(root)
627
628     for s in src:
629         if s == "-" and stdin_filename:
630             p = Path(stdin_filename)
631             is_stdin = True
632         else:
633             p = Path(s)
634             is_stdin = False
635
636         if is_stdin or p.is_file():
637             normalized_path: Optional[str] = normalize_path_maybe_ignore(
638                 p, ctx.obj["root"], report
639             )
640             if normalized_path is None:
641                 if verbose:
642                     out(f'Skipping invalid source: "{normalized_path}"', fg="red")
643                 continue
644             if verbose:
645                 out(f'Found input source: "{normalized_path}"', fg="blue")
646
647             normalized_path = "/" + normalized_path
648             # Hard-exclude any files that matches the `--force-exclude` regex.
649             if force_exclude:
650                 force_exclude_match = force_exclude.search(normalized_path)
651             else:
652                 force_exclude_match = None
653             if force_exclude_match and force_exclude_match.group(0):
654                 report.path_ignored(p, "matches the --force-exclude regular expression")
655                 continue
656
657             if is_stdin:
658                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
659
660             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
661                 verbose=verbose, quiet=quiet
662             ):
663                 continue
664
665             sources.add(p)
666         elif p.is_dir():
667             p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
668             if verbose:
669                 out(f'Found input source directory: "{p}"', fg="blue")
670
671             if using_default_exclude:
672                 gitignore = {
673                     root: root_gitignore,
674                     p: get_gitignore(p),
675                 }
676             sources.update(
677                 gen_python_files(
678                     p.iterdir(),
679                     ctx.obj["root"],
680                     include,
681                     exclude,
682                     extend_exclude,
683                     force_exclude,
684                     report,
685                     gitignore,
686                     verbose=verbose,
687                     quiet=quiet,
688                 )
689             )
690         elif s == "-":
691             if verbose:
692                 out("Found input source stdin", fg="blue")
693             sources.add(p)
694         else:
695             err(f"invalid path: {s}")
696
697     return sources
698
699
700 def path_empty(
701     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
702 ) -> None:
703     """
704     Exit if there is no `src` provided for formatting
705     """
706     if not src:
707         if verbose or not quiet:
708             out(msg)
709         ctx.exit(0)
710
711
712 def reformat_code(
713     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
714 ) -> None:
715     """
716     Reformat and print out `content` without spawning child processes.
717     Similar to `reformat_one`, but for string content.
718
719     `fast`, `write_back`, and `mode` options are passed to
720     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
721     """
722     path = Path("<string>")
723     try:
724         changed = Changed.NO
725         if format_stdin_to_stdout(
726             content=content, fast=fast, write_back=write_back, mode=mode
727         ):
728             changed = Changed.YES
729         report.done(path, changed)
730     except Exception as exc:
731         if report.verbose:
732             traceback.print_exc()
733         report.failed(path, str(exc))
734
735
736 # diff-shades depends on being to monkeypatch this function to operate. I know it's
737 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
738 @mypyc_attr(patchable=True)
739 def reformat_one(
740     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
741 ) -> None:
742     """Reformat a single file under `src` without spawning child processes.
743
744     `fast`, `write_back`, and `mode` options are passed to
745     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
746     """
747     try:
748         changed = Changed.NO
749
750         if str(src) == "-":
751             is_stdin = True
752         elif str(src).startswith(STDIN_PLACEHOLDER):
753             is_stdin = True
754             # Use the original name again in case we want to print something
755             # to the user
756             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
757         else:
758             is_stdin = False
759
760         if is_stdin:
761             if src.suffix == ".pyi":
762                 mode = replace(mode, is_pyi=True)
763             elif src.suffix == ".ipynb":
764                 mode = replace(mode, is_ipynb=True)
765             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
766                 changed = Changed.YES
767         else:
768             cache: Cache = {}
769             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
770                 cache = read_cache(mode)
771                 res_src = src.resolve()
772                 res_src_s = str(res_src)
773                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
774                     changed = Changed.CACHED
775             if changed is not Changed.CACHED and format_file_in_place(
776                 src, fast=fast, write_back=write_back, mode=mode
777             ):
778                 changed = Changed.YES
779             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
780                 write_back is WriteBack.CHECK and changed is Changed.NO
781             ):
782                 write_cache(cache, [src], mode)
783         report.done(src, changed)
784     except Exception as exc:
785         if report.verbose:
786             traceback.print_exc()
787         report.failed(src, str(exc))
788
789
790 def format_file_in_place(
791     src: Path,
792     fast: bool,
793     mode: Mode,
794     write_back: WriteBack = WriteBack.NO,
795     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
796 ) -> bool:
797     """Format file under `src` path. Return True if changed.
798
799     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
800     code to the file.
801     `mode` and `fast` options are passed to :func:`format_file_contents`.
802     """
803     if src.suffix == ".pyi":
804         mode = replace(mode, is_pyi=True)
805     elif src.suffix == ".ipynb":
806         mode = replace(mode, is_ipynb=True)
807
808     then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)
809     header = b""
810     with open(src, "rb") as buf:
811         if mode.skip_source_first_line:
812             header = buf.readline()
813         src_contents, encoding, newline = decode_bytes(buf.read())
814     try:
815         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
816     except NothingChanged:
817         return False
818     except JSONDecodeError:
819         raise ValueError(
820             f"File '{src}' cannot be parsed as valid Jupyter notebook."
821         ) from None
822     src_contents = header.decode(encoding) + src_contents
823     dst_contents = header.decode(encoding) + dst_contents
824
825     if write_back == WriteBack.YES:
826         with open(src, "w", encoding=encoding, newline=newline) as f:
827             f.write(dst_contents)
828     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
829         now = datetime.now(timezone.utc)
830         src_name = f"{src}\t{then}"
831         dst_name = f"{src}\t{now}"
832         if mode.is_ipynb:
833             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
834         else:
835             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
836
837         if write_back == WriteBack.COLOR_DIFF:
838             diff_contents = color_diff(diff_contents)
839
840         with lock or nullcontext():
841             f = io.TextIOWrapper(
842                 sys.stdout.buffer,
843                 encoding=encoding,
844                 newline=newline,
845                 write_through=True,
846             )
847             f = wrap_stream_for_windows(f)
848             f.write(diff_contents)
849             f.detach()
850
851     return True
852
853
854 def format_stdin_to_stdout(
855     fast: bool,
856     *,
857     content: Optional[str] = None,
858     write_back: WriteBack = WriteBack.NO,
859     mode: Mode,
860 ) -> bool:
861     """Format file on stdin. Return True if changed.
862
863     If content is None, it's read from sys.stdin.
864
865     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
866     write a diff to stdout. The `mode` argument is passed to
867     :func:`format_file_contents`.
868     """
869     then = datetime.now(timezone.utc)
870
871     if content is None:
872         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
873     else:
874         src, encoding, newline = content, "utf-8", ""
875
876     dst = src
877     try:
878         dst = format_file_contents(src, fast=fast, mode=mode)
879         return True
880
881     except NothingChanged:
882         return False
883
884     finally:
885         f = io.TextIOWrapper(
886             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
887         )
888         if write_back == WriteBack.YES:
889             # Make sure there's a newline after the content
890             if dst and dst[-1] != "\n":
891                 dst += "\n"
892             f.write(dst)
893         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
894             now = datetime.now(timezone.utc)
895             src_name = f"STDIN\t{then}"
896             dst_name = f"STDOUT\t{now}"
897             d = diff(src, dst, src_name, dst_name)
898             if write_back == WriteBack.COLOR_DIFF:
899                 d = color_diff(d)
900                 f = wrap_stream_for_windows(f)
901             f.write(d)
902         f.detach()
903
904
905 def check_stability_and_equivalence(
906     src_contents: str, dst_contents: str, *, mode: Mode
907 ) -> None:
908     """Perform stability and equivalence checks.
909
910     Raise AssertionError if source and destination contents are not
911     equivalent, or if a second pass of the formatter would format the
912     content differently.
913     """
914     assert_equivalent(src_contents, dst_contents)
915     assert_stable(src_contents, dst_contents, mode=mode)
916
917
918 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
919     """Reformat contents of a file and return new contents.
920
921     If `fast` is False, additionally confirm that the reformatted code is
922     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
923     `mode` is passed to :func:`format_str`.
924     """
925     if mode.is_ipynb:
926         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
927     else:
928         dst_contents = format_str(src_contents, mode=mode)
929     if src_contents == dst_contents:
930         raise NothingChanged
931
932     if not fast and not mode.is_ipynb:
933         # Jupyter notebooks will already have been checked above.
934         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
935     return dst_contents
936
937
938 def validate_cell(src: str, mode: Mode) -> None:
939     """Check that cell does not already contain TransformerManager transformations,
940     or non-Python cell magics, which might cause tokenizer_rt to break because of
941     indentations.
942
943     If a cell contains ``!ls``, then it'll be transformed to
944     ``get_ipython().system('ls')``. However, if the cell originally contained
945     ``get_ipython().system('ls')``, then it would get transformed in the same way:
946
947         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
948         "get_ipython().system('ls')\n"
949         >>> TransformerManager().transform_cell("!ls")
950         "get_ipython().system('ls')\n"
951
952     Due to the impossibility of safely roundtripping in such situations, cells
953     containing transformed magics will be ignored.
954     """
955     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
956         raise NothingChanged
957     if (
958         src[:2] == "%%"
959         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
960     ):
961         raise NothingChanged
962
963
964 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
965     """Format code in given cell of Jupyter notebook.
966
967     General idea is:
968
969       - if cell has trailing semicolon, remove it;
970       - if cell has IPython magics, mask them;
971       - format cell;
972       - reinstate IPython magics;
973       - reinstate trailing semicolon (if originally present);
974       - strip trailing newlines.
975
976     Cells with syntax errors will not be processed, as they
977     could potentially be automagics or multi-line magics, which
978     are currently not supported.
979     """
980     validate_cell(src, mode)
981     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
982         src
983     )
984     try:
985         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
986     except SyntaxError:
987         raise NothingChanged from None
988     masked_dst = format_str(masked_src, mode=mode)
989     if not fast:
990         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
991     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
992     dst = put_trailing_semicolon_back(
993         dst_without_trailing_semicolon, has_trailing_semicolon
994     )
995     dst = dst.rstrip("\n")
996     if dst == src:
997         raise NothingChanged from None
998     return dst
999
1000
1001 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1002     """If notebook is marked as non-Python, don't format it.
1003
1004     All notebook metadata fields are optional, see
1005     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1006     if a notebook has empty metadata, we will try to parse it anyway.
1007     """
1008     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1009     if language is not None and language != "python":
1010         raise NothingChanged from None
1011
1012
1013 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1014     """Format Jupyter notebook.
1015
1016     Operate cell-by-cell, only on code cells, only for Python notebooks.
1017     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1018     """
1019     if not src_contents:
1020         raise NothingChanged
1021
1022     trailing_newline = src_contents[-1] == "\n"
1023     modified = False
1024     nb = json.loads(src_contents)
1025     validate_metadata(nb)
1026     for cell in nb["cells"]:
1027         if cell.get("cell_type", None) == "code":
1028             try:
1029                 src = "".join(cell["source"])
1030                 dst = format_cell(src, fast=fast, mode=mode)
1031             except NothingChanged:
1032                 pass
1033             else:
1034                 cell["source"] = dst.splitlines(keepends=True)
1035                 modified = True
1036     if modified:
1037         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1038         if trailing_newline:
1039             dst_contents = dst_contents + "\n"
1040         return dst_contents
1041     else:
1042         raise NothingChanged
1043
1044
1045 def format_str(src_contents: str, *, mode: Mode) -> str:
1046     """Reformat a string and return new contents.
1047
1048     `mode` determines formatting options, such as how many characters per line are
1049     allowed.  Example:
1050
1051     >>> import black
1052     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1053     def f(arg: str = "") -> None:
1054         ...
1055
1056     A more complex example:
1057
1058     >>> print(
1059     ...   black.format_str(
1060     ...     "def f(arg:str='')->None: hey",
1061     ...     mode=black.Mode(
1062     ...       target_versions={black.TargetVersion.PY36},
1063     ...       line_length=10,
1064     ...       string_normalization=False,
1065     ...       is_pyi=False,
1066     ...     ),
1067     ...   ),
1068     ... )
1069     def f(
1070         arg: str = '',
1071     ) -> None:
1072         hey
1073
1074     """
1075     dst_contents = _format_str_once(src_contents, mode=mode)
1076     # Forced second pass to work around optional trailing commas (becoming
1077     # forced trailing commas on pass 2) interacting differently with optional
1078     # parentheses.  Admittedly ugly.
1079     if src_contents != dst_contents:
1080         return _format_str_once(dst_contents, mode=mode)
1081     return dst_contents
1082
1083
1084 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1085     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1086     dst_blocks: List[LinesBlock] = []
1087     if mode.target_versions:
1088         versions = mode.target_versions
1089     else:
1090         future_imports = get_future_imports(src_node)
1091         versions = detect_target_versions(src_node, future_imports=future_imports)
1092
1093     context_manager_features = {
1094         feature
1095         for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
1096         if supports_feature(versions, feature)
1097     }
1098     normalize_fmt_off(src_node)
1099     lines = LineGenerator(mode=mode, features=context_manager_features)
1100     elt = EmptyLineTracker(mode=mode)
1101     split_line_features = {
1102         feature
1103         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1104         if supports_feature(versions, feature)
1105     }
1106     block: Optional[LinesBlock] = None
1107     for current_line in lines.visit(src_node):
1108         block = elt.maybe_empty_lines(current_line)
1109         dst_blocks.append(block)
1110         for line in transform_line(
1111             current_line, mode=mode, features=split_line_features
1112         ):
1113             block.content_lines.append(str(line))
1114     if dst_blocks:
1115         dst_blocks[-1].after = 0
1116     dst_contents = []
1117     for block in dst_blocks:
1118         dst_contents.extend(block.all_lines())
1119     if not dst_contents:
1120         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
1121         # and check if normalized_content has more than one line
1122         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
1123         if "\n" in normalized_content:
1124             return newline
1125         return ""
1126     return "".join(dst_contents)
1127
1128
1129 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1130     """Return a tuple of (decoded_contents, encoding, newline).
1131
1132     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1133     universal newlines (i.e. only contains LF).
1134     """
1135     srcbuf = io.BytesIO(src)
1136     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1137     if not lines:
1138         return "", encoding, "\n"
1139
1140     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1141     srcbuf.seek(0)
1142     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1143         return tiow.read(), encoding, newline
1144
1145
1146 def get_features_used(  # noqa: C901
1147     node: Node, *, future_imports: Optional[Set[str]] = None
1148 ) -> Set[Feature]:
1149     """Return a set of (relatively) new Python features used in this file.
1150
1151     Currently looking for:
1152     - f-strings;
1153     - self-documenting expressions in f-strings (f"{x=}");
1154     - underscores in numeric literals;
1155     - trailing commas after * or ** in function signatures and calls;
1156     - positional only arguments in function signatures and lambdas;
1157     - assignment expression;
1158     - relaxed decorator syntax;
1159     - usage of __future__ flags (annotations);
1160     - print / exec statements;
1161     - parenthesized context managers;
1162     - match statements;
1163     - except* clause;
1164     - variadic generics;
1165     """
1166     features: Set[Feature] = set()
1167     if future_imports:
1168         features |= {
1169             FUTURE_FLAG_TO_FEATURE[future_import]
1170             for future_import in future_imports
1171             if future_import in FUTURE_FLAG_TO_FEATURE
1172         }
1173
1174     for n in node.pre_order():
1175         if is_string_token(n):
1176             value_head = n.value[:2]
1177             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1178                 features.add(Feature.F_STRINGS)
1179                 if Feature.DEBUG_F_STRINGS not in features:
1180                     for span_beg, span_end in iter_fexpr_spans(n.value):
1181                         if n.value[span_beg : span_end - 1].rstrip().endswith("="):
1182                             features.add(Feature.DEBUG_F_STRINGS)
1183                             break
1184
1185         elif is_number_token(n):
1186             if "_" in n.value:
1187                 features.add(Feature.NUMERIC_UNDERSCORES)
1188
1189         elif n.type == token.SLASH:
1190             if n.parent and n.parent.type in {
1191                 syms.typedargslist,
1192                 syms.arglist,
1193                 syms.varargslist,
1194             }:
1195                 features.add(Feature.POS_ONLY_ARGUMENTS)
1196
1197         elif n.type == token.COLONEQUAL:
1198             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1199
1200         elif n.type == syms.decorator:
1201             if len(n.children) > 1 and not is_simple_decorator_expression(
1202                 n.children[1]
1203             ):
1204                 features.add(Feature.RELAXED_DECORATORS)
1205
1206         elif (
1207             n.type in {syms.typedargslist, syms.arglist}
1208             and n.children
1209             and n.children[-1].type == token.COMMA
1210         ):
1211             if n.type == syms.typedargslist:
1212                 feature = Feature.TRAILING_COMMA_IN_DEF
1213             else:
1214                 feature = Feature.TRAILING_COMMA_IN_CALL
1215
1216             for ch in n.children:
1217                 if ch.type in STARS:
1218                     features.add(feature)
1219
1220                 if ch.type == syms.argument:
1221                     for argch in ch.children:
1222                         if argch.type in STARS:
1223                             features.add(feature)
1224
1225         elif (
1226             n.type in {syms.return_stmt, syms.yield_expr}
1227             and len(n.children) >= 2
1228             and n.children[1].type == syms.testlist_star_expr
1229             and any(child.type == syms.star_expr for child in n.children[1].children)
1230         ):
1231             features.add(Feature.UNPACKING_ON_FLOW)
1232
1233         elif (
1234             n.type == syms.annassign
1235             and len(n.children) >= 4
1236             and n.children[3].type == syms.testlist_star_expr
1237         ):
1238             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1239
1240         elif (
1241             n.type == syms.with_stmt
1242             and len(n.children) > 2
1243             and n.children[1].type == syms.atom
1244         ):
1245             atom_children = n.children[1].children
1246             if (
1247                 len(atom_children) == 3
1248                 and atom_children[0].type == token.LPAR
1249                 and atom_children[1].type == syms.testlist_gexp
1250                 and atom_children[2].type == token.RPAR
1251             ):
1252                 features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
1253
1254         elif n.type == syms.match_stmt:
1255             features.add(Feature.PATTERN_MATCHING)
1256
1257         elif (
1258             n.type == syms.except_clause
1259             and len(n.children) >= 2
1260             and n.children[1].type == token.STAR
1261         ):
1262             features.add(Feature.EXCEPT_STAR)
1263
1264         elif n.type in {syms.subscriptlist, syms.trailer} and any(
1265             child.type == syms.star_expr for child in n.children
1266         ):
1267             features.add(Feature.VARIADIC_GENERICS)
1268
1269         elif (
1270             n.type == syms.tname_star
1271             and len(n.children) == 3
1272             and n.children[2].type == syms.star_expr
1273         ):
1274             features.add(Feature.VARIADIC_GENERICS)
1275
1276         elif n.type in (syms.type_stmt, syms.typeparams):
1277             features.add(Feature.TYPE_PARAMS)
1278
1279     return features
1280
1281
1282 def detect_target_versions(
1283     node: Node, *, future_imports: Optional[Set[str]] = None
1284 ) -> Set[TargetVersion]:
1285     """Detect the version to target based on the nodes used."""
1286     features = get_features_used(node, future_imports=future_imports)
1287     return {
1288         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1289     }
1290
1291
1292 def get_future_imports(node: Node) -> Set[str]:
1293     """Return a set of __future__ imports in the file."""
1294     imports: Set[str] = set()
1295
1296     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1297         for child in children:
1298             if isinstance(child, Leaf):
1299                 if child.type == token.NAME:
1300                     yield child.value
1301
1302             elif child.type == syms.import_as_name:
1303                 orig_name = child.children[0]
1304                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1305                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1306                 yield orig_name.value
1307
1308             elif child.type == syms.import_as_names:
1309                 yield from get_imports_from_children(child.children)
1310
1311             else:
1312                 raise AssertionError("Invalid syntax parsing imports")
1313
1314     for child in node.children:
1315         if child.type != syms.simple_stmt:
1316             break
1317
1318         first_child = child.children[0]
1319         if isinstance(first_child, Leaf):
1320             # Continue looking if we see a docstring; otherwise stop.
1321             if (
1322                 len(child.children) == 2
1323                 and first_child.type == token.STRING
1324                 and child.children[1].type == token.NEWLINE
1325             ):
1326                 continue
1327
1328             break
1329
1330         elif first_child.type == syms.import_from:
1331             module_name = first_child.children[1]
1332             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1333                 break
1334
1335             imports |= set(get_imports_from_children(first_child.children[3:]))
1336         else:
1337             break
1338
1339     return imports
1340
1341
1342 def assert_equivalent(src: str, dst: str) -> None:
1343     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1344     try:
1345         src_ast = parse_ast(src)
1346     except Exception as exc:
1347         raise AssertionError(
1348             "cannot use --safe with this file; failed to parse source file AST: "
1349             f"{exc}\n"
1350             "This could be caused by running Black with an older Python version "
1351             "that does not support new syntax used in your source file."
1352         ) from exc
1353
1354     try:
1355         dst_ast = parse_ast(dst)
1356     except Exception as exc:
1357         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1358         raise AssertionError(
1359             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1360             "Please report a bug on https://github.com/psf/black/issues.  "
1361             f"This invalid output might be helpful: {log}"
1362         ) from None
1363
1364     src_ast_str = "\n".join(stringify_ast(src_ast))
1365     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1366     if src_ast_str != dst_ast_str:
1367         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1368         raise AssertionError(
1369             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1370             " source.  Please report a bug on "
1371             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1372         ) from None
1373
1374
1375 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1376     """Raise AssertionError if `dst` reformats differently the second time."""
1377     # We shouldn't call format_str() here, because that formats the string
1378     # twice and may hide a bug where we bounce back and forth between two
1379     # versions.
1380     newdst = _format_str_once(dst, mode=mode)
1381     if dst != newdst:
1382         log = dump_to_file(
1383             str(mode),
1384             diff(src, dst, "source", "first pass"),
1385             diff(dst, newdst, "first pass", "second pass"),
1386         )
1387         raise AssertionError(
1388             "INTERNAL ERROR: Black produced different code on the second pass of the"
1389             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1390             f"  This diff might be helpful: {log}"
1391         ) from None
1392
1393
1394 @contextmanager
1395 def nullcontext() -> Iterator[None]:
1396     """Return an empty context manager.
1397
1398     To be used like `nullcontext` in Python 3.7.
1399     """
1400     yield
1401
1402
1403 def patch_click() -> None:
1404     """Make Click not crash on Python 3.6 with LANG=C.
1405
1406     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1407     default which restricts paths that it can access during the lifetime of the
1408     application.  Click refuses to work in this scenario by raising a RuntimeError.
1409
1410     In case of Black the likelihood that non-ASCII characters are going to be used in
1411     file paths is minimal since it's Python source code.  Moreover, this crash was
1412     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1413     """
1414     modules: List[Any] = []
1415     try:
1416         from click import core
1417     except ImportError:
1418         pass
1419     else:
1420         modules.append(core)
1421     try:
1422         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1423         # older versions installed.
1424         from click import _unicodefun  # type: ignore
1425     except ImportError:
1426         pass
1427     else:
1428         modules.append(_unicodefun)
1429
1430     for module in modules:
1431         if hasattr(module, "_verify_python3_env"):
1432             module._verify_python3_env = lambda: None
1433         if hasattr(module, "_verify_python_env"):
1434             module._verify_python_env = lambda: None
1435
1436
1437 def patched_main() -> None:
1438     # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
1439     # environments so just assume we always need to call it if frozen.
1440     if getattr(sys, "frozen", False):
1441         from multiprocessing import freeze_support
1442
1443         freeze_support()
1444
1445     patch_click()
1446     main()
1447
1448
1449 if __name__ == "__main__":
1450     patched_main()