]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

51e31e9598b29bd250fbba49b48ad823e5328eeb
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Sized,
30     Tuple,
31     Union,
32 )
33
34 import click
35 from click.core import ParameterSource
36 from dataclasses import replace
37 from mypy_extensions import mypyc_attr
38
39 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
40 from black.const import STDIN_PLACEHOLDER
41 from black.nodes import STARS, syms, is_simple_decorator_expression
42 from black.nodes import is_string_token
43 from black.lines import Line, EmptyLineTracker
44 from black.linegen import transform_line, LineGenerator, LN
45 from black.comments import normalize_fmt_off
46 from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
47 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
48 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
49 from black.concurrency import cancel, shutdown, maybe_install_uvloop
50 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
51 from black.report import Report, Changed, NothingChanged
52 from black.files import (
53     find_project_root,
54     find_pyproject_toml,
55     parse_pyproject_toml,
56     find_user_pyproject_toml,
57 )
58 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
59 from black.files import wrap_stream_for_windows
60 from black.parsing import InvalidInput  # noqa F401
61 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
62 from black.handle_ipynb_magics import (
63     mask_cell,
64     unmask_cell,
65     remove_trailing_semicolon,
66     put_trailing_semicolon_back,
67     TRANSFORMED_MAGICS,
68     PYTHON_CELL_MAGICS,
69     jupyter_dependencies_are_installed,
70 )
71
72
73 # lib2to3 fork
74 from blib2to3.pytree import Node, Leaf
75 from blib2to3.pgen2 import token
76
77 from _black_version import version as __version__
78
79 COMPILED = Path(__file__).suffix in (".pyd", ".so")
80
81 # types
82 FileContent = str
83 Encoding = str
84 NewLine = str
85
86
87 class WriteBack(Enum):
88     NO = 0
89     YES = 1
90     DIFF = 2
91     CHECK = 3
92     COLOR_DIFF = 4
93
94     @classmethod
95     def from_configuration(
96         cls, *, check: bool, diff: bool, color: bool = False
97     ) -> "WriteBack":
98         if check and not diff:
99             return cls.CHECK
100
101         if diff and color:
102             return cls.COLOR_DIFF
103
104         return cls.DIFF if diff else cls.YES
105
106
107 # Legacy name, left for integrations.
108 FileMode = Mode
109
110 DEFAULT_WORKERS = os.cpu_count()
111
112
113 def read_pyproject_toml(
114     ctx: click.Context, param: click.Parameter, value: Optional[str]
115 ) -> Optional[str]:
116     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
117
118     Returns the path to a successfully found and read configuration file, None
119     otherwise.
120     """
121     if not value:
122         value = find_pyproject_toml(ctx.params.get("src", ()))
123         if value is None:
124             return None
125
126     try:
127         config = parse_pyproject_toml(value)
128     except (OSError, ValueError) as e:
129         raise click.FileError(
130             filename=value, hint=f"Error reading configuration file: {e}"
131         ) from None
132
133     if not config:
134         return None
135     else:
136         # Sanitize the values to be Click friendly. For more information please see:
137         # https://github.com/psf/black/issues/1458
138         # https://github.com/pallets/click/issues/1567
139         config = {
140             k: str(v) if not isinstance(v, (list, dict)) else v
141             for k, v in config.items()
142         }
143
144     target_version = config.get("target_version")
145     if target_version is not None and not isinstance(target_version, list):
146         raise click.BadOptionUsage(
147             "target-version", "Config key target-version must be a list"
148         )
149
150     default_map: Dict[str, Any] = {}
151     if ctx.default_map:
152         default_map.update(ctx.default_map)
153     default_map.update(config)
154
155     ctx.default_map = default_map
156     return value
157
158
159 def target_version_option_callback(
160     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
161 ) -> List[TargetVersion]:
162     """Compute the target versions from a --target-version flag.
163
164     This is its own function because mypy couldn't infer the type correctly
165     when it was a lambda, causing mypyc trouble.
166     """
167     return [TargetVersion[val.upper()] for val in v]
168
169
170 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
171     """Compile a regular expression string in `regex`.
172
173     If it contains newlines, use verbose mode.
174     """
175     if "\n" in regex:
176         regex = "(?x)" + regex
177     compiled: Pattern[str] = re.compile(regex)
178     return compiled
179
180
181 def validate_regex(
182     ctx: click.Context,
183     param: click.Parameter,
184     value: Optional[str],
185 ) -> Optional[Pattern[str]]:
186     try:
187         return re_compile_maybe_verbose(value) if value is not None else None
188     except re.error as e:
189         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
190
191
192 @click.command(
193     context_settings={"help_option_names": ["-h", "--help"]},
194     # While Click does set this field automatically using the docstring, mypyc
195     # (annoyingly) strips 'em so we need to set it here too.
196     help="The uncompromising code formatter.",
197 )
198 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
199 @click.option(
200     "-l",
201     "--line-length",
202     type=int,
203     default=DEFAULT_LINE_LENGTH,
204     help="How many characters per line to allow.",
205     show_default=True,
206 )
207 @click.option(
208     "-t",
209     "--target-version",
210     type=click.Choice([v.name.lower() for v in TargetVersion]),
211     callback=target_version_option_callback,
212     multiple=True,
213     help=(
214         "Python versions that should be supported by Black's output. [default: per-file"
215         " auto-detection]"
216     ),
217 )
218 @click.option(
219     "--pyi",
220     is_flag=True,
221     help=(
222         "Format all input files like typing stubs regardless of file extension (useful"
223         " when piping source on standard input)."
224     ),
225 )
226 @click.option(
227     "--ipynb",
228     is_flag=True,
229     help=(
230         "Format all input files like Jupyter Notebooks regardless of file extension "
231         "(useful when piping source on standard input)."
232     ),
233 )
234 @click.option(
235     "--python-cell-magics",
236     multiple=True,
237     help=(
238         "When processing Jupyter Notebooks, add the given magic to the list"
239         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
240         " Useful for formatting cells with custom python magics."
241     ),
242     default=[],
243 )
244 @click.option(
245     "-S",
246     "--skip-string-normalization",
247     is_flag=True,
248     help="Don't normalize string quotes or prefixes.",
249 )
250 @click.option(
251     "-C",
252     "--skip-magic-trailing-comma",
253     is_flag=True,
254     help="Don't use trailing commas as a reason to split lines.",
255 )
256 @click.option(
257     "--experimental-string-processing",
258     is_flag=True,
259     hidden=True,
260     help="(DEPRECATED and now included in --preview) Normalize string literals.",
261 )
262 @click.option(
263     "--preview",
264     is_flag=True,
265     help=(
266         "Enable potentially disruptive style changes that may be added to Black's main"
267         " functionality in the next major release."
268     ),
269 )
270 @click.option(
271     "--check",
272     is_flag=True,
273     help=(
274         "Don't write the files back, just return the status. Return code 0 means"
275         " nothing would change. Return code 1 means some files would be reformatted."
276         " Return code 123 means there was an internal error."
277     ),
278 )
279 @click.option(
280     "--diff",
281     is_flag=True,
282     help="Don't write the files back, just output a diff for each file on stdout.",
283 )
284 @click.option(
285     "--color/--no-color",
286     is_flag=True,
287     help="Show colored diff. Only applies when `--diff` is given.",
288 )
289 @click.option(
290     "--fast/--safe",
291     is_flag=True,
292     help="If --fast given, skip temporary sanity checks. [default: --safe]",
293 )
294 @click.option(
295     "--required-version",
296     type=str,
297     help=(
298         "Require a specific version of Black to be running (useful for unifying results"
299         " across many environments e.g. with a pyproject.toml file). It can be"
300         " either a major version number or an exact version."
301     ),
302 )
303 @click.option(
304     "--include",
305     type=str,
306     default=DEFAULT_INCLUDES,
307     callback=validate_regex,
308     help=(
309         "A regular expression that matches files and directories that should be"
310         " included on recursive searches. An empty value means all files are included"
311         " regardless of the name. Use forward slashes for directories on all platforms"
312         " (Windows, too). Exclusions are calculated first, inclusions later."
313     ),
314     show_default=True,
315 )
316 @click.option(
317     "--exclude",
318     type=str,
319     callback=validate_regex,
320     help=(
321         "A regular expression that matches files and directories that should be"
322         " excluded on recursive searches. An empty value means no paths are excluded."
323         " Use forward slashes for directories on all platforms (Windows, too)."
324         " Exclusions are calculated first, inclusions later. [default:"
325         f" {DEFAULT_EXCLUDES}]"
326     ),
327     show_default=False,
328 )
329 @click.option(
330     "--extend-exclude",
331     type=str,
332     callback=validate_regex,
333     help=(
334         "Like --exclude, but adds additional files and directories on top of the"
335         " excluded ones. (Useful if you simply want to add to the default)"
336     ),
337 )
338 @click.option(
339     "--force-exclude",
340     type=str,
341     callback=validate_regex,
342     help=(
343         "Like --exclude, but files and directories matching this regex will be "
344         "excluded even when they are passed explicitly as arguments."
345     ),
346 )
347 @click.option(
348     "--stdin-filename",
349     type=str,
350     help=(
351         "The name of the file when passing it through stdin. Useful to make "
352         "sure Black will respect --force-exclude option on some "
353         "editors that rely on using stdin."
354     ),
355 )
356 @click.option(
357     "-W",
358     "--workers",
359     type=click.IntRange(min=1),
360     default=DEFAULT_WORKERS,
361     show_default=True,
362     help="Number of parallel workers",
363 )
364 @click.option(
365     "-q",
366     "--quiet",
367     is_flag=True,
368     help=(
369         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
370         " those with 2>/dev/null."
371     ),
372 )
373 @click.option(
374     "-v",
375     "--verbose",
376     is_flag=True,
377     help=(
378         "Also emit messages to stderr about files that were not changed or were ignored"
379         " due to exclusion patterns."
380     ),
381 )
382 @click.version_option(
383     version=__version__,
384     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
385 )
386 @click.argument(
387     "src",
388     nargs=-1,
389     type=click.Path(
390         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
391     ),
392     is_eager=True,
393     metavar="SRC ...",
394 )
395 @click.option(
396     "--config",
397     type=click.Path(
398         exists=True,
399         file_okay=True,
400         dir_okay=False,
401         readable=True,
402         allow_dash=False,
403         path_type=str,
404     ),
405     is_eager=True,
406     callback=read_pyproject_toml,
407     help="Read configuration from FILE path.",
408 )
409 @click.pass_context
410 def main(  # noqa: C901
411     ctx: click.Context,
412     code: Optional[str],
413     line_length: int,
414     target_version: List[TargetVersion],
415     check: bool,
416     diff: bool,
417     color: bool,
418     fast: bool,
419     pyi: bool,
420     ipynb: bool,
421     python_cell_magics: Sequence[str],
422     skip_string_normalization: bool,
423     skip_magic_trailing_comma: bool,
424     experimental_string_processing: bool,
425     preview: bool,
426     quiet: bool,
427     verbose: bool,
428     required_version: Optional[str],
429     include: Pattern[str],
430     exclude: Optional[Pattern[str]],
431     extend_exclude: Optional[Pattern[str]],
432     force_exclude: Optional[Pattern[str]],
433     stdin_filename: Optional[str],
434     workers: int,
435     src: Tuple[str, ...],
436     config: Optional[str],
437 ) -> None:
438     """The uncompromising code formatter."""
439     ctx.ensure_object(dict)
440
441     if src and code is not None:
442         out(
443             main.get_usage(ctx)
444             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
445         )
446         ctx.exit(1)
447     if not src and code is None:
448         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
449         ctx.exit(1)
450
451     root, method = find_project_root(src) if code is None else (None, None)
452     ctx.obj["root"] = root
453
454     if verbose:
455         if root:
456             out(
457                 f"Identified `{root}` as project root containing a {method}.",
458                 fg="blue",
459             )
460
461             normalized = [
462                 (normalize_path_maybe_ignore(Path(source), root), source)
463                 for source in src
464             ]
465             srcs_string = ", ".join(
466                 [
467                     f'"{_norm}"'
468                     if _norm
469                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
470                     for _norm, source in normalized
471                 ]
472             )
473             out(f"Sources to be formatted: {srcs_string}", fg="blue")
474
475         if config:
476             config_source = ctx.get_parameter_source("config")
477             user_level_config = str(find_user_pyproject_toml())
478             if config == user_level_config:
479                 out(
480                     "Using configuration from user-level config at "
481                     f"'{user_level_config}'.",
482                     fg="blue",
483                 )
484             elif config_source in (
485                 ParameterSource.DEFAULT,
486                 ParameterSource.DEFAULT_MAP,
487             ):
488                 out("Using configuration from project root.", fg="blue")
489             else:
490                 out(f"Using configuration in '{config}'.", fg="blue")
491
492     error_msg = "Oh no! 💥 💔 💥"
493     if (
494         required_version
495         and required_version != __version__
496         and required_version != __version__.split(".")[0]
497     ):
498         err(
499             f"{error_msg} The required version `{required_version}` does not match"
500             f" the running version `{__version__}`!"
501         )
502         ctx.exit(1)
503     if ipynb and pyi:
504         err("Cannot pass both `pyi` and `ipynb` flags!")
505         ctx.exit(1)
506
507     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
508     if target_version:
509         versions = set(target_version)
510     else:
511         # We'll autodetect later.
512         versions = set()
513     mode = Mode(
514         target_versions=versions,
515         line_length=line_length,
516         is_pyi=pyi,
517         is_ipynb=ipynb,
518         string_normalization=not skip_string_normalization,
519         magic_trailing_comma=not skip_magic_trailing_comma,
520         experimental_string_processing=experimental_string_processing,
521         preview=preview,
522         python_cell_magics=set(python_cell_magics),
523     )
524
525     if code is not None:
526         # Run in quiet mode by default with -c; the extra output isn't useful.
527         # You can still pass -v to get verbose output.
528         quiet = True
529
530     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
531
532     if code is not None:
533         reformat_code(
534             content=code, fast=fast, write_back=write_back, mode=mode, report=report
535         )
536     else:
537         try:
538             sources = get_sources(
539                 ctx=ctx,
540                 src=src,
541                 quiet=quiet,
542                 verbose=verbose,
543                 include=include,
544                 exclude=exclude,
545                 extend_exclude=extend_exclude,
546                 force_exclude=force_exclude,
547                 report=report,
548                 stdin_filename=stdin_filename,
549             )
550         except GitWildMatchPatternError:
551             ctx.exit(1)
552
553         path_empty(
554             sources,
555             "No Python files are present to be formatted. Nothing to do 😴",
556             quiet,
557             verbose,
558             ctx,
559         )
560
561         if len(sources) == 1:
562             reformat_one(
563                 src=sources.pop(),
564                 fast=fast,
565                 write_back=write_back,
566                 mode=mode,
567                 report=report,
568             )
569         else:
570             reformat_many(
571                 sources=sources,
572                 fast=fast,
573                 write_back=write_back,
574                 mode=mode,
575                 report=report,
576                 workers=workers,
577             )
578
579     if verbose or not quiet:
580         if code is None and (verbose or report.change_count or report.failure_count):
581             out()
582         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
583         if code is None:
584             click.echo(str(report), err=True)
585     ctx.exit(report.return_code)
586
587
588 def get_sources(
589     *,
590     ctx: click.Context,
591     src: Tuple[str, ...],
592     quiet: bool,
593     verbose: bool,
594     include: Pattern[str],
595     exclude: Optional[Pattern[str]],
596     extend_exclude: Optional[Pattern[str]],
597     force_exclude: Optional[Pattern[str]],
598     report: "Report",
599     stdin_filename: Optional[str],
600 ) -> Set[Path]:
601     """Compute the set of files to be formatted."""
602     sources: Set[Path] = set()
603
604     if exclude is None:
605         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
606         gitignore = get_gitignore(ctx.obj["root"])
607     else:
608         gitignore = None
609
610     for s in src:
611         if s == "-" and stdin_filename:
612             p = Path(stdin_filename)
613             is_stdin = True
614         else:
615             p = Path(s)
616             is_stdin = False
617
618         if is_stdin or p.is_file():
619             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
620             if normalized_path is None:
621                 continue
622
623             normalized_path = "/" + normalized_path
624             # Hard-exclude any files that matches the `--force-exclude` regex.
625             if force_exclude:
626                 force_exclude_match = force_exclude.search(normalized_path)
627             else:
628                 force_exclude_match = None
629             if force_exclude_match and force_exclude_match.group(0):
630                 report.path_ignored(p, "matches the --force-exclude regular expression")
631                 continue
632
633             if is_stdin:
634                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
635
636             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
637                 verbose=verbose, quiet=quiet
638             ):
639                 continue
640
641             sources.add(p)
642         elif p.is_dir():
643             sources.update(
644                 gen_python_files(
645                     p.iterdir(),
646                     ctx.obj["root"],
647                     include,
648                     exclude,
649                     extend_exclude,
650                     force_exclude,
651                     report,
652                     gitignore,
653                     verbose=verbose,
654                     quiet=quiet,
655                 )
656             )
657         elif s == "-":
658             sources.add(p)
659         else:
660             err(f"invalid path: {s}")
661     return sources
662
663
664 def path_empty(
665     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
666 ) -> None:
667     """
668     Exit if there is no `src` provided for formatting
669     """
670     if not src:
671         if verbose or not quiet:
672             out(msg)
673         ctx.exit(0)
674
675
676 def reformat_code(
677     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
678 ) -> None:
679     """
680     Reformat and print out `content` without spawning child processes.
681     Similar to `reformat_one`, but for string content.
682
683     `fast`, `write_back`, and `mode` options are passed to
684     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
685     """
686     path = Path("<string>")
687     try:
688         changed = Changed.NO
689         if format_stdin_to_stdout(
690             content=content, fast=fast, write_back=write_back, mode=mode
691         ):
692             changed = Changed.YES
693         report.done(path, changed)
694     except Exception as exc:
695         if report.verbose:
696             traceback.print_exc()
697         report.failed(path, str(exc))
698
699
700 def reformat_one(
701     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
702 ) -> None:
703     """Reformat a single file under `src` without spawning child processes.
704
705     `fast`, `write_back`, and `mode` options are passed to
706     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
707     """
708     try:
709         changed = Changed.NO
710
711         if str(src) == "-":
712             is_stdin = True
713         elif str(src).startswith(STDIN_PLACEHOLDER):
714             is_stdin = True
715             # Use the original name again in case we want to print something
716             # to the user
717             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
718         else:
719             is_stdin = False
720
721         if is_stdin:
722             if src.suffix == ".pyi":
723                 mode = replace(mode, is_pyi=True)
724             elif src.suffix == ".ipynb":
725                 mode = replace(mode, is_ipynb=True)
726             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
727                 changed = Changed.YES
728         else:
729             cache: Cache = {}
730             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
731                 cache = read_cache(mode)
732                 res_src = src.resolve()
733                 res_src_s = str(res_src)
734                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
735                     changed = Changed.CACHED
736             if changed is not Changed.CACHED and format_file_in_place(
737                 src, fast=fast, write_back=write_back, mode=mode
738             ):
739                 changed = Changed.YES
740             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
741                 write_back is WriteBack.CHECK and changed is Changed.NO
742             ):
743                 write_cache(cache, [src], mode)
744         report.done(src, changed)
745     except Exception as exc:
746         if report.verbose:
747             traceback.print_exc()
748         report.failed(src, str(exc))
749
750
751 # diff-shades depends on being to monkeypatch this function to operate. I know it's
752 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
753 @mypyc_attr(patchable=True)
754 def reformat_many(
755     sources: Set[Path],
756     fast: bool,
757     write_back: WriteBack,
758     mode: Mode,
759     report: "Report",
760     workers: Optional[int],
761 ) -> None:
762     """Reformat multiple files using a ProcessPoolExecutor."""
763     executor: Executor
764     loop = asyncio.get_event_loop()
765     worker_count = workers if workers is not None else DEFAULT_WORKERS
766     if sys.platform == "win32":
767         # Work around https://bugs.python.org/issue26903
768         assert worker_count is not None
769         worker_count = min(worker_count, 60)
770     try:
771         executor = ProcessPoolExecutor(max_workers=worker_count)
772     except (ImportError, NotImplementedError, OSError):
773         # we arrive here if the underlying system does not support multi-processing
774         # like in AWS Lambda or Termux, in which case we gracefully fallback to
775         # a ThreadPoolExecutor with just a single worker (more workers would not do us
776         # any good due to the Global Interpreter Lock)
777         executor = ThreadPoolExecutor(max_workers=1)
778
779     try:
780         loop.run_until_complete(
781             schedule_formatting(
782                 sources=sources,
783                 fast=fast,
784                 write_back=write_back,
785                 mode=mode,
786                 report=report,
787                 loop=loop,
788                 executor=executor,
789             )
790         )
791     finally:
792         shutdown(loop)
793         if executor is not None:
794             executor.shutdown()
795
796
797 async def schedule_formatting(
798     sources: Set[Path],
799     fast: bool,
800     write_back: WriteBack,
801     mode: Mode,
802     report: "Report",
803     loop: asyncio.AbstractEventLoop,
804     executor: Executor,
805 ) -> None:
806     """Run formatting of `sources` in parallel using the provided `executor`.
807
808     (Use ProcessPoolExecutors for actual parallelism.)
809
810     `write_back`, `fast`, and `mode` options are passed to
811     :func:`format_file_in_place`.
812     """
813     cache: Cache = {}
814     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
815         cache = read_cache(mode)
816         sources, cached = filter_cached(cache, sources)
817         for src in sorted(cached):
818             report.done(src, Changed.CACHED)
819     if not sources:
820         return
821
822     cancelled = []
823     sources_to_cache = []
824     lock = None
825     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
826         # For diff output, we need locks to ensure we don't interleave output
827         # from different processes.
828         manager = Manager()
829         lock = manager.Lock()
830     tasks = {
831         asyncio.ensure_future(
832             loop.run_in_executor(
833                 executor, format_file_in_place, src, fast, mode, write_back, lock
834             )
835         ): src
836         for src in sorted(sources)
837     }
838     pending = tasks.keys()
839     try:
840         loop.add_signal_handler(signal.SIGINT, cancel, pending)
841         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
842     except NotImplementedError:
843         # There are no good alternatives for these on Windows.
844         pass
845     while pending:
846         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
847         for task in done:
848             src = tasks.pop(task)
849             if task.cancelled():
850                 cancelled.append(task)
851             elif task.exception():
852                 report.failed(src, str(task.exception()))
853             else:
854                 changed = Changed.YES if task.result() else Changed.NO
855                 # If the file was written back or was successfully checked as
856                 # well-formatted, store this information in the cache.
857                 if write_back is WriteBack.YES or (
858                     write_back is WriteBack.CHECK and changed is Changed.NO
859                 ):
860                     sources_to_cache.append(src)
861                 report.done(src, changed)
862     if cancelled:
863         if sys.version_info >= (3, 7):
864             await asyncio.gather(*cancelled, return_exceptions=True)
865         else:
866             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
867     if sources_to_cache:
868         write_cache(cache, sources_to_cache, mode)
869
870
871 def format_file_in_place(
872     src: Path,
873     fast: bool,
874     mode: Mode,
875     write_back: WriteBack = WriteBack.NO,
876     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
877 ) -> bool:
878     """Format file under `src` path. Return True if changed.
879
880     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
881     code to the file.
882     `mode` and `fast` options are passed to :func:`format_file_contents`.
883     """
884     if src.suffix == ".pyi":
885         mode = replace(mode, is_pyi=True)
886     elif src.suffix == ".ipynb":
887         mode = replace(mode, is_ipynb=True)
888
889     then = datetime.utcfromtimestamp(src.stat().st_mtime)
890     with open(src, "rb") as buf:
891         src_contents, encoding, newline = decode_bytes(buf.read())
892     try:
893         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
894     except NothingChanged:
895         return False
896     except JSONDecodeError:
897         raise ValueError(
898             f"File '{src}' cannot be parsed as valid Jupyter notebook."
899         ) from None
900
901     if write_back == WriteBack.YES:
902         with open(src, "w", encoding=encoding, newline=newline) as f:
903             f.write(dst_contents)
904     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
905         now = datetime.utcnow()
906         src_name = f"{src}\t{then} +0000"
907         dst_name = f"{src}\t{now} +0000"
908         if mode.is_ipynb:
909             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
910         else:
911             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
912
913         if write_back == WriteBack.COLOR_DIFF:
914             diff_contents = color_diff(diff_contents)
915
916         with lock or nullcontext():
917             f = io.TextIOWrapper(
918                 sys.stdout.buffer,
919                 encoding=encoding,
920                 newline=newline,
921                 write_through=True,
922             )
923             f = wrap_stream_for_windows(f)
924             f.write(diff_contents)
925             f.detach()
926
927     return True
928
929
930 def format_stdin_to_stdout(
931     fast: bool,
932     *,
933     content: Optional[str] = None,
934     write_back: WriteBack = WriteBack.NO,
935     mode: Mode,
936 ) -> bool:
937     """Format file on stdin. Return True if changed.
938
939     If content is None, it's read from sys.stdin.
940
941     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
942     write a diff to stdout. The `mode` argument is passed to
943     :func:`format_file_contents`.
944     """
945     then = datetime.utcnow()
946
947     if content is None:
948         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
949     else:
950         src, encoding, newline = content, "utf-8", ""
951
952     dst = src
953     try:
954         dst = format_file_contents(src, fast=fast, mode=mode)
955         return True
956
957     except NothingChanged:
958         return False
959
960     finally:
961         f = io.TextIOWrapper(
962             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
963         )
964         if write_back == WriteBack.YES:
965             # Make sure there's a newline after the content
966             if dst and dst[-1] != "\n":
967                 dst += "\n"
968             f.write(dst)
969         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
970             now = datetime.utcnow()
971             src_name = f"STDIN\t{then} +0000"
972             dst_name = f"STDOUT\t{now} +0000"
973             d = diff(src, dst, src_name, dst_name)
974             if write_back == WriteBack.COLOR_DIFF:
975                 d = color_diff(d)
976                 f = wrap_stream_for_windows(f)
977             f.write(d)
978         f.detach()
979
980
981 def check_stability_and_equivalence(
982     src_contents: str, dst_contents: str, *, mode: Mode
983 ) -> None:
984     """Perform stability and equivalence checks.
985
986     Raise AssertionError if source and destination contents are not
987     equivalent, or if a second pass of the formatter would format the
988     content differently.
989     """
990     assert_equivalent(src_contents, dst_contents)
991     assert_stable(src_contents, dst_contents, mode=mode)
992
993
994 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
995     """Reformat contents of a file and return new contents.
996
997     If `fast` is False, additionally confirm that the reformatted code is
998     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
999     `mode` is passed to :func:`format_str`.
1000     """
1001     if not src_contents.strip():
1002         raise NothingChanged
1003
1004     if mode.is_ipynb:
1005         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
1006     else:
1007         dst_contents = format_str(src_contents, mode=mode)
1008     if src_contents == dst_contents:
1009         raise NothingChanged
1010
1011     if not fast and not mode.is_ipynb:
1012         # Jupyter notebooks will already have been checked above.
1013         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
1014     return dst_contents
1015
1016
1017 def validate_cell(src: str, mode: Mode) -> None:
1018     """Check that cell does not already contain TransformerManager transformations,
1019     or non-Python cell magics, which might cause tokenizer_rt to break because of
1020     indentations.
1021
1022     If a cell contains ``!ls``, then it'll be transformed to
1023     ``get_ipython().system('ls')``. However, if the cell originally contained
1024     ``get_ipython().system('ls')``, then it would get transformed in the same way:
1025
1026         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
1027         "get_ipython().system('ls')\n"
1028         >>> TransformerManager().transform_cell("!ls")
1029         "get_ipython().system('ls')\n"
1030
1031     Due to the impossibility of safely roundtripping in such situations, cells
1032     containing transformed magics will be ignored.
1033     """
1034     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
1035         raise NothingChanged
1036     if (
1037         src[:2] == "%%"
1038         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
1039     ):
1040         raise NothingChanged
1041
1042
1043 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
1044     """Format code in given cell of Jupyter notebook.
1045
1046     General idea is:
1047
1048       - if cell has trailing semicolon, remove it;
1049       - if cell has IPython magics, mask them;
1050       - format cell;
1051       - reinstate IPython magics;
1052       - reinstate trailing semicolon (if originally present);
1053       - strip trailing newlines.
1054
1055     Cells with syntax errors will not be processed, as they
1056     could potentially be automagics or multi-line magics, which
1057     are currently not supported.
1058     """
1059     validate_cell(src, mode)
1060     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
1061         src
1062     )
1063     try:
1064         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
1065     except SyntaxError:
1066         raise NothingChanged from None
1067     masked_dst = format_str(masked_src, mode=mode)
1068     if not fast:
1069         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
1070     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1071     dst = put_trailing_semicolon_back(
1072         dst_without_trailing_semicolon, has_trailing_semicolon
1073     )
1074     dst = dst.rstrip("\n")
1075     if dst == src:
1076         raise NothingChanged from None
1077     return dst
1078
1079
1080 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1081     """If notebook is marked as non-Python, don't format it.
1082
1083     All notebook metadata fields are optional, see
1084     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1085     if a notebook has empty metadata, we will try to parse it anyway.
1086     """
1087     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1088     if language is not None and language != "python":
1089         raise NothingChanged from None
1090
1091
1092 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1093     """Format Jupyter notebook.
1094
1095     Operate cell-by-cell, only on code cells, only for Python notebooks.
1096     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1097     """
1098     trailing_newline = src_contents[-1] == "\n"
1099     modified = False
1100     nb = json.loads(src_contents)
1101     validate_metadata(nb)
1102     for cell in nb["cells"]:
1103         if cell.get("cell_type", None) == "code":
1104             try:
1105                 src = "".join(cell["source"])
1106                 dst = format_cell(src, fast=fast, mode=mode)
1107             except NothingChanged:
1108                 pass
1109             else:
1110                 cell["source"] = dst.splitlines(keepends=True)
1111                 modified = True
1112     if modified:
1113         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1114         if trailing_newline:
1115             dst_contents = dst_contents + "\n"
1116         return dst_contents
1117     else:
1118         raise NothingChanged
1119
1120
1121 def format_str(src_contents: str, *, mode: Mode) -> str:
1122     """Reformat a string and return new contents.
1123
1124     `mode` determines formatting options, such as how many characters per line are
1125     allowed.  Example:
1126
1127     >>> import black
1128     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1129     def f(arg: str = "") -> None:
1130         ...
1131
1132     A more complex example:
1133
1134     >>> print(
1135     ...   black.format_str(
1136     ...     "def f(arg:str='')->None: hey",
1137     ...     mode=black.Mode(
1138     ...       target_versions={black.TargetVersion.PY36},
1139     ...       line_length=10,
1140     ...       string_normalization=False,
1141     ...       is_pyi=False,
1142     ...     ),
1143     ...   ),
1144     ... )
1145     def f(
1146         arg: str = '',
1147     ) -> None:
1148         hey
1149
1150     """
1151     dst_contents = _format_str_once(src_contents, mode=mode)
1152     # Forced second pass to work around optional trailing commas (becoming
1153     # forced trailing commas on pass 2) interacting differently with optional
1154     # parentheses.  Admittedly ugly.
1155     if src_contents != dst_contents:
1156         return _format_str_once(dst_contents, mode=mode)
1157     return dst_contents
1158
1159
1160 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1161     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1162     dst_contents = []
1163     future_imports = get_future_imports(src_node)
1164     if mode.target_versions:
1165         versions = mode.target_versions
1166     else:
1167         versions = detect_target_versions(src_node, future_imports=future_imports)
1168
1169     normalize_fmt_off(src_node, preview=mode.preview)
1170     lines = LineGenerator(mode=mode)
1171     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1172     empty_line = Line(mode=mode)
1173     after = 0
1174     split_line_features = {
1175         feature
1176         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1177         if supports_feature(versions, feature)
1178     }
1179     for current_line in lines.visit(src_node):
1180         dst_contents.append(str(empty_line) * after)
1181         before, after = elt.maybe_empty_lines(current_line)
1182         dst_contents.append(str(empty_line) * before)
1183         for line in transform_line(
1184             current_line, mode=mode, features=split_line_features
1185         ):
1186             dst_contents.append(str(line))
1187     return "".join(dst_contents)
1188
1189
1190 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1191     """Return a tuple of (decoded_contents, encoding, newline).
1192
1193     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1194     universal newlines (i.e. only contains LF).
1195     """
1196     srcbuf = io.BytesIO(src)
1197     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1198     if not lines:
1199         return "", encoding, "\n"
1200
1201     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1202     srcbuf.seek(0)
1203     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1204         return tiow.read(), encoding, newline
1205
1206
1207 def get_features_used(  # noqa: C901
1208     node: Node, *, future_imports: Optional[Set[str]] = None
1209 ) -> Set[Feature]:
1210     """Return a set of (relatively) new Python features used in this file.
1211
1212     Currently looking for:
1213     - f-strings;
1214     - underscores in numeric literals;
1215     - trailing commas after * or ** in function signatures and calls;
1216     - positional only arguments in function signatures and lambdas;
1217     - assignment expression;
1218     - relaxed decorator syntax;
1219     - usage of __future__ flags (annotations);
1220     - print / exec statements;
1221     """
1222     features: Set[Feature] = set()
1223     if future_imports:
1224         features |= {
1225             FUTURE_FLAG_TO_FEATURE[future_import]
1226             for future_import in future_imports
1227             if future_import in FUTURE_FLAG_TO_FEATURE
1228         }
1229
1230     for n in node.pre_order():
1231         if is_string_token(n):
1232             value_head = n.value[:2]
1233             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1234                 features.add(Feature.F_STRINGS)
1235
1236         elif n.type == token.NUMBER:
1237             assert isinstance(n, Leaf)
1238             if "_" in n.value:
1239                 features.add(Feature.NUMERIC_UNDERSCORES)
1240
1241         elif n.type == token.SLASH:
1242             if n.parent and n.parent.type in {
1243                 syms.typedargslist,
1244                 syms.arglist,
1245                 syms.varargslist,
1246             }:
1247                 features.add(Feature.POS_ONLY_ARGUMENTS)
1248
1249         elif n.type == token.COLONEQUAL:
1250             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1251
1252         elif n.type == syms.decorator:
1253             if len(n.children) > 1 and not is_simple_decorator_expression(
1254                 n.children[1]
1255             ):
1256                 features.add(Feature.RELAXED_DECORATORS)
1257
1258         elif (
1259             n.type in {syms.typedargslist, syms.arglist}
1260             and n.children
1261             and n.children[-1].type == token.COMMA
1262         ):
1263             if n.type == syms.typedargslist:
1264                 feature = Feature.TRAILING_COMMA_IN_DEF
1265             else:
1266                 feature = Feature.TRAILING_COMMA_IN_CALL
1267
1268             for ch in n.children:
1269                 if ch.type in STARS:
1270                     features.add(feature)
1271
1272                 if ch.type == syms.argument:
1273                     for argch in ch.children:
1274                         if argch.type in STARS:
1275                             features.add(feature)
1276
1277         elif (
1278             n.type in {syms.return_stmt, syms.yield_expr}
1279             and len(n.children) >= 2
1280             and n.children[1].type == syms.testlist_star_expr
1281             and any(child.type == syms.star_expr for child in n.children[1].children)
1282         ):
1283             features.add(Feature.UNPACKING_ON_FLOW)
1284
1285         elif (
1286             n.type == syms.annassign
1287             and len(n.children) >= 4
1288             and n.children[3].type == syms.testlist_star_expr
1289         ):
1290             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1291
1292     return features
1293
1294
1295 def detect_target_versions(
1296     node: Node, *, future_imports: Optional[Set[str]] = None
1297 ) -> Set[TargetVersion]:
1298     """Detect the version to target based on the nodes used."""
1299     features = get_features_used(node, future_imports=future_imports)
1300     return {
1301         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1302     }
1303
1304
1305 def get_future_imports(node: Node) -> Set[str]:
1306     """Return a set of __future__ imports in the file."""
1307     imports: Set[str] = set()
1308
1309     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1310         for child in children:
1311             if isinstance(child, Leaf):
1312                 if child.type == token.NAME:
1313                     yield child.value
1314
1315             elif child.type == syms.import_as_name:
1316                 orig_name = child.children[0]
1317                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1318                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1319                 yield orig_name.value
1320
1321             elif child.type == syms.import_as_names:
1322                 yield from get_imports_from_children(child.children)
1323
1324             else:
1325                 raise AssertionError("Invalid syntax parsing imports")
1326
1327     for child in node.children:
1328         if child.type != syms.simple_stmt:
1329             break
1330
1331         first_child = child.children[0]
1332         if isinstance(first_child, Leaf):
1333             # Continue looking if we see a docstring; otherwise stop.
1334             if (
1335                 len(child.children) == 2
1336                 and first_child.type == token.STRING
1337                 and child.children[1].type == token.NEWLINE
1338             ):
1339                 continue
1340
1341             break
1342
1343         elif first_child.type == syms.import_from:
1344             module_name = first_child.children[1]
1345             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1346                 break
1347
1348             imports |= set(get_imports_from_children(first_child.children[3:]))
1349         else:
1350             break
1351
1352     return imports
1353
1354
1355 def assert_equivalent(src: str, dst: str) -> None:
1356     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1357     try:
1358         src_ast = parse_ast(src)
1359     except Exception as exc:
1360         raise AssertionError(
1361             "cannot use --safe with this file; failed to parse source file AST: "
1362             f"{exc}\n"
1363             "This could be caused by running Black with an older Python version "
1364             "that does not support new syntax used in your source file."
1365         ) from exc
1366
1367     try:
1368         dst_ast = parse_ast(dst)
1369     except Exception as exc:
1370         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1371         raise AssertionError(
1372             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1373             "Please report a bug on https://github.com/psf/black/issues.  "
1374             f"This invalid output might be helpful: {log}"
1375         ) from None
1376
1377     src_ast_str = "\n".join(stringify_ast(src_ast))
1378     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1379     if src_ast_str != dst_ast_str:
1380         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1381         raise AssertionError(
1382             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1383             " source.  Please report a bug on "
1384             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1385         ) from None
1386
1387
1388 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1389     """Raise AssertionError if `dst` reformats differently the second time."""
1390     # We shouldn't call format_str() here, because that formats the string
1391     # twice and may hide a bug where we bounce back and forth between two
1392     # versions.
1393     newdst = _format_str_once(dst, mode=mode)
1394     if dst != newdst:
1395         log = dump_to_file(
1396             str(mode),
1397             diff(src, dst, "source", "first pass"),
1398             diff(dst, newdst, "first pass", "second pass"),
1399         )
1400         raise AssertionError(
1401             "INTERNAL ERROR: Black produced different code on the second pass of the"
1402             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1403             f"  This diff might be helpful: {log}"
1404         ) from None
1405
1406
1407 @contextmanager
1408 def nullcontext() -> Iterator[None]:
1409     """Return an empty context manager.
1410
1411     To be used like `nullcontext` in Python 3.7.
1412     """
1413     yield
1414
1415
1416 def patch_click() -> None:
1417     """Make Click not crash on Python 3.6 with LANG=C.
1418
1419     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1420     default which restricts paths that it can access during the lifetime of the
1421     application.  Click refuses to work in this scenario by raising a RuntimeError.
1422
1423     In case of Black the likelihood that non-ASCII characters are going to be used in
1424     file paths is minimal since it's Python source code.  Moreover, this crash was
1425     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1426     """
1427     try:
1428         from click import core
1429         from click import _unicodefun
1430     except ModuleNotFoundError:
1431         return
1432
1433     for module in (core, _unicodefun):
1434         if hasattr(module, "_verify_python3_env"):
1435             module._verify_python3_env = lambda: None  # type: ignore
1436         if hasattr(module, "_verify_python_env"):
1437             module._verify_python_env = lambda: None  # type: ignore
1438
1439
1440 def patched_main() -> None:
1441     maybe_install_uvloop()
1442     freeze_support()
1443     patch_click()
1444     main()
1445
1446
1447 if __name__ == "__main__":
1448     patched_main()