]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump furo from 2022.3.4 to 2022.4.7 in /docs (#3003)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import platform
14 import re
15 import signal
16 import sys
17 import tokenize
18 import traceback
19 from typing import (
20     Any,
21     Dict,
22     Generator,
23     Iterator,
24     List,
25     MutableMapping,
26     Optional,
27     Pattern,
28     Sequence,
29     Set,
30     Sized,
31     Tuple,
32     Union,
33 )
34
35 import click
36 from click.core import ParameterSource
37 from dataclasses import replace
38 from mypy_extensions import mypyc_attr
39
40 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
41 from black.const import STDIN_PLACEHOLDER
42 from black.nodes import STARS, syms, is_simple_decorator_expression
43 from black.nodes import is_string_token
44 from black.lines import Line, EmptyLineTracker
45 from black.linegen import transform_line, LineGenerator, LN
46 from black.comments import normalize_fmt_off
47 from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
48 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
49 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
50 from black.concurrency import cancel, shutdown, maybe_install_uvloop
51 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
52 from black.report import Report, Changed, NothingChanged
53 from black.files import (
54     find_project_root,
55     find_pyproject_toml,
56     parse_pyproject_toml,
57     find_user_pyproject_toml,
58 )
59 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
60 from black.files import wrap_stream_for_windows
61 from black.parsing import InvalidInput  # noqa F401
62 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
63 from black.handle_ipynb_magics import (
64     mask_cell,
65     unmask_cell,
66     remove_trailing_semicolon,
67     put_trailing_semicolon_back,
68     TRANSFORMED_MAGICS,
69     PYTHON_CELL_MAGICS,
70     jupyter_dependencies_are_installed,
71 )
72
73
74 # lib2to3 fork
75 from blib2to3.pytree import Node, Leaf
76 from blib2to3.pgen2 import token
77
78 from _black_version import version as __version__
79
80 COMPILED = Path(__file__).suffix in (".pyd", ".so")
81
82 # types
83 FileContent = str
84 Encoding = str
85 NewLine = str
86
87
88 class WriteBack(Enum):
89     NO = 0
90     YES = 1
91     DIFF = 2
92     CHECK = 3
93     COLOR_DIFF = 4
94
95     @classmethod
96     def from_configuration(
97         cls, *, check: bool, diff: bool, color: bool = False
98     ) -> "WriteBack":
99         if check and not diff:
100             return cls.CHECK
101
102         if diff and color:
103             return cls.COLOR_DIFF
104
105         return cls.DIFF if diff else cls.YES
106
107
108 # Legacy name, left for integrations.
109 FileMode = Mode
110
111 DEFAULT_WORKERS = os.cpu_count()
112
113
114 def read_pyproject_toml(
115     ctx: click.Context, param: click.Parameter, value: Optional[str]
116 ) -> Optional[str]:
117     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
118
119     Returns the path to a successfully found and read configuration file, None
120     otherwise.
121     """
122     if not value:
123         value = find_pyproject_toml(ctx.params.get("src", ()))
124         if value is None:
125             return None
126
127     try:
128         config = parse_pyproject_toml(value)
129     except (OSError, ValueError) as e:
130         raise click.FileError(
131             filename=value, hint=f"Error reading configuration file: {e}"
132         ) from None
133
134     if not config:
135         return None
136     else:
137         # Sanitize the values to be Click friendly. For more information please see:
138         # https://github.com/psf/black/issues/1458
139         # https://github.com/pallets/click/issues/1567
140         config = {
141             k: str(v) if not isinstance(v, (list, dict)) else v
142             for k, v in config.items()
143         }
144
145     target_version = config.get("target_version")
146     if target_version is not None and not isinstance(target_version, list):
147         raise click.BadOptionUsage(
148             "target-version", "Config key target-version must be a list"
149         )
150
151     default_map: Dict[str, Any] = {}
152     if ctx.default_map:
153         default_map.update(ctx.default_map)
154     default_map.update(config)
155
156     ctx.default_map = default_map
157     return value
158
159
160 def target_version_option_callback(
161     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
162 ) -> List[TargetVersion]:
163     """Compute the target versions from a --target-version flag.
164
165     This is its own function because mypy couldn't infer the type correctly
166     when it was a lambda, causing mypyc trouble.
167     """
168     return [TargetVersion[val.upper()] for val in v]
169
170
171 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
172     """Compile a regular expression string in `regex`.
173
174     If it contains newlines, use verbose mode.
175     """
176     if "\n" in regex:
177         regex = "(?x)" + regex
178     compiled: Pattern[str] = re.compile(regex)
179     return compiled
180
181
182 def validate_regex(
183     ctx: click.Context,
184     param: click.Parameter,
185     value: Optional[str],
186 ) -> Optional[Pattern[str]]:
187     try:
188         return re_compile_maybe_verbose(value) if value is not None else None
189     except re.error as e:
190         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
191
192
193 @click.command(
194     context_settings={"help_option_names": ["-h", "--help"]},
195     # While Click does set this field automatically using the docstring, mypyc
196     # (annoyingly) strips 'em so we need to set it here too.
197     help="The uncompromising code formatter.",
198 )
199 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
200 @click.option(
201     "-l",
202     "--line-length",
203     type=int,
204     default=DEFAULT_LINE_LENGTH,
205     help="How many characters per line to allow.",
206     show_default=True,
207 )
208 @click.option(
209     "-t",
210     "--target-version",
211     type=click.Choice([v.name.lower() for v in TargetVersion]),
212     callback=target_version_option_callback,
213     multiple=True,
214     help=(
215         "Python versions that should be supported by Black's output. [default: per-file"
216         " auto-detection]"
217     ),
218 )
219 @click.option(
220     "--pyi",
221     is_flag=True,
222     help=(
223         "Format all input files like typing stubs regardless of file extension (useful"
224         " when piping source on standard input)."
225     ),
226 )
227 @click.option(
228     "--ipynb",
229     is_flag=True,
230     help=(
231         "Format all input files like Jupyter Notebooks regardless of file extension "
232         "(useful when piping source on standard input)."
233     ),
234 )
235 @click.option(
236     "--python-cell-magics",
237     multiple=True,
238     help=(
239         "When processing Jupyter Notebooks, add the given magic to the list"
240         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
241         " Useful for formatting cells with custom python magics."
242     ),
243     default=[],
244 )
245 @click.option(
246     "-S",
247     "--skip-string-normalization",
248     is_flag=True,
249     help="Don't normalize string quotes or prefixes.",
250 )
251 @click.option(
252     "-C",
253     "--skip-magic-trailing-comma",
254     is_flag=True,
255     help="Don't use trailing commas as a reason to split lines.",
256 )
257 @click.option(
258     "--experimental-string-processing",
259     is_flag=True,
260     hidden=True,
261     help="(DEPRECATED and now included in --preview) Normalize string literals.",
262 )
263 @click.option(
264     "--preview",
265     is_flag=True,
266     help=(
267         "Enable potentially disruptive style changes that may be added to Black's main"
268         " functionality in the next major release."
269     ),
270 )
271 @click.option(
272     "--check",
273     is_flag=True,
274     help=(
275         "Don't write the files back, just return the status. Return code 0 means"
276         " nothing would change. Return code 1 means some files would be reformatted."
277         " Return code 123 means there was an internal error."
278     ),
279 )
280 @click.option(
281     "--diff",
282     is_flag=True,
283     help="Don't write the files back, just output a diff for each file on stdout.",
284 )
285 @click.option(
286     "--color/--no-color",
287     is_flag=True,
288     help="Show colored diff. Only applies when `--diff` is given.",
289 )
290 @click.option(
291     "--fast/--safe",
292     is_flag=True,
293     help="If --fast given, skip temporary sanity checks. [default: --safe]",
294 )
295 @click.option(
296     "--required-version",
297     type=str,
298     help=(
299         "Require a specific version of Black to be running (useful for unifying results"
300         " across many environments e.g. with a pyproject.toml file). It can be"
301         " either a major version number or an exact version."
302     ),
303 )
304 @click.option(
305     "--include",
306     type=str,
307     default=DEFAULT_INCLUDES,
308     callback=validate_regex,
309     help=(
310         "A regular expression that matches files and directories that should be"
311         " included on recursive searches. An empty value means all files are included"
312         " regardless of the name. Use forward slashes for directories on all platforms"
313         " (Windows, too). Exclusions are calculated first, inclusions later."
314     ),
315     show_default=True,
316 )
317 @click.option(
318     "--exclude",
319     type=str,
320     callback=validate_regex,
321     help=(
322         "A regular expression that matches files and directories that should be"
323         " excluded on recursive searches. An empty value means no paths are excluded."
324         " Use forward slashes for directories on all platforms (Windows, too)."
325         " Exclusions are calculated first, inclusions later. [default:"
326         f" {DEFAULT_EXCLUDES}]"
327     ),
328     show_default=False,
329 )
330 @click.option(
331     "--extend-exclude",
332     type=str,
333     callback=validate_regex,
334     help=(
335         "Like --exclude, but adds additional files and directories on top of the"
336         " excluded ones. (Useful if you simply want to add to the default)"
337     ),
338 )
339 @click.option(
340     "--force-exclude",
341     type=str,
342     callback=validate_regex,
343     help=(
344         "Like --exclude, but files and directories matching this regex will be "
345         "excluded even when they are passed explicitly as arguments."
346     ),
347 )
348 @click.option(
349     "--stdin-filename",
350     type=str,
351     help=(
352         "The name of the file when passing it through stdin. Useful to make "
353         "sure Black will respect --force-exclude option on some "
354         "editors that rely on using stdin."
355     ),
356 )
357 @click.option(
358     "-W",
359     "--workers",
360     type=click.IntRange(min=1),
361     default=DEFAULT_WORKERS,
362     show_default=True,
363     help="Number of parallel workers",
364 )
365 @click.option(
366     "-q",
367     "--quiet",
368     is_flag=True,
369     help=(
370         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
371         " those with 2>/dev/null."
372     ),
373 )
374 @click.option(
375     "-v",
376     "--verbose",
377     is_flag=True,
378     help=(
379         "Also emit messages to stderr about files that were not changed or were ignored"
380         " due to exclusion patterns."
381     ),
382 )
383 @click.version_option(
384     version=__version__,
385     message=(
386         f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
387         f"Python ({platform.python_implementation()}) {platform.python_version()}"
388     ),
389 )
390 @click.argument(
391     "src",
392     nargs=-1,
393     type=click.Path(
394         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
395     ),
396     is_eager=True,
397     metavar="SRC ...",
398 )
399 @click.option(
400     "--config",
401     type=click.Path(
402         exists=True,
403         file_okay=True,
404         dir_okay=False,
405         readable=True,
406         allow_dash=False,
407         path_type=str,
408     ),
409     is_eager=True,
410     callback=read_pyproject_toml,
411     help="Read configuration from FILE path.",
412 )
413 @click.pass_context
414 def main(  # noqa: C901
415     ctx: click.Context,
416     code: Optional[str],
417     line_length: int,
418     target_version: List[TargetVersion],
419     check: bool,
420     diff: bool,
421     color: bool,
422     fast: bool,
423     pyi: bool,
424     ipynb: bool,
425     python_cell_magics: Sequence[str],
426     skip_string_normalization: bool,
427     skip_magic_trailing_comma: bool,
428     experimental_string_processing: bool,
429     preview: bool,
430     quiet: bool,
431     verbose: bool,
432     required_version: Optional[str],
433     include: Pattern[str],
434     exclude: Optional[Pattern[str]],
435     extend_exclude: Optional[Pattern[str]],
436     force_exclude: Optional[Pattern[str]],
437     stdin_filename: Optional[str],
438     workers: int,
439     src: Tuple[str, ...],
440     config: Optional[str],
441 ) -> None:
442     """The uncompromising code formatter."""
443     ctx.ensure_object(dict)
444
445     if src and code is not None:
446         out(
447             main.get_usage(ctx)
448             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
449         )
450         ctx.exit(1)
451     if not src and code is None:
452         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
453         ctx.exit(1)
454
455     root, method = find_project_root(src) if code is None else (None, None)
456     ctx.obj["root"] = root
457
458     if verbose:
459         if root:
460             out(
461                 f"Identified `{root}` as project root containing a {method}.",
462                 fg="blue",
463             )
464
465             normalized = [
466                 (normalize_path_maybe_ignore(Path(source), root), source)
467                 for source in src
468             ]
469             srcs_string = ", ".join(
470                 [
471                     f'"{_norm}"'
472                     if _norm
473                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
474                     for _norm, source in normalized
475                 ]
476             )
477             out(f"Sources to be formatted: {srcs_string}", fg="blue")
478
479         if config:
480             config_source = ctx.get_parameter_source("config")
481             user_level_config = str(find_user_pyproject_toml())
482             if config == user_level_config:
483                 out(
484                     "Using configuration from user-level config at "
485                     f"'{user_level_config}'.",
486                     fg="blue",
487                 )
488             elif config_source in (
489                 ParameterSource.DEFAULT,
490                 ParameterSource.DEFAULT_MAP,
491             ):
492                 out("Using configuration from project root.", fg="blue")
493             else:
494                 out(f"Using configuration in '{config}'.", fg="blue")
495
496     error_msg = "Oh no! 💥 💔 💥"
497     if (
498         required_version
499         and required_version != __version__
500         and required_version != __version__.split(".")[0]
501     ):
502         err(
503             f"{error_msg} The required version `{required_version}` does not match"
504             f" the running version `{__version__}`!"
505         )
506         ctx.exit(1)
507     if ipynb and pyi:
508         err("Cannot pass both `pyi` and `ipynb` flags!")
509         ctx.exit(1)
510
511     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
512     if target_version:
513         versions = set(target_version)
514     else:
515         # We'll autodetect later.
516         versions = set()
517     mode = Mode(
518         target_versions=versions,
519         line_length=line_length,
520         is_pyi=pyi,
521         is_ipynb=ipynb,
522         string_normalization=not skip_string_normalization,
523         magic_trailing_comma=not skip_magic_trailing_comma,
524         experimental_string_processing=experimental_string_processing,
525         preview=preview,
526         python_cell_magics=set(python_cell_magics),
527     )
528
529     if code is not None:
530         # Run in quiet mode by default with -c; the extra output isn't useful.
531         # You can still pass -v to get verbose output.
532         quiet = True
533
534     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
535
536     if code is not None:
537         reformat_code(
538             content=code, fast=fast, write_back=write_back, mode=mode, report=report
539         )
540     else:
541         try:
542             sources = get_sources(
543                 ctx=ctx,
544                 src=src,
545                 quiet=quiet,
546                 verbose=verbose,
547                 include=include,
548                 exclude=exclude,
549                 extend_exclude=extend_exclude,
550                 force_exclude=force_exclude,
551                 report=report,
552                 stdin_filename=stdin_filename,
553             )
554         except GitWildMatchPatternError:
555             ctx.exit(1)
556
557         path_empty(
558             sources,
559             "No Python files are present to be formatted. Nothing to do 😴",
560             quiet,
561             verbose,
562             ctx,
563         )
564
565         if len(sources) == 1:
566             reformat_one(
567                 src=sources.pop(),
568                 fast=fast,
569                 write_back=write_back,
570                 mode=mode,
571                 report=report,
572             )
573         else:
574             reformat_many(
575                 sources=sources,
576                 fast=fast,
577                 write_back=write_back,
578                 mode=mode,
579                 report=report,
580                 workers=workers,
581             )
582
583     if verbose or not quiet:
584         if code is None and (verbose or report.change_count or report.failure_count):
585             out()
586         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
587         if code is None:
588             click.echo(str(report), err=True)
589     ctx.exit(report.return_code)
590
591
592 def get_sources(
593     *,
594     ctx: click.Context,
595     src: Tuple[str, ...],
596     quiet: bool,
597     verbose: bool,
598     include: Pattern[str],
599     exclude: Optional[Pattern[str]],
600     extend_exclude: Optional[Pattern[str]],
601     force_exclude: Optional[Pattern[str]],
602     report: "Report",
603     stdin_filename: Optional[str],
604 ) -> Set[Path]:
605     """Compute the set of files to be formatted."""
606     sources: Set[Path] = set()
607
608     if exclude is None:
609         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
610         gitignore = get_gitignore(ctx.obj["root"])
611     else:
612         gitignore = None
613
614     for s in src:
615         if s == "-" and stdin_filename:
616             p = Path(stdin_filename)
617             is_stdin = True
618         else:
619             p = Path(s)
620             is_stdin = False
621
622         if is_stdin or p.is_file():
623             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
624             if normalized_path is None:
625                 continue
626
627             normalized_path = "/" + normalized_path
628             # Hard-exclude any files that matches the `--force-exclude` regex.
629             if force_exclude:
630                 force_exclude_match = force_exclude.search(normalized_path)
631             else:
632                 force_exclude_match = None
633             if force_exclude_match and force_exclude_match.group(0):
634                 report.path_ignored(p, "matches the --force-exclude regular expression")
635                 continue
636
637             if is_stdin:
638                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
639
640             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
641                 verbose=verbose, quiet=quiet
642             ):
643                 continue
644
645             sources.add(p)
646         elif p.is_dir():
647             sources.update(
648                 gen_python_files(
649                     p.iterdir(),
650                     ctx.obj["root"],
651                     include,
652                     exclude,
653                     extend_exclude,
654                     force_exclude,
655                     report,
656                     gitignore,
657                     verbose=verbose,
658                     quiet=quiet,
659                 )
660             )
661         elif s == "-":
662             sources.add(p)
663         else:
664             err(f"invalid path: {s}")
665     return sources
666
667
668 def path_empty(
669     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
670 ) -> None:
671     """
672     Exit if there is no `src` provided for formatting
673     """
674     if not src:
675         if verbose or not quiet:
676             out(msg)
677         ctx.exit(0)
678
679
680 def reformat_code(
681     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
682 ) -> None:
683     """
684     Reformat and print out `content` without spawning child processes.
685     Similar to `reformat_one`, but for string content.
686
687     `fast`, `write_back`, and `mode` options are passed to
688     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
689     """
690     path = Path("<string>")
691     try:
692         changed = Changed.NO
693         if format_stdin_to_stdout(
694             content=content, fast=fast, write_back=write_back, mode=mode
695         ):
696             changed = Changed.YES
697         report.done(path, changed)
698     except Exception as exc:
699         if report.verbose:
700             traceback.print_exc()
701         report.failed(path, str(exc))
702
703
704 # diff-shades depends on being to monkeypatch this function to operate. I know it's
705 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
706 @mypyc_attr(patchable=True)
707 def reformat_one(
708     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
709 ) -> None:
710     """Reformat a single file under `src` without spawning child processes.
711
712     `fast`, `write_back`, and `mode` options are passed to
713     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
714     """
715     try:
716         changed = Changed.NO
717
718         if str(src) == "-":
719             is_stdin = True
720         elif str(src).startswith(STDIN_PLACEHOLDER):
721             is_stdin = True
722             # Use the original name again in case we want to print something
723             # to the user
724             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
725         else:
726             is_stdin = False
727
728         if is_stdin:
729             if src.suffix == ".pyi":
730                 mode = replace(mode, is_pyi=True)
731             elif src.suffix == ".ipynb":
732                 mode = replace(mode, is_ipynb=True)
733             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
734                 changed = Changed.YES
735         else:
736             cache: Cache = {}
737             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
738                 cache = read_cache(mode)
739                 res_src = src.resolve()
740                 res_src_s = str(res_src)
741                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
742                     changed = Changed.CACHED
743             if changed is not Changed.CACHED and format_file_in_place(
744                 src, fast=fast, write_back=write_back, mode=mode
745             ):
746                 changed = Changed.YES
747             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
748                 write_back is WriteBack.CHECK and changed is Changed.NO
749             ):
750                 write_cache(cache, [src], mode)
751         report.done(src, changed)
752     except Exception as exc:
753         if report.verbose:
754             traceback.print_exc()
755         report.failed(src, str(exc))
756
757
758 # diff-shades depends on being to monkeypatch this function to operate. I know it's
759 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
760 @mypyc_attr(patchable=True)
761 def reformat_many(
762     sources: Set[Path],
763     fast: bool,
764     write_back: WriteBack,
765     mode: Mode,
766     report: "Report",
767     workers: Optional[int],
768 ) -> None:
769     """Reformat multiple files using a ProcessPoolExecutor."""
770     executor: Executor
771     loop = asyncio.get_event_loop()
772     worker_count = workers if workers is not None else DEFAULT_WORKERS
773     if sys.platform == "win32":
774         # Work around https://bugs.python.org/issue26903
775         assert worker_count is not None
776         worker_count = min(worker_count, 60)
777     try:
778         executor = ProcessPoolExecutor(max_workers=worker_count)
779     except (ImportError, NotImplementedError, OSError):
780         # we arrive here if the underlying system does not support multi-processing
781         # like in AWS Lambda or Termux, in which case we gracefully fallback to
782         # a ThreadPoolExecutor with just a single worker (more workers would not do us
783         # any good due to the Global Interpreter Lock)
784         executor = ThreadPoolExecutor(max_workers=1)
785
786     try:
787         loop.run_until_complete(
788             schedule_formatting(
789                 sources=sources,
790                 fast=fast,
791                 write_back=write_back,
792                 mode=mode,
793                 report=report,
794                 loop=loop,
795                 executor=executor,
796             )
797         )
798     finally:
799         shutdown(loop)
800         if executor is not None:
801             executor.shutdown()
802
803
804 async def schedule_formatting(
805     sources: Set[Path],
806     fast: bool,
807     write_back: WriteBack,
808     mode: Mode,
809     report: "Report",
810     loop: asyncio.AbstractEventLoop,
811     executor: Executor,
812 ) -> None:
813     """Run formatting of `sources` in parallel using the provided `executor`.
814
815     (Use ProcessPoolExecutors for actual parallelism.)
816
817     `write_back`, `fast`, and `mode` options are passed to
818     :func:`format_file_in_place`.
819     """
820     cache: Cache = {}
821     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
822         cache = read_cache(mode)
823         sources, cached = filter_cached(cache, sources)
824         for src in sorted(cached):
825             report.done(src, Changed.CACHED)
826     if not sources:
827         return
828
829     cancelled = []
830     sources_to_cache = []
831     lock = None
832     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
833         # For diff output, we need locks to ensure we don't interleave output
834         # from different processes.
835         manager = Manager()
836         lock = manager.Lock()
837     tasks = {
838         asyncio.ensure_future(
839             loop.run_in_executor(
840                 executor, format_file_in_place, src, fast, mode, write_back, lock
841             )
842         ): src
843         for src in sorted(sources)
844     }
845     pending = tasks.keys()
846     try:
847         loop.add_signal_handler(signal.SIGINT, cancel, pending)
848         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
849     except NotImplementedError:
850         # There are no good alternatives for these on Windows.
851         pass
852     while pending:
853         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
854         for task in done:
855             src = tasks.pop(task)
856             if task.cancelled():
857                 cancelled.append(task)
858             elif task.exception():
859                 report.failed(src, str(task.exception()))
860             else:
861                 changed = Changed.YES if task.result() else Changed.NO
862                 # If the file was written back or was successfully checked as
863                 # well-formatted, store this information in the cache.
864                 if write_back is WriteBack.YES or (
865                     write_back is WriteBack.CHECK and changed is Changed.NO
866                 ):
867                     sources_to_cache.append(src)
868                 report.done(src, changed)
869     if cancelled:
870         if sys.version_info >= (3, 7):
871             await asyncio.gather(*cancelled, return_exceptions=True)
872         else:
873             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
874     if sources_to_cache:
875         write_cache(cache, sources_to_cache, mode)
876
877
878 def format_file_in_place(
879     src: Path,
880     fast: bool,
881     mode: Mode,
882     write_back: WriteBack = WriteBack.NO,
883     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
884 ) -> bool:
885     """Format file under `src` path. Return True if changed.
886
887     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
888     code to the file.
889     `mode` and `fast` options are passed to :func:`format_file_contents`.
890     """
891     if src.suffix == ".pyi":
892         mode = replace(mode, is_pyi=True)
893     elif src.suffix == ".ipynb":
894         mode = replace(mode, is_ipynb=True)
895
896     then = datetime.utcfromtimestamp(src.stat().st_mtime)
897     with open(src, "rb") as buf:
898         src_contents, encoding, newline = decode_bytes(buf.read())
899     try:
900         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
901     except NothingChanged:
902         return False
903     except JSONDecodeError:
904         raise ValueError(
905             f"File '{src}' cannot be parsed as valid Jupyter notebook."
906         ) from None
907
908     if write_back == WriteBack.YES:
909         with open(src, "w", encoding=encoding, newline=newline) as f:
910             f.write(dst_contents)
911     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
912         now = datetime.utcnow()
913         src_name = f"{src}\t{then} +0000"
914         dst_name = f"{src}\t{now} +0000"
915         if mode.is_ipynb:
916             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
917         else:
918             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
919
920         if write_back == WriteBack.COLOR_DIFF:
921             diff_contents = color_diff(diff_contents)
922
923         with lock or nullcontext():
924             f = io.TextIOWrapper(
925                 sys.stdout.buffer,
926                 encoding=encoding,
927                 newline=newline,
928                 write_through=True,
929             )
930             f = wrap_stream_for_windows(f)
931             f.write(diff_contents)
932             f.detach()
933
934     return True
935
936
937 def format_stdin_to_stdout(
938     fast: bool,
939     *,
940     content: Optional[str] = None,
941     write_back: WriteBack = WriteBack.NO,
942     mode: Mode,
943 ) -> bool:
944     """Format file on stdin. Return True if changed.
945
946     If content is None, it's read from sys.stdin.
947
948     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
949     write a diff to stdout. The `mode` argument is passed to
950     :func:`format_file_contents`.
951     """
952     then = datetime.utcnow()
953
954     if content is None:
955         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
956     else:
957         src, encoding, newline = content, "utf-8", ""
958
959     dst = src
960     try:
961         dst = format_file_contents(src, fast=fast, mode=mode)
962         return True
963
964     except NothingChanged:
965         return False
966
967     finally:
968         f = io.TextIOWrapper(
969             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
970         )
971         if write_back == WriteBack.YES:
972             # Make sure there's a newline after the content
973             if dst and dst[-1] != "\n":
974                 dst += "\n"
975             f.write(dst)
976         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
977             now = datetime.utcnow()
978             src_name = f"STDIN\t{then} +0000"
979             dst_name = f"STDOUT\t{now} +0000"
980             d = diff(src, dst, src_name, dst_name)
981             if write_back == WriteBack.COLOR_DIFF:
982                 d = color_diff(d)
983                 f = wrap_stream_for_windows(f)
984             f.write(d)
985         f.detach()
986
987
988 def check_stability_and_equivalence(
989     src_contents: str, dst_contents: str, *, mode: Mode
990 ) -> None:
991     """Perform stability and equivalence checks.
992
993     Raise AssertionError if source and destination contents are not
994     equivalent, or if a second pass of the formatter would format the
995     content differently.
996     """
997     assert_equivalent(src_contents, dst_contents)
998     assert_stable(src_contents, dst_contents, mode=mode)
999
1000
1001 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1002     """Reformat contents of a file and return new contents.
1003
1004     If `fast` is False, additionally confirm that the reformatted code is
1005     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
1006     `mode` is passed to :func:`format_str`.
1007     """
1008     if not src_contents.strip():
1009         raise NothingChanged
1010
1011     if mode.is_ipynb:
1012         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
1013     else:
1014         dst_contents = format_str(src_contents, mode=mode)
1015     if src_contents == dst_contents:
1016         raise NothingChanged
1017
1018     if not fast and not mode.is_ipynb:
1019         # Jupyter notebooks will already have been checked above.
1020         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
1021     return dst_contents
1022
1023
1024 def validate_cell(src: str, mode: Mode) -> None:
1025     """Check that cell does not already contain TransformerManager transformations,
1026     or non-Python cell magics, which might cause tokenizer_rt to break because of
1027     indentations.
1028
1029     If a cell contains ``!ls``, then it'll be transformed to
1030     ``get_ipython().system('ls')``. However, if the cell originally contained
1031     ``get_ipython().system('ls')``, then it would get transformed in the same way:
1032
1033         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
1034         "get_ipython().system('ls')\n"
1035         >>> TransformerManager().transform_cell("!ls")
1036         "get_ipython().system('ls')\n"
1037
1038     Due to the impossibility of safely roundtripping in such situations, cells
1039     containing transformed magics will be ignored.
1040     """
1041     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
1042         raise NothingChanged
1043     if (
1044         src[:2] == "%%"
1045         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
1046     ):
1047         raise NothingChanged
1048
1049
1050 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
1051     """Format code in given cell of Jupyter notebook.
1052
1053     General idea is:
1054
1055       - if cell has trailing semicolon, remove it;
1056       - if cell has IPython magics, mask them;
1057       - format cell;
1058       - reinstate IPython magics;
1059       - reinstate trailing semicolon (if originally present);
1060       - strip trailing newlines.
1061
1062     Cells with syntax errors will not be processed, as they
1063     could potentially be automagics or multi-line magics, which
1064     are currently not supported.
1065     """
1066     validate_cell(src, mode)
1067     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
1068         src
1069     )
1070     try:
1071         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
1072     except SyntaxError:
1073         raise NothingChanged from None
1074     masked_dst = format_str(masked_src, mode=mode)
1075     if not fast:
1076         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
1077     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1078     dst = put_trailing_semicolon_back(
1079         dst_without_trailing_semicolon, has_trailing_semicolon
1080     )
1081     dst = dst.rstrip("\n")
1082     if dst == src:
1083         raise NothingChanged from None
1084     return dst
1085
1086
1087 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1088     """If notebook is marked as non-Python, don't format it.
1089
1090     All notebook metadata fields are optional, see
1091     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1092     if a notebook has empty metadata, we will try to parse it anyway.
1093     """
1094     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1095     if language is not None and language != "python":
1096         raise NothingChanged from None
1097
1098
1099 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1100     """Format Jupyter notebook.
1101
1102     Operate cell-by-cell, only on code cells, only for Python notebooks.
1103     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1104     """
1105     trailing_newline = src_contents[-1] == "\n"
1106     modified = False
1107     nb = json.loads(src_contents)
1108     validate_metadata(nb)
1109     for cell in nb["cells"]:
1110         if cell.get("cell_type", None) == "code":
1111             try:
1112                 src = "".join(cell["source"])
1113                 dst = format_cell(src, fast=fast, mode=mode)
1114             except NothingChanged:
1115                 pass
1116             else:
1117                 cell["source"] = dst.splitlines(keepends=True)
1118                 modified = True
1119     if modified:
1120         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1121         if trailing_newline:
1122             dst_contents = dst_contents + "\n"
1123         return dst_contents
1124     else:
1125         raise NothingChanged
1126
1127
1128 def format_str(src_contents: str, *, mode: Mode) -> str:
1129     """Reformat a string and return new contents.
1130
1131     `mode` determines formatting options, such as how many characters per line are
1132     allowed.  Example:
1133
1134     >>> import black
1135     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1136     def f(arg: str = "") -> None:
1137         ...
1138
1139     A more complex example:
1140
1141     >>> print(
1142     ...   black.format_str(
1143     ...     "def f(arg:str='')->None: hey",
1144     ...     mode=black.Mode(
1145     ...       target_versions={black.TargetVersion.PY36},
1146     ...       line_length=10,
1147     ...       string_normalization=False,
1148     ...       is_pyi=False,
1149     ...     ),
1150     ...   ),
1151     ... )
1152     def f(
1153         arg: str = '',
1154     ) -> None:
1155         hey
1156
1157     """
1158     dst_contents = _format_str_once(src_contents, mode=mode)
1159     # Forced second pass to work around optional trailing commas (becoming
1160     # forced trailing commas on pass 2) interacting differently with optional
1161     # parentheses.  Admittedly ugly.
1162     if src_contents != dst_contents:
1163         return _format_str_once(dst_contents, mode=mode)
1164     return dst_contents
1165
1166
1167 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1168     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1169     dst_contents = []
1170     future_imports = get_future_imports(src_node)
1171     if mode.target_versions:
1172         versions = mode.target_versions
1173     else:
1174         versions = detect_target_versions(src_node, future_imports=future_imports)
1175
1176     normalize_fmt_off(src_node, preview=mode.preview)
1177     lines = LineGenerator(mode=mode)
1178     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1179     empty_line = Line(mode=mode)
1180     after = 0
1181     split_line_features = {
1182         feature
1183         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1184         if supports_feature(versions, feature)
1185     }
1186     for current_line in lines.visit(src_node):
1187         dst_contents.append(str(empty_line) * after)
1188         before, after = elt.maybe_empty_lines(current_line)
1189         dst_contents.append(str(empty_line) * before)
1190         for line in transform_line(
1191             current_line, mode=mode, features=split_line_features
1192         ):
1193             dst_contents.append(str(line))
1194     return "".join(dst_contents)
1195
1196
1197 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1198     """Return a tuple of (decoded_contents, encoding, newline).
1199
1200     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1201     universal newlines (i.e. only contains LF).
1202     """
1203     srcbuf = io.BytesIO(src)
1204     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1205     if not lines:
1206         return "", encoding, "\n"
1207
1208     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1209     srcbuf.seek(0)
1210     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1211         return tiow.read(), encoding, newline
1212
1213
1214 def get_features_used(  # noqa: C901
1215     node: Node, *, future_imports: Optional[Set[str]] = None
1216 ) -> Set[Feature]:
1217     """Return a set of (relatively) new Python features used in this file.
1218
1219     Currently looking for:
1220     - f-strings;
1221     - underscores in numeric literals;
1222     - trailing commas after * or ** in function signatures and calls;
1223     - positional only arguments in function signatures and lambdas;
1224     - assignment expression;
1225     - relaxed decorator syntax;
1226     - usage of __future__ flags (annotations);
1227     - print / exec statements;
1228     """
1229     features: Set[Feature] = set()
1230     if future_imports:
1231         features |= {
1232             FUTURE_FLAG_TO_FEATURE[future_import]
1233             for future_import in future_imports
1234             if future_import in FUTURE_FLAG_TO_FEATURE
1235         }
1236
1237     for n in node.pre_order():
1238         if is_string_token(n):
1239             value_head = n.value[:2]
1240             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1241                 features.add(Feature.F_STRINGS)
1242
1243         elif n.type == token.NUMBER:
1244             assert isinstance(n, Leaf)
1245             if "_" in n.value:
1246                 features.add(Feature.NUMERIC_UNDERSCORES)
1247
1248         elif n.type == token.SLASH:
1249             if n.parent and n.parent.type in {
1250                 syms.typedargslist,
1251                 syms.arglist,
1252                 syms.varargslist,
1253             }:
1254                 features.add(Feature.POS_ONLY_ARGUMENTS)
1255
1256         elif n.type == token.COLONEQUAL:
1257             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1258
1259         elif n.type == syms.decorator:
1260             if len(n.children) > 1 and not is_simple_decorator_expression(
1261                 n.children[1]
1262             ):
1263                 features.add(Feature.RELAXED_DECORATORS)
1264
1265         elif (
1266             n.type in {syms.typedargslist, syms.arglist}
1267             and n.children
1268             and n.children[-1].type == token.COMMA
1269         ):
1270             if n.type == syms.typedargslist:
1271                 feature = Feature.TRAILING_COMMA_IN_DEF
1272             else:
1273                 feature = Feature.TRAILING_COMMA_IN_CALL
1274
1275             for ch in n.children:
1276                 if ch.type in STARS:
1277                     features.add(feature)
1278
1279                 if ch.type == syms.argument:
1280                     for argch in ch.children:
1281                         if argch.type in STARS:
1282                             features.add(feature)
1283
1284         elif (
1285             n.type in {syms.return_stmt, syms.yield_expr}
1286             and len(n.children) >= 2
1287             and n.children[1].type == syms.testlist_star_expr
1288             and any(child.type == syms.star_expr for child in n.children[1].children)
1289         ):
1290             features.add(Feature.UNPACKING_ON_FLOW)
1291
1292         elif (
1293             n.type == syms.annassign
1294             and len(n.children) >= 4
1295             and n.children[3].type == syms.testlist_star_expr
1296         ):
1297             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1298
1299     return features
1300
1301
1302 def detect_target_versions(
1303     node: Node, *, future_imports: Optional[Set[str]] = None
1304 ) -> Set[TargetVersion]:
1305     """Detect the version to target based on the nodes used."""
1306     features = get_features_used(node, future_imports=future_imports)
1307     return {
1308         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1309     }
1310
1311
1312 def get_future_imports(node: Node) -> Set[str]:
1313     """Return a set of __future__ imports in the file."""
1314     imports: Set[str] = set()
1315
1316     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1317         for child in children:
1318             if isinstance(child, Leaf):
1319                 if child.type == token.NAME:
1320                     yield child.value
1321
1322             elif child.type == syms.import_as_name:
1323                 orig_name = child.children[0]
1324                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1325                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1326                 yield orig_name.value
1327
1328             elif child.type == syms.import_as_names:
1329                 yield from get_imports_from_children(child.children)
1330
1331             else:
1332                 raise AssertionError("Invalid syntax parsing imports")
1333
1334     for child in node.children:
1335         if child.type != syms.simple_stmt:
1336             break
1337
1338         first_child = child.children[0]
1339         if isinstance(first_child, Leaf):
1340             # Continue looking if we see a docstring; otherwise stop.
1341             if (
1342                 len(child.children) == 2
1343                 and first_child.type == token.STRING
1344                 and child.children[1].type == token.NEWLINE
1345             ):
1346                 continue
1347
1348             break
1349
1350         elif first_child.type == syms.import_from:
1351             module_name = first_child.children[1]
1352             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1353                 break
1354
1355             imports |= set(get_imports_from_children(first_child.children[3:]))
1356         else:
1357             break
1358
1359     return imports
1360
1361
1362 def assert_equivalent(src: str, dst: str) -> None:
1363     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1364     try:
1365         src_ast = parse_ast(src)
1366     except Exception as exc:
1367         raise AssertionError(
1368             "cannot use --safe with this file; failed to parse source file AST: "
1369             f"{exc}\n"
1370             "This could be caused by running Black with an older Python version "
1371             "that does not support new syntax used in your source file."
1372         ) from exc
1373
1374     try:
1375         dst_ast = parse_ast(dst)
1376     except Exception as exc:
1377         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1378         raise AssertionError(
1379             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1380             "Please report a bug on https://github.com/psf/black/issues.  "
1381             f"This invalid output might be helpful: {log}"
1382         ) from None
1383
1384     src_ast_str = "\n".join(stringify_ast(src_ast))
1385     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1386     if src_ast_str != dst_ast_str:
1387         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1388         raise AssertionError(
1389             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1390             " source.  Please report a bug on "
1391             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1392         ) from None
1393
1394
1395 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1396     """Raise AssertionError if `dst` reformats differently the second time."""
1397     # We shouldn't call format_str() here, because that formats the string
1398     # twice and may hide a bug where we bounce back and forth between two
1399     # versions.
1400     newdst = _format_str_once(dst, mode=mode)
1401     if dst != newdst:
1402         log = dump_to_file(
1403             str(mode),
1404             diff(src, dst, "source", "first pass"),
1405             diff(dst, newdst, "first pass", "second pass"),
1406         )
1407         raise AssertionError(
1408             "INTERNAL ERROR: Black produced different code on the second pass of the"
1409             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1410             f"  This diff might be helpful: {log}"
1411         ) from None
1412
1413
1414 @contextmanager
1415 def nullcontext() -> Iterator[None]:
1416     """Return an empty context manager.
1417
1418     To be used like `nullcontext` in Python 3.7.
1419     """
1420     yield
1421
1422
1423 def patch_click() -> None:
1424     """Make Click not crash on Python 3.6 with LANG=C.
1425
1426     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1427     default which restricts paths that it can access during the lifetime of the
1428     application.  Click refuses to work in this scenario by raising a RuntimeError.
1429
1430     In case of Black the likelihood that non-ASCII characters are going to be used in
1431     file paths is minimal since it's Python source code.  Moreover, this crash was
1432     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1433     """
1434     modules: List[Any] = []
1435     try:
1436         from click import core
1437     except ImportError:
1438         pass
1439     else:
1440         modules.append(core)
1441     try:
1442         # Removed in Click 8.1.0 and newer; we keep this around for users who have
1443         # older versions installed.
1444         from click import _unicodefun  # type: ignore
1445     except ImportError:
1446         pass
1447     else:
1448         modules.append(_unicodefun)
1449
1450     for module in modules:
1451         if hasattr(module, "_verify_python3_env"):
1452             module._verify_python3_env = lambda: None  # type: ignore
1453         if hasattr(module, "_verify_python_env"):
1454             module._verify_python_env = lambda: None  # type: ignore
1455
1456
1457 def patched_main() -> None:
1458     maybe_install_uvloop()
1459     freeze_support()
1460     patch_click()
1461     main()
1462
1463
1464 if __name__ == "__main__":
1465     patched_main()