]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

a82cf6a6cd782644df4d0aa922a48f9c1b810a24
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Sized,
30     Tuple,
31     Union,
32 )
33
34 import click
35 from click.core import ParameterSource
36 from dataclasses import replace
37 from mypy_extensions import mypyc_attr
38
39 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
40 from black.const import STDIN_PLACEHOLDER
41 from black.nodes import STARS, syms, is_simple_decorator_expression
42 from black.nodes import is_string_token
43 from black.lines import Line, EmptyLineTracker
44 from black.linegen import transform_line, LineGenerator, LN
45 from black.comments import normalize_fmt_off
46 from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
47 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
48 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
49 from black.concurrency import cancel, shutdown, maybe_install_uvloop
50 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
51 from black.report import Report, Changed, NothingChanged
52 from black.files import (
53     find_project_root,
54     find_pyproject_toml,
55     parse_pyproject_toml,
56     find_user_pyproject_toml,
57 )
58 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
59 from black.files import wrap_stream_for_windows
60 from black.parsing import InvalidInput  # noqa F401
61 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
62 from black.handle_ipynb_magics import (
63     mask_cell,
64     unmask_cell,
65     remove_trailing_semicolon,
66     put_trailing_semicolon_back,
67     TRANSFORMED_MAGICS,
68     PYTHON_CELL_MAGICS,
69     jupyter_dependencies_are_installed,
70 )
71
72
73 # lib2to3 fork
74 from blib2to3.pytree import Node, Leaf
75 from blib2to3.pgen2 import token
76
77 from _black_version import version as __version__
78
79 COMPILED = Path(__file__).suffix in (".pyd", ".so")
80
81 # types
82 FileContent = str
83 Encoding = str
84 NewLine = str
85
86
87 class WriteBack(Enum):
88     NO = 0
89     YES = 1
90     DIFF = 2
91     CHECK = 3
92     COLOR_DIFF = 4
93
94     @classmethod
95     def from_configuration(
96         cls, *, check: bool, diff: bool, color: bool = False
97     ) -> "WriteBack":
98         if check and not diff:
99             return cls.CHECK
100
101         if diff and color:
102             return cls.COLOR_DIFF
103
104         return cls.DIFF if diff else cls.YES
105
106
107 # Legacy name, left for integrations.
108 FileMode = Mode
109
110 DEFAULT_WORKERS = os.cpu_count()
111
112
113 def read_pyproject_toml(
114     ctx: click.Context, param: click.Parameter, value: Optional[str]
115 ) -> Optional[str]:
116     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
117
118     Returns the path to a successfully found and read configuration file, None
119     otherwise.
120     """
121     if not value:
122         value = find_pyproject_toml(ctx.params.get("src", ()))
123         if value is None:
124             return None
125
126     try:
127         config = parse_pyproject_toml(value)
128     except (OSError, ValueError) as e:
129         raise click.FileError(
130             filename=value, hint=f"Error reading configuration file: {e}"
131         ) from None
132
133     if not config:
134         return None
135     else:
136         # Sanitize the values to be Click friendly. For more information please see:
137         # https://github.com/psf/black/issues/1458
138         # https://github.com/pallets/click/issues/1567
139         config = {
140             k: str(v) if not isinstance(v, (list, dict)) else v
141             for k, v in config.items()
142         }
143
144     target_version = config.get("target_version")
145     if target_version is not None and not isinstance(target_version, list):
146         raise click.BadOptionUsage(
147             "target-version", "Config key target-version must be a list"
148         )
149
150     default_map: Dict[str, Any] = {}
151     if ctx.default_map:
152         default_map.update(ctx.default_map)
153     default_map.update(config)
154
155     ctx.default_map = default_map
156     return value
157
158
159 def target_version_option_callback(
160     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
161 ) -> List[TargetVersion]:
162     """Compute the target versions from a --target-version flag.
163
164     This is its own function because mypy couldn't infer the type correctly
165     when it was a lambda, causing mypyc trouble.
166     """
167     return [TargetVersion[val.upper()] for val in v]
168
169
170 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
171     """Compile a regular expression string in `regex`.
172
173     If it contains newlines, use verbose mode.
174     """
175     if "\n" in regex:
176         regex = "(?x)" + regex
177     compiled: Pattern[str] = re.compile(regex)
178     return compiled
179
180
181 def validate_regex(
182     ctx: click.Context,
183     param: click.Parameter,
184     value: Optional[str],
185 ) -> Optional[Pattern[str]]:
186     try:
187         return re_compile_maybe_verbose(value) if value is not None else None
188     except re.error as e:
189         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
190
191
192 @click.command(
193     context_settings={"help_option_names": ["-h", "--help"]},
194     # While Click does set this field automatically using the docstring, mypyc
195     # (annoyingly) strips 'em so we need to set it here too.
196     help="The uncompromising code formatter.",
197 )
198 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
199 @click.option(
200     "-l",
201     "--line-length",
202     type=int,
203     default=DEFAULT_LINE_LENGTH,
204     help="How many characters per line to allow.",
205     show_default=True,
206 )
207 @click.option(
208     "-t",
209     "--target-version",
210     type=click.Choice([v.name.lower() for v in TargetVersion]),
211     callback=target_version_option_callback,
212     multiple=True,
213     help=(
214         "Python versions that should be supported by Black's output. [default: per-file"
215         " auto-detection]"
216     ),
217 )
218 @click.option(
219     "--pyi",
220     is_flag=True,
221     help=(
222         "Format all input files like typing stubs regardless of file extension (useful"
223         " when piping source on standard input)."
224     ),
225 )
226 @click.option(
227     "--ipynb",
228     is_flag=True,
229     help=(
230         "Format all input files like Jupyter Notebooks regardless of file extension "
231         "(useful when piping source on standard input)."
232     ),
233 )
234 @click.option(
235     "--python-cell-magics",
236     multiple=True,
237     help=(
238         "When processing Jupyter Notebooks, add the given magic to the list"
239         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
240         " Useful for formatting cells with custom python magics."
241     ),
242     default=[],
243 )
244 @click.option(
245     "-S",
246     "--skip-string-normalization",
247     is_flag=True,
248     help="Don't normalize string quotes or prefixes.",
249 )
250 @click.option(
251     "-C",
252     "--skip-magic-trailing-comma",
253     is_flag=True,
254     help="Don't use trailing commas as a reason to split lines.",
255 )
256 @click.option(
257     "--experimental-string-processing",
258     is_flag=True,
259     hidden=True,
260     help="(DEPRECATED and now included in --preview) Normalize string literals.",
261 )
262 @click.option(
263     "--preview",
264     is_flag=True,
265     help=(
266         "Enable potentially disruptive style changes that may be added to Black's main"
267         " functionality in the next major release."
268     ),
269 )
270 @click.option(
271     "--check",
272     is_flag=True,
273     help=(
274         "Don't write the files back, just return the status. Return code 0 means"
275         " nothing would change. Return code 1 means some files would be reformatted."
276         " Return code 123 means there was an internal error."
277     ),
278 )
279 @click.option(
280     "--diff",
281     is_flag=True,
282     help="Don't write the files back, just output a diff for each file on stdout.",
283 )
284 @click.option(
285     "--color/--no-color",
286     is_flag=True,
287     help="Show colored diff. Only applies when `--diff` is given.",
288 )
289 @click.option(
290     "--fast/--safe",
291     is_flag=True,
292     help="If --fast given, skip temporary sanity checks. [default: --safe]",
293 )
294 @click.option(
295     "--required-version",
296     type=str,
297     help=(
298         "Require a specific version of Black to be running (useful for unifying results"
299         " across many environments e.g. with a pyproject.toml file). It can be"
300         " either a major version number or an exact version."
301     ),
302 )
303 @click.option(
304     "--include",
305     type=str,
306     default=DEFAULT_INCLUDES,
307     callback=validate_regex,
308     help=(
309         "A regular expression that matches files and directories that should be"
310         " included on recursive searches. An empty value means all files are included"
311         " regardless of the name. Use forward slashes for directories on all platforms"
312         " (Windows, too). Exclusions are calculated first, inclusions later."
313     ),
314     show_default=True,
315 )
316 @click.option(
317     "--exclude",
318     type=str,
319     callback=validate_regex,
320     help=(
321         "A regular expression that matches files and directories that should be"
322         " excluded on recursive searches. An empty value means no paths are excluded."
323         " Use forward slashes for directories on all platforms (Windows, too)."
324         " Exclusions are calculated first, inclusions later. [default:"
325         f" {DEFAULT_EXCLUDES}]"
326     ),
327     show_default=False,
328 )
329 @click.option(
330     "--extend-exclude",
331     type=str,
332     callback=validate_regex,
333     help=(
334         "Like --exclude, but adds additional files and directories on top of the"
335         " excluded ones. (Useful if you simply want to add to the default)"
336     ),
337 )
338 @click.option(
339     "--force-exclude",
340     type=str,
341     callback=validate_regex,
342     help=(
343         "Like --exclude, but files and directories matching this regex will be "
344         "excluded even when they are passed explicitly as arguments."
345     ),
346 )
347 @click.option(
348     "--stdin-filename",
349     type=str,
350     help=(
351         "The name of the file when passing it through stdin. Useful to make "
352         "sure Black will respect --force-exclude option on some "
353         "editors that rely on using stdin."
354     ),
355 )
356 @click.option(
357     "-W",
358     "--workers",
359     type=click.IntRange(min=1),
360     default=DEFAULT_WORKERS,
361     show_default=True,
362     help="Number of parallel workers",
363 )
364 @click.option(
365     "-q",
366     "--quiet",
367     is_flag=True,
368     help=(
369         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
370         " those with 2>/dev/null."
371     ),
372 )
373 @click.option(
374     "-v",
375     "--verbose",
376     is_flag=True,
377     help=(
378         "Also emit messages to stderr about files that were not changed or were ignored"
379         " due to exclusion patterns."
380     ),
381 )
382 @click.version_option(
383     version=__version__,
384     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
385 )
386 @click.argument(
387     "src",
388     nargs=-1,
389     type=click.Path(
390         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
391     ),
392     is_eager=True,
393     metavar="SRC ...",
394 )
395 @click.option(
396     "--config",
397     type=click.Path(
398         exists=True,
399         file_okay=True,
400         dir_okay=False,
401         readable=True,
402         allow_dash=False,
403         path_type=str,
404     ),
405     is_eager=True,
406     callback=read_pyproject_toml,
407     help="Read configuration from FILE path.",
408 )
409 @click.pass_context
410 def main(  # noqa: C901
411     ctx: click.Context,
412     code: Optional[str],
413     line_length: int,
414     target_version: List[TargetVersion],
415     check: bool,
416     diff: bool,
417     color: bool,
418     fast: bool,
419     pyi: bool,
420     ipynb: bool,
421     python_cell_magics: Sequence[str],
422     skip_string_normalization: bool,
423     skip_magic_trailing_comma: bool,
424     experimental_string_processing: bool,
425     preview: bool,
426     quiet: bool,
427     verbose: bool,
428     required_version: Optional[str],
429     include: Pattern[str],
430     exclude: Optional[Pattern[str]],
431     extend_exclude: Optional[Pattern[str]],
432     force_exclude: Optional[Pattern[str]],
433     stdin_filename: Optional[str],
434     workers: int,
435     src: Tuple[str, ...],
436     config: Optional[str],
437 ) -> None:
438     """The uncompromising code formatter."""
439     ctx.ensure_object(dict)
440
441     if src and code is not None:
442         out(
443             main.get_usage(ctx)
444             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
445         )
446         ctx.exit(1)
447     if not src and code is None:
448         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
449         ctx.exit(1)
450
451     root, method = find_project_root(src) if code is None else (None, None)
452     ctx.obj["root"] = root
453
454     if verbose:
455         if root:
456             out(
457                 f"Identified `{root}` as project root containing a {method}.",
458                 fg="blue",
459             )
460
461             normalized = [
462                 (normalize_path_maybe_ignore(Path(source), root), source)
463                 for source in src
464             ]
465             srcs_string = ", ".join(
466                 [
467                     f'"{_norm}"'
468                     if _norm
469                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
470                     for _norm, source in normalized
471                 ]
472             )
473             out(f"Sources to be formatted: {srcs_string}", fg="blue")
474
475         if config:
476             config_source = ctx.get_parameter_source("config")
477             user_level_config = str(find_user_pyproject_toml())
478             if config == user_level_config:
479                 out(
480                     "Using configuration from user-level config at "
481                     f"'{user_level_config}'.",
482                     fg="blue",
483                 )
484             elif config_source in (
485                 ParameterSource.DEFAULT,
486                 ParameterSource.DEFAULT_MAP,
487             ):
488                 out("Using configuration from project root.", fg="blue")
489             else:
490                 out(f"Using configuration in '{config}'.", fg="blue")
491
492     error_msg = "Oh no! 💥 💔 💥"
493     if (
494         required_version
495         and required_version != __version__
496         and required_version != __version__.split(".")[0]
497     ):
498         err(
499             f"{error_msg} The required version `{required_version}` does not match"
500             f" the running version `{__version__}`!"
501         )
502         ctx.exit(1)
503     if ipynb and pyi:
504         err("Cannot pass both `pyi` and `ipynb` flags!")
505         ctx.exit(1)
506
507     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
508     if target_version:
509         versions = set(target_version)
510     else:
511         # We'll autodetect later.
512         versions = set()
513     mode = Mode(
514         target_versions=versions,
515         line_length=line_length,
516         is_pyi=pyi,
517         is_ipynb=ipynb,
518         string_normalization=not skip_string_normalization,
519         magic_trailing_comma=not skip_magic_trailing_comma,
520         experimental_string_processing=experimental_string_processing,
521         preview=preview,
522         python_cell_magics=set(python_cell_magics),
523     )
524
525     if code is not None:
526         # Run in quiet mode by default with -c; the extra output isn't useful.
527         # You can still pass -v to get verbose output.
528         quiet = True
529
530     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
531
532     if code is not None:
533         reformat_code(
534             content=code, fast=fast, write_back=write_back, mode=mode, report=report
535         )
536     else:
537         try:
538             sources = get_sources(
539                 ctx=ctx,
540                 src=src,
541                 quiet=quiet,
542                 verbose=verbose,
543                 include=include,
544                 exclude=exclude,
545                 extend_exclude=extend_exclude,
546                 force_exclude=force_exclude,
547                 report=report,
548                 stdin_filename=stdin_filename,
549             )
550         except GitWildMatchPatternError:
551             ctx.exit(1)
552
553         path_empty(
554             sources,
555             "No Python files are present to be formatted. Nothing to do 😴",
556             quiet,
557             verbose,
558             ctx,
559         )
560
561         if len(sources) == 1:
562             reformat_one(
563                 src=sources.pop(),
564                 fast=fast,
565                 write_back=write_back,
566                 mode=mode,
567                 report=report,
568             )
569         else:
570             reformat_many(
571                 sources=sources,
572                 fast=fast,
573                 write_back=write_back,
574                 mode=mode,
575                 report=report,
576                 workers=workers,
577             )
578
579     if verbose or not quiet:
580         if code is None and (verbose or report.change_count or report.failure_count):
581             out()
582         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
583         if code is None:
584             click.echo(str(report), err=True)
585     ctx.exit(report.return_code)
586
587
588 def get_sources(
589     *,
590     ctx: click.Context,
591     src: Tuple[str, ...],
592     quiet: bool,
593     verbose: bool,
594     include: Pattern[str],
595     exclude: Optional[Pattern[str]],
596     extend_exclude: Optional[Pattern[str]],
597     force_exclude: Optional[Pattern[str]],
598     report: "Report",
599     stdin_filename: Optional[str],
600 ) -> Set[Path]:
601     """Compute the set of files to be formatted."""
602     sources: Set[Path] = set()
603
604     if exclude is None:
605         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
606         gitignore = get_gitignore(ctx.obj["root"])
607     else:
608         gitignore = None
609
610     for s in src:
611         if s == "-" and stdin_filename:
612             p = Path(stdin_filename)
613             is_stdin = True
614         else:
615             p = Path(s)
616             is_stdin = False
617
618         if is_stdin or p.is_file():
619             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
620             if normalized_path is None:
621                 continue
622
623             normalized_path = "/" + normalized_path
624             # Hard-exclude any files that matches the `--force-exclude` regex.
625             if force_exclude:
626                 force_exclude_match = force_exclude.search(normalized_path)
627             else:
628                 force_exclude_match = None
629             if force_exclude_match and force_exclude_match.group(0):
630                 report.path_ignored(p, "matches the --force-exclude regular expression")
631                 continue
632
633             if is_stdin:
634                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
635
636             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
637                 verbose=verbose, quiet=quiet
638             ):
639                 continue
640
641             sources.add(p)
642         elif p.is_dir():
643             sources.update(
644                 gen_python_files(
645                     p.iterdir(),
646                     ctx.obj["root"],
647                     include,
648                     exclude,
649                     extend_exclude,
650                     force_exclude,
651                     report,
652                     gitignore,
653                     verbose=verbose,
654                     quiet=quiet,
655                 )
656             )
657         elif s == "-":
658             sources.add(p)
659         else:
660             err(f"invalid path: {s}")
661     return sources
662
663
664 def path_empty(
665     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
666 ) -> None:
667     """
668     Exit if there is no `src` provided for formatting
669     """
670     if not src:
671         if verbose or not quiet:
672             out(msg)
673         ctx.exit(0)
674
675
676 def reformat_code(
677     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
678 ) -> None:
679     """
680     Reformat and print out `content` without spawning child processes.
681     Similar to `reformat_one`, but for string content.
682
683     `fast`, `write_back`, and `mode` options are passed to
684     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
685     """
686     path = Path("<string>")
687     try:
688         changed = Changed.NO
689         if format_stdin_to_stdout(
690             content=content, fast=fast, write_back=write_back, mode=mode
691         ):
692             changed = Changed.YES
693         report.done(path, changed)
694     except Exception as exc:
695         if report.verbose:
696             traceback.print_exc()
697         report.failed(path, str(exc))
698
699
700 # diff-shades depends on being to monkeypatch this function to operate. I know it's
701 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
702 @mypyc_attr(patchable=True)
703 def reformat_one(
704     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
705 ) -> None:
706     """Reformat a single file under `src` without spawning child processes.
707
708     `fast`, `write_back`, and `mode` options are passed to
709     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
710     """
711     try:
712         changed = Changed.NO
713
714         if str(src) == "-":
715             is_stdin = True
716         elif str(src).startswith(STDIN_PLACEHOLDER):
717             is_stdin = True
718             # Use the original name again in case we want to print something
719             # to the user
720             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
721         else:
722             is_stdin = False
723
724         if is_stdin:
725             if src.suffix == ".pyi":
726                 mode = replace(mode, is_pyi=True)
727             elif src.suffix == ".ipynb":
728                 mode = replace(mode, is_ipynb=True)
729             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
730                 changed = Changed.YES
731         else:
732             cache: Cache = {}
733             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
734                 cache = read_cache(mode)
735                 res_src = src.resolve()
736                 res_src_s = str(res_src)
737                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
738                     changed = Changed.CACHED
739             if changed is not Changed.CACHED and format_file_in_place(
740                 src, fast=fast, write_back=write_back, mode=mode
741             ):
742                 changed = Changed.YES
743             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
744                 write_back is WriteBack.CHECK and changed is Changed.NO
745             ):
746                 write_cache(cache, [src], mode)
747         report.done(src, changed)
748     except Exception as exc:
749         if report.verbose:
750             traceback.print_exc()
751         report.failed(src, str(exc))
752
753
754 # diff-shades depends on being to monkeypatch this function to operate. I know it's
755 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
756 @mypyc_attr(patchable=True)
757 def reformat_many(
758     sources: Set[Path],
759     fast: bool,
760     write_back: WriteBack,
761     mode: Mode,
762     report: "Report",
763     workers: Optional[int],
764 ) -> None:
765     """Reformat multiple files using a ProcessPoolExecutor."""
766     executor: Executor
767     loop = asyncio.get_event_loop()
768     worker_count = workers if workers is not None else DEFAULT_WORKERS
769     if sys.platform == "win32":
770         # Work around https://bugs.python.org/issue26903
771         assert worker_count is not None
772         worker_count = min(worker_count, 60)
773     try:
774         executor = ProcessPoolExecutor(max_workers=worker_count)
775     except (ImportError, NotImplementedError, OSError):
776         # we arrive here if the underlying system does not support multi-processing
777         # like in AWS Lambda or Termux, in which case we gracefully fallback to
778         # a ThreadPoolExecutor with just a single worker (more workers would not do us
779         # any good due to the Global Interpreter Lock)
780         executor = ThreadPoolExecutor(max_workers=1)
781
782     try:
783         loop.run_until_complete(
784             schedule_formatting(
785                 sources=sources,
786                 fast=fast,
787                 write_back=write_back,
788                 mode=mode,
789                 report=report,
790                 loop=loop,
791                 executor=executor,
792             )
793         )
794     finally:
795         shutdown(loop)
796         if executor is not None:
797             executor.shutdown()
798
799
800 async def schedule_formatting(
801     sources: Set[Path],
802     fast: bool,
803     write_back: WriteBack,
804     mode: Mode,
805     report: "Report",
806     loop: asyncio.AbstractEventLoop,
807     executor: Executor,
808 ) -> None:
809     """Run formatting of `sources` in parallel using the provided `executor`.
810
811     (Use ProcessPoolExecutors for actual parallelism.)
812
813     `write_back`, `fast`, and `mode` options are passed to
814     :func:`format_file_in_place`.
815     """
816     cache: Cache = {}
817     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
818         cache = read_cache(mode)
819         sources, cached = filter_cached(cache, sources)
820         for src in sorted(cached):
821             report.done(src, Changed.CACHED)
822     if not sources:
823         return
824
825     cancelled = []
826     sources_to_cache = []
827     lock = None
828     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
829         # For diff output, we need locks to ensure we don't interleave output
830         # from different processes.
831         manager = Manager()
832         lock = manager.Lock()
833     tasks = {
834         asyncio.ensure_future(
835             loop.run_in_executor(
836                 executor, format_file_in_place, src, fast, mode, write_back, lock
837             )
838         ): src
839         for src in sorted(sources)
840     }
841     pending = tasks.keys()
842     try:
843         loop.add_signal_handler(signal.SIGINT, cancel, pending)
844         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
845     except NotImplementedError:
846         # There are no good alternatives for these on Windows.
847         pass
848     while pending:
849         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
850         for task in done:
851             src = tasks.pop(task)
852             if task.cancelled():
853                 cancelled.append(task)
854             elif task.exception():
855                 report.failed(src, str(task.exception()))
856             else:
857                 changed = Changed.YES if task.result() else Changed.NO
858                 # If the file was written back or was successfully checked as
859                 # well-formatted, store this information in the cache.
860                 if write_back is WriteBack.YES or (
861                     write_back is WriteBack.CHECK and changed is Changed.NO
862                 ):
863                     sources_to_cache.append(src)
864                 report.done(src, changed)
865     if cancelled:
866         if sys.version_info >= (3, 7):
867             await asyncio.gather(*cancelled, return_exceptions=True)
868         else:
869             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
870     if sources_to_cache:
871         write_cache(cache, sources_to_cache, mode)
872
873
874 def format_file_in_place(
875     src: Path,
876     fast: bool,
877     mode: Mode,
878     write_back: WriteBack = WriteBack.NO,
879     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
880 ) -> bool:
881     """Format file under `src` path. Return True if changed.
882
883     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
884     code to the file.
885     `mode` and `fast` options are passed to :func:`format_file_contents`.
886     """
887     if src.suffix == ".pyi":
888         mode = replace(mode, is_pyi=True)
889     elif src.suffix == ".ipynb":
890         mode = replace(mode, is_ipynb=True)
891
892     then = datetime.utcfromtimestamp(src.stat().st_mtime)
893     with open(src, "rb") as buf:
894         src_contents, encoding, newline = decode_bytes(buf.read())
895     try:
896         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
897     except NothingChanged:
898         return False
899     except JSONDecodeError:
900         raise ValueError(
901             f"File '{src}' cannot be parsed as valid Jupyter notebook."
902         ) from None
903
904     if write_back == WriteBack.YES:
905         with open(src, "w", encoding=encoding, newline=newline) as f:
906             f.write(dst_contents)
907     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
908         now = datetime.utcnow()
909         src_name = f"{src}\t{then} +0000"
910         dst_name = f"{src}\t{now} +0000"
911         if mode.is_ipynb:
912             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
913         else:
914             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
915
916         if write_back == WriteBack.COLOR_DIFF:
917             diff_contents = color_diff(diff_contents)
918
919         with lock or nullcontext():
920             f = io.TextIOWrapper(
921                 sys.stdout.buffer,
922                 encoding=encoding,
923                 newline=newline,
924                 write_through=True,
925             )
926             f = wrap_stream_for_windows(f)
927             f.write(diff_contents)
928             f.detach()
929
930     return True
931
932
933 def format_stdin_to_stdout(
934     fast: bool,
935     *,
936     content: Optional[str] = None,
937     write_back: WriteBack = WriteBack.NO,
938     mode: Mode,
939 ) -> bool:
940     """Format file on stdin. Return True if changed.
941
942     If content is None, it's read from sys.stdin.
943
944     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
945     write a diff to stdout. The `mode` argument is passed to
946     :func:`format_file_contents`.
947     """
948     then = datetime.utcnow()
949
950     if content is None:
951         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
952     else:
953         src, encoding, newline = content, "utf-8", ""
954
955     dst = src
956     try:
957         dst = format_file_contents(src, fast=fast, mode=mode)
958         return True
959
960     except NothingChanged:
961         return False
962
963     finally:
964         f = io.TextIOWrapper(
965             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
966         )
967         if write_back == WriteBack.YES:
968             # Make sure there's a newline after the content
969             if dst and dst[-1] != "\n":
970                 dst += "\n"
971             f.write(dst)
972         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
973             now = datetime.utcnow()
974             src_name = f"STDIN\t{then} +0000"
975             dst_name = f"STDOUT\t{now} +0000"
976             d = diff(src, dst, src_name, dst_name)
977             if write_back == WriteBack.COLOR_DIFF:
978                 d = color_diff(d)
979                 f = wrap_stream_for_windows(f)
980             f.write(d)
981         f.detach()
982
983
984 def check_stability_and_equivalence(
985     src_contents: str, dst_contents: str, *, mode: Mode
986 ) -> None:
987     """Perform stability and equivalence checks.
988
989     Raise AssertionError if source and destination contents are not
990     equivalent, or if a second pass of the formatter would format the
991     content differently.
992     """
993     assert_equivalent(src_contents, dst_contents)
994     assert_stable(src_contents, dst_contents, mode=mode)
995
996
997 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
998     """Reformat contents of a file and return new contents.
999
1000     If `fast` is False, additionally confirm that the reformatted code is
1001     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
1002     `mode` is passed to :func:`format_str`.
1003     """
1004     if not src_contents.strip():
1005         raise NothingChanged
1006
1007     if mode.is_ipynb:
1008         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
1009     else:
1010         dst_contents = format_str(src_contents, mode=mode)
1011     if src_contents == dst_contents:
1012         raise NothingChanged
1013
1014     if not fast and not mode.is_ipynb:
1015         # Jupyter notebooks will already have been checked above.
1016         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
1017     return dst_contents
1018
1019
1020 def validate_cell(src: str, mode: Mode) -> None:
1021     """Check that cell does not already contain TransformerManager transformations,
1022     or non-Python cell magics, which might cause tokenizer_rt to break because of
1023     indentations.
1024
1025     If a cell contains ``!ls``, then it'll be transformed to
1026     ``get_ipython().system('ls')``. However, if the cell originally contained
1027     ``get_ipython().system('ls')``, then it would get transformed in the same way:
1028
1029         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
1030         "get_ipython().system('ls')\n"
1031         >>> TransformerManager().transform_cell("!ls")
1032         "get_ipython().system('ls')\n"
1033
1034     Due to the impossibility of safely roundtripping in such situations, cells
1035     containing transformed magics will be ignored.
1036     """
1037     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
1038         raise NothingChanged
1039     if (
1040         src[:2] == "%%"
1041         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
1042     ):
1043         raise NothingChanged
1044
1045
1046 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
1047     """Format code in given cell of Jupyter notebook.
1048
1049     General idea is:
1050
1051       - if cell has trailing semicolon, remove it;
1052       - if cell has IPython magics, mask them;
1053       - format cell;
1054       - reinstate IPython magics;
1055       - reinstate trailing semicolon (if originally present);
1056       - strip trailing newlines.
1057
1058     Cells with syntax errors will not be processed, as they
1059     could potentially be automagics or multi-line magics, which
1060     are currently not supported.
1061     """
1062     validate_cell(src, mode)
1063     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
1064         src
1065     )
1066     try:
1067         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
1068     except SyntaxError:
1069         raise NothingChanged from None
1070     masked_dst = format_str(masked_src, mode=mode)
1071     if not fast:
1072         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
1073     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1074     dst = put_trailing_semicolon_back(
1075         dst_without_trailing_semicolon, has_trailing_semicolon
1076     )
1077     dst = dst.rstrip("\n")
1078     if dst == src:
1079         raise NothingChanged from None
1080     return dst
1081
1082
1083 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1084     """If notebook is marked as non-Python, don't format it.
1085
1086     All notebook metadata fields are optional, see
1087     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1088     if a notebook has empty metadata, we will try to parse it anyway.
1089     """
1090     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1091     if language is not None and language != "python":
1092         raise NothingChanged from None
1093
1094
1095 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1096     """Format Jupyter notebook.
1097
1098     Operate cell-by-cell, only on code cells, only for Python notebooks.
1099     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1100     """
1101     trailing_newline = src_contents[-1] == "\n"
1102     modified = False
1103     nb = json.loads(src_contents)
1104     validate_metadata(nb)
1105     for cell in nb["cells"]:
1106         if cell.get("cell_type", None) == "code":
1107             try:
1108                 src = "".join(cell["source"])
1109                 dst = format_cell(src, fast=fast, mode=mode)
1110             except NothingChanged:
1111                 pass
1112             else:
1113                 cell["source"] = dst.splitlines(keepends=True)
1114                 modified = True
1115     if modified:
1116         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1117         if trailing_newline:
1118             dst_contents = dst_contents + "\n"
1119         return dst_contents
1120     else:
1121         raise NothingChanged
1122
1123
1124 def format_str(src_contents: str, *, mode: Mode) -> str:
1125     """Reformat a string and return new contents.
1126
1127     `mode` determines formatting options, such as how many characters per line are
1128     allowed.  Example:
1129
1130     >>> import black
1131     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1132     def f(arg: str = "") -> None:
1133         ...
1134
1135     A more complex example:
1136
1137     >>> print(
1138     ...   black.format_str(
1139     ...     "def f(arg:str='')->None: hey",
1140     ...     mode=black.Mode(
1141     ...       target_versions={black.TargetVersion.PY36},
1142     ...       line_length=10,
1143     ...       string_normalization=False,
1144     ...       is_pyi=False,
1145     ...     ),
1146     ...   ),
1147     ... )
1148     def f(
1149         arg: str = '',
1150     ) -> None:
1151         hey
1152
1153     """
1154     dst_contents = _format_str_once(src_contents, mode=mode)
1155     # Forced second pass to work around optional trailing commas (becoming
1156     # forced trailing commas on pass 2) interacting differently with optional
1157     # parentheses.  Admittedly ugly.
1158     if src_contents != dst_contents:
1159         return _format_str_once(dst_contents, mode=mode)
1160     return dst_contents
1161
1162
1163 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1164     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1165     dst_contents = []
1166     future_imports = get_future_imports(src_node)
1167     if mode.target_versions:
1168         versions = mode.target_versions
1169     else:
1170         versions = detect_target_versions(src_node, future_imports=future_imports)
1171
1172     normalize_fmt_off(src_node, preview=mode.preview)
1173     lines = LineGenerator(mode=mode)
1174     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1175     empty_line = Line(mode=mode)
1176     after = 0
1177     split_line_features = {
1178         feature
1179         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1180         if supports_feature(versions, feature)
1181     }
1182     for current_line in lines.visit(src_node):
1183         dst_contents.append(str(empty_line) * after)
1184         before, after = elt.maybe_empty_lines(current_line)
1185         dst_contents.append(str(empty_line) * before)
1186         for line in transform_line(
1187             current_line, mode=mode, features=split_line_features
1188         ):
1189             dst_contents.append(str(line))
1190     return "".join(dst_contents)
1191
1192
1193 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1194     """Return a tuple of (decoded_contents, encoding, newline).
1195
1196     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1197     universal newlines (i.e. only contains LF).
1198     """
1199     srcbuf = io.BytesIO(src)
1200     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1201     if not lines:
1202         return "", encoding, "\n"
1203
1204     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1205     srcbuf.seek(0)
1206     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1207         return tiow.read(), encoding, newline
1208
1209
1210 def get_features_used(  # noqa: C901
1211     node: Node, *, future_imports: Optional[Set[str]] = None
1212 ) -> Set[Feature]:
1213     """Return a set of (relatively) new Python features used in this file.
1214
1215     Currently looking for:
1216     - f-strings;
1217     - underscores in numeric literals;
1218     - trailing commas after * or ** in function signatures and calls;
1219     - positional only arguments in function signatures and lambdas;
1220     - assignment expression;
1221     - relaxed decorator syntax;
1222     - usage of __future__ flags (annotations);
1223     - print / exec statements;
1224     """
1225     features: Set[Feature] = set()
1226     if future_imports:
1227         features |= {
1228             FUTURE_FLAG_TO_FEATURE[future_import]
1229             for future_import in future_imports
1230             if future_import in FUTURE_FLAG_TO_FEATURE
1231         }
1232
1233     for n in node.pre_order():
1234         if is_string_token(n):
1235             value_head = n.value[:2]
1236             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1237                 features.add(Feature.F_STRINGS)
1238
1239         elif n.type == token.NUMBER:
1240             assert isinstance(n, Leaf)
1241             if "_" in n.value:
1242                 features.add(Feature.NUMERIC_UNDERSCORES)
1243
1244         elif n.type == token.SLASH:
1245             if n.parent and n.parent.type in {
1246                 syms.typedargslist,
1247                 syms.arglist,
1248                 syms.varargslist,
1249             }:
1250                 features.add(Feature.POS_ONLY_ARGUMENTS)
1251
1252         elif n.type == token.COLONEQUAL:
1253             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1254
1255         elif n.type == syms.decorator:
1256             if len(n.children) > 1 and not is_simple_decorator_expression(
1257                 n.children[1]
1258             ):
1259                 features.add(Feature.RELAXED_DECORATORS)
1260
1261         elif (
1262             n.type in {syms.typedargslist, syms.arglist}
1263             and n.children
1264             and n.children[-1].type == token.COMMA
1265         ):
1266             if n.type == syms.typedargslist:
1267                 feature = Feature.TRAILING_COMMA_IN_DEF
1268             else:
1269                 feature = Feature.TRAILING_COMMA_IN_CALL
1270
1271             for ch in n.children:
1272                 if ch.type in STARS:
1273                     features.add(feature)
1274
1275                 if ch.type == syms.argument:
1276                     for argch in ch.children:
1277                         if argch.type in STARS:
1278                             features.add(feature)
1279
1280         elif (
1281             n.type in {syms.return_stmt, syms.yield_expr}
1282             and len(n.children) >= 2
1283             and n.children[1].type == syms.testlist_star_expr
1284             and any(child.type == syms.star_expr for child in n.children[1].children)
1285         ):
1286             features.add(Feature.UNPACKING_ON_FLOW)
1287
1288         elif (
1289             n.type == syms.annassign
1290             and len(n.children) >= 4
1291             and n.children[3].type == syms.testlist_star_expr
1292         ):
1293             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1294
1295     return features
1296
1297
1298 def detect_target_versions(
1299     node: Node, *, future_imports: Optional[Set[str]] = None
1300 ) -> Set[TargetVersion]:
1301     """Detect the version to target based on the nodes used."""
1302     features = get_features_used(node, future_imports=future_imports)
1303     return {
1304         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1305     }
1306
1307
1308 def get_future_imports(node: Node) -> Set[str]:
1309     """Return a set of __future__ imports in the file."""
1310     imports: Set[str] = set()
1311
1312     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1313         for child in children:
1314             if isinstance(child, Leaf):
1315                 if child.type == token.NAME:
1316                     yield child.value
1317
1318             elif child.type == syms.import_as_name:
1319                 orig_name = child.children[0]
1320                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1321                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1322                 yield orig_name.value
1323
1324             elif child.type == syms.import_as_names:
1325                 yield from get_imports_from_children(child.children)
1326
1327             else:
1328                 raise AssertionError("Invalid syntax parsing imports")
1329
1330     for child in node.children:
1331         if child.type != syms.simple_stmt:
1332             break
1333
1334         first_child = child.children[0]
1335         if isinstance(first_child, Leaf):
1336             # Continue looking if we see a docstring; otherwise stop.
1337             if (
1338                 len(child.children) == 2
1339                 and first_child.type == token.STRING
1340                 and child.children[1].type == token.NEWLINE
1341             ):
1342                 continue
1343
1344             break
1345
1346         elif first_child.type == syms.import_from:
1347             module_name = first_child.children[1]
1348             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1349                 break
1350
1351             imports |= set(get_imports_from_children(first_child.children[3:]))
1352         else:
1353             break
1354
1355     return imports
1356
1357
1358 def assert_equivalent(src: str, dst: str) -> None:
1359     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1360     try:
1361         src_ast = parse_ast(src)
1362     except Exception as exc:
1363         raise AssertionError(
1364             "cannot use --safe with this file; failed to parse source file AST: "
1365             f"{exc}\n"
1366             "This could be caused by running Black with an older Python version "
1367             "that does not support new syntax used in your source file."
1368         ) from exc
1369
1370     try:
1371         dst_ast = parse_ast(dst)
1372     except Exception as exc:
1373         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1374         raise AssertionError(
1375             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1376             "Please report a bug on https://github.com/psf/black/issues.  "
1377             f"This invalid output might be helpful: {log}"
1378         ) from None
1379
1380     src_ast_str = "\n".join(stringify_ast(src_ast))
1381     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1382     if src_ast_str != dst_ast_str:
1383         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1384         raise AssertionError(
1385             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1386             " source.  Please report a bug on "
1387             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1388         ) from None
1389
1390
1391 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1392     """Raise AssertionError if `dst` reformats differently the second time."""
1393     # We shouldn't call format_str() here, because that formats the string
1394     # twice and may hide a bug where we bounce back and forth between two
1395     # versions.
1396     newdst = _format_str_once(dst, mode=mode)
1397     if dst != newdst:
1398         log = dump_to_file(
1399             str(mode),
1400             diff(src, dst, "source", "first pass"),
1401             diff(dst, newdst, "first pass", "second pass"),
1402         )
1403         raise AssertionError(
1404             "INTERNAL ERROR: Black produced different code on the second pass of the"
1405             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1406             f"  This diff might be helpful: {log}"
1407         ) from None
1408
1409
1410 @contextmanager
1411 def nullcontext() -> Iterator[None]:
1412     """Return an empty context manager.
1413
1414     To be used like `nullcontext` in Python 3.7.
1415     """
1416     yield
1417
1418
1419 def patch_click() -> None:
1420     """Make Click not crash on Python 3.6 with LANG=C.
1421
1422     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1423     default which restricts paths that it can access during the lifetime of the
1424     application.  Click refuses to work in this scenario by raising a RuntimeError.
1425
1426     In case of Black the likelihood that non-ASCII characters are going to be used in
1427     file paths is minimal since it's Python source code.  Moreover, this crash was
1428     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1429     """
1430     try:
1431         from click import core
1432         from click import _unicodefun
1433     except ModuleNotFoundError:
1434         return
1435
1436     for module in (core, _unicodefun):
1437         if hasattr(module, "_verify_python3_env"):
1438             module._verify_python3_env = lambda: None  # type: ignore
1439         if hasattr(module, "_verify_python_env"):
1440             module._verify_python_env = lambda: None  # type: ignore
1441
1442
1443 def patched_main() -> None:
1444     maybe_install_uvloop()
1445     freeze_support()
1446     patch_click()
1447     main()
1448
1449
1450 if __name__ == "__main__":
1451     patched_main()