]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Switch to Furo (#2793)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Sized,
30     Tuple,
31     Union,
32 )
33
34 import click
35 from click.core import ParameterSource
36 from dataclasses import replace
37 from mypy_extensions import mypyc_attr
38
39 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
40 from black.const import STDIN_PLACEHOLDER
41 from black.nodes import STARS, syms, is_simple_decorator_expression
42 from black.nodes import is_string_token
43 from black.lines import Line, EmptyLineTracker
44 from black.linegen import transform_line, LineGenerator, LN
45 from black.comments import normalize_fmt_off
46 from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
47 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
48 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
49 from black.concurrency import cancel, shutdown, maybe_install_uvloop
50 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
51 from black.report import Report, Changed, NothingChanged
52 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
53 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
54 from black.files import wrap_stream_for_windows
55 from black.parsing import InvalidInput  # noqa F401
56 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
57 from black.handle_ipynb_magics import (
58     mask_cell,
59     unmask_cell,
60     remove_trailing_semicolon,
61     put_trailing_semicolon_back,
62     TRANSFORMED_MAGICS,
63     PYTHON_CELL_MAGICS,
64     jupyter_dependencies_are_installed,
65 )
66
67
68 # lib2to3 fork
69 from blib2to3.pytree import Node, Leaf
70 from blib2to3.pgen2 import token
71
72 from _black_version import version as __version__
73
74 COMPILED = Path(__file__).suffix in (".pyd", ".so")
75
76 # types
77 FileContent = str
78 Encoding = str
79 NewLine = str
80
81
82 class WriteBack(Enum):
83     NO = 0
84     YES = 1
85     DIFF = 2
86     CHECK = 3
87     COLOR_DIFF = 4
88
89     @classmethod
90     def from_configuration(
91         cls, *, check: bool, diff: bool, color: bool = False
92     ) -> "WriteBack":
93         if check and not diff:
94             return cls.CHECK
95
96         if diff and color:
97             return cls.COLOR_DIFF
98
99         return cls.DIFF if diff else cls.YES
100
101
102 # Legacy name, left for integrations.
103 FileMode = Mode
104
105 DEFAULT_WORKERS = os.cpu_count()
106
107
108 def read_pyproject_toml(
109     ctx: click.Context, param: click.Parameter, value: Optional[str]
110 ) -> Optional[str]:
111     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
112
113     Returns the path to a successfully found and read configuration file, None
114     otherwise.
115     """
116     if not value:
117         value = find_pyproject_toml(ctx.params.get("src", ()))
118         if value is None:
119             return None
120
121     try:
122         config = parse_pyproject_toml(value)
123     except (OSError, ValueError) as e:
124         raise click.FileError(
125             filename=value, hint=f"Error reading configuration file: {e}"
126         ) from None
127
128     if not config:
129         return None
130     else:
131         # Sanitize the values to be Click friendly. For more information please see:
132         # https://github.com/psf/black/issues/1458
133         # https://github.com/pallets/click/issues/1567
134         config = {
135             k: str(v) if not isinstance(v, (list, dict)) else v
136             for k, v in config.items()
137         }
138
139     target_version = config.get("target_version")
140     if target_version is not None and not isinstance(target_version, list):
141         raise click.BadOptionUsage(
142             "target-version", "Config key target-version must be a list"
143         )
144
145     default_map: Dict[str, Any] = {}
146     if ctx.default_map:
147         default_map.update(ctx.default_map)
148     default_map.update(config)
149
150     ctx.default_map = default_map
151     return value
152
153
154 def target_version_option_callback(
155     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
156 ) -> List[TargetVersion]:
157     """Compute the target versions from a --target-version flag.
158
159     This is its own function because mypy couldn't infer the type correctly
160     when it was a lambda, causing mypyc trouble.
161     """
162     return [TargetVersion[val.upper()] for val in v]
163
164
165 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
166     """Compile a regular expression string in `regex`.
167
168     If it contains newlines, use verbose mode.
169     """
170     if "\n" in regex:
171         regex = "(?x)" + regex
172     compiled: Pattern[str] = re.compile(regex)
173     return compiled
174
175
176 def validate_regex(
177     ctx: click.Context,
178     param: click.Parameter,
179     value: Optional[str],
180 ) -> Optional[Pattern[str]]:
181     try:
182         return re_compile_maybe_verbose(value) if value is not None else None
183     except re.error as e:
184         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
185
186
187 @click.command(
188     context_settings={"help_option_names": ["-h", "--help"]},
189     # While Click does set this field automatically using the docstring, mypyc
190     # (annoyingly) strips 'em so we need to set it here too.
191     help="The uncompromising code formatter.",
192 )
193 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
194 @click.option(
195     "-l",
196     "--line-length",
197     type=int,
198     default=DEFAULT_LINE_LENGTH,
199     help="How many characters per line to allow.",
200     show_default=True,
201 )
202 @click.option(
203     "-t",
204     "--target-version",
205     type=click.Choice([v.name.lower() for v in TargetVersion]),
206     callback=target_version_option_callback,
207     multiple=True,
208     help=(
209         "Python versions that should be supported by Black's output. [default: per-file"
210         " auto-detection]"
211     ),
212 )
213 @click.option(
214     "--pyi",
215     is_flag=True,
216     help=(
217         "Format all input files like typing stubs regardless of file extension (useful"
218         " when piping source on standard input)."
219     ),
220 )
221 @click.option(
222     "--ipynb",
223     is_flag=True,
224     help=(
225         "Format all input files like Jupyter Notebooks regardless of file extension "
226         "(useful when piping source on standard input)."
227     ),
228 )
229 @click.option(
230     "--python-cell-magics",
231     multiple=True,
232     help=(
233         "When processing Jupyter Notebooks, add the given magic to the list"
234         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
235         " Useful for formatting cells with custom python magics."
236     ),
237     default=[],
238 )
239 @click.option(
240     "-S",
241     "--skip-string-normalization",
242     is_flag=True,
243     help="Don't normalize string quotes or prefixes.",
244 )
245 @click.option(
246     "-C",
247     "--skip-magic-trailing-comma",
248     is_flag=True,
249     help="Don't use trailing commas as a reason to split lines.",
250 )
251 @click.option(
252     "--experimental-string-processing",
253     is_flag=True,
254     hidden=True,
255     help="(DEPRECATED and now included in --preview) Normalize string literals.",
256 )
257 @click.option(
258     "--preview",
259     is_flag=True,
260     help=(
261         "Enable potentially disruptive style changes that will be added to Black's main"
262         " functionality in the next major release."
263     ),
264 )
265 @click.option(
266     "--check",
267     is_flag=True,
268     help=(
269         "Don't write the files back, just return the status. Return code 0 means"
270         " nothing would change. Return code 1 means some files would be reformatted."
271         " Return code 123 means there was an internal error."
272     ),
273 )
274 @click.option(
275     "--diff",
276     is_flag=True,
277     help="Don't write the files back, just output a diff for each file on stdout.",
278 )
279 @click.option(
280     "--color/--no-color",
281     is_flag=True,
282     help="Show colored diff. Only applies when `--diff` is given.",
283 )
284 @click.option(
285     "--fast/--safe",
286     is_flag=True,
287     help="If --fast given, skip temporary sanity checks. [default: --safe]",
288 )
289 @click.option(
290     "--required-version",
291     type=str,
292     help=(
293         "Require a specific version of Black to be running (useful for unifying results"
294         " across many environments e.g. with a pyproject.toml file)."
295     ),
296 )
297 @click.option(
298     "--include",
299     type=str,
300     default=DEFAULT_INCLUDES,
301     callback=validate_regex,
302     help=(
303         "A regular expression that matches files and directories that should be"
304         " included on recursive searches. An empty value means all files are included"
305         " regardless of the name. Use forward slashes for directories on all platforms"
306         " (Windows, too). Exclusions are calculated first, inclusions later."
307     ),
308     show_default=True,
309 )
310 @click.option(
311     "--exclude",
312     type=str,
313     callback=validate_regex,
314     help=(
315         "A regular expression that matches files and directories that should be"
316         " excluded on recursive searches. An empty value means no paths are excluded."
317         " Use forward slashes for directories on all platforms (Windows, too)."
318         " Exclusions are calculated first, inclusions later. [default:"
319         f" {DEFAULT_EXCLUDES}]"
320     ),
321     show_default=False,
322 )
323 @click.option(
324     "--extend-exclude",
325     type=str,
326     callback=validate_regex,
327     help=(
328         "Like --exclude, but adds additional files and directories on top of the"
329         " excluded ones. (Useful if you simply want to add to the default)"
330     ),
331 )
332 @click.option(
333     "--force-exclude",
334     type=str,
335     callback=validate_regex,
336     help=(
337         "Like --exclude, but files and directories matching this regex will be "
338         "excluded even when they are passed explicitly as arguments."
339     ),
340 )
341 @click.option(
342     "--stdin-filename",
343     type=str,
344     help=(
345         "The name of the file when passing it through stdin. Useful to make "
346         "sure Black will respect --force-exclude option on some "
347         "editors that rely on using stdin."
348     ),
349 )
350 @click.option(
351     "-W",
352     "--workers",
353     type=click.IntRange(min=1),
354     default=DEFAULT_WORKERS,
355     show_default=True,
356     help="Number of parallel workers",
357 )
358 @click.option(
359     "-q",
360     "--quiet",
361     is_flag=True,
362     help=(
363         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
364         " those with 2>/dev/null."
365     ),
366 )
367 @click.option(
368     "-v",
369     "--verbose",
370     is_flag=True,
371     help=(
372         "Also emit messages to stderr about files that were not changed or were ignored"
373         " due to exclusion patterns."
374     ),
375 )
376 @click.version_option(
377     version=__version__,
378     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
379 )
380 @click.argument(
381     "src",
382     nargs=-1,
383     type=click.Path(
384         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
385     ),
386     is_eager=True,
387     metavar="SRC ...",
388 )
389 @click.option(
390     "--config",
391     type=click.Path(
392         exists=True,
393         file_okay=True,
394         dir_okay=False,
395         readable=True,
396         allow_dash=False,
397         path_type=str,
398     ),
399     is_eager=True,
400     callback=read_pyproject_toml,
401     help="Read configuration from FILE path.",
402 )
403 @click.pass_context
404 def main(
405     ctx: click.Context,
406     code: Optional[str],
407     line_length: int,
408     target_version: List[TargetVersion],
409     check: bool,
410     diff: bool,
411     color: bool,
412     fast: bool,
413     pyi: bool,
414     ipynb: bool,
415     python_cell_magics: Sequence[str],
416     skip_string_normalization: bool,
417     skip_magic_trailing_comma: bool,
418     experimental_string_processing: bool,
419     preview: bool,
420     quiet: bool,
421     verbose: bool,
422     required_version: Optional[str],
423     include: Pattern[str],
424     exclude: Optional[Pattern[str]],
425     extend_exclude: Optional[Pattern[str]],
426     force_exclude: Optional[Pattern[str]],
427     stdin_filename: Optional[str],
428     workers: int,
429     src: Tuple[str, ...],
430     config: Optional[str],
431 ) -> None:
432     """The uncompromising code formatter."""
433     ctx.ensure_object(dict)
434     root, method = find_project_root(src) if code is None else (None, None)
435     ctx.obj["root"] = root
436
437     if verbose:
438         if root:
439             out(
440                 f"Identified `{root}` as project root containing a {method}.",
441                 fg="blue",
442             )
443
444             normalized = [
445                 (normalize_path_maybe_ignore(Path(source), root), source)
446                 for source in src
447             ]
448             srcs_string = ", ".join(
449                 [
450                     f'"{_norm}"'
451                     if _norm
452                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
453                     for _norm, source in normalized
454                 ]
455             )
456             out(f"Sources to be formatted: {srcs_string}", fg="blue")
457
458         if config:
459             config_source = ctx.get_parameter_source("config")
460             if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP):
461                 out("Using configuration from project root.", fg="blue")
462             else:
463                 out(f"Using configuration in '{config}'.", fg="blue")
464
465     error_msg = "Oh no! 💥 💔 💥"
466     if required_version and required_version != __version__:
467         err(
468             f"{error_msg} The required version `{required_version}` does not match"
469             f" the running version `{__version__}`!"
470         )
471         ctx.exit(1)
472     if ipynb and pyi:
473         err("Cannot pass both `pyi` and `ipynb` flags!")
474         ctx.exit(1)
475
476     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
477     if target_version:
478         versions = set(target_version)
479     else:
480         # We'll autodetect later.
481         versions = set()
482     mode = Mode(
483         target_versions=versions,
484         line_length=line_length,
485         is_pyi=pyi,
486         is_ipynb=ipynb,
487         string_normalization=not skip_string_normalization,
488         magic_trailing_comma=not skip_magic_trailing_comma,
489         experimental_string_processing=experimental_string_processing,
490         preview=preview,
491         python_cell_magics=set(python_cell_magics),
492     )
493
494     if code is not None:
495         # Run in quiet mode by default with -c; the extra output isn't useful.
496         # You can still pass -v to get verbose output.
497         quiet = True
498
499     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
500
501     if code is not None:
502         reformat_code(
503             content=code, fast=fast, write_back=write_back, mode=mode, report=report
504         )
505     else:
506         try:
507             sources = get_sources(
508                 ctx=ctx,
509                 src=src,
510                 quiet=quiet,
511                 verbose=verbose,
512                 include=include,
513                 exclude=exclude,
514                 extend_exclude=extend_exclude,
515                 force_exclude=force_exclude,
516                 report=report,
517                 stdin_filename=stdin_filename,
518             )
519         except GitWildMatchPatternError:
520             ctx.exit(1)
521
522         path_empty(
523             sources,
524             "No Python files are present to be formatted. Nothing to do 😴",
525             quiet,
526             verbose,
527             ctx,
528         )
529
530         if len(sources) == 1:
531             reformat_one(
532                 src=sources.pop(),
533                 fast=fast,
534                 write_back=write_back,
535                 mode=mode,
536                 report=report,
537             )
538         else:
539             reformat_many(
540                 sources=sources,
541                 fast=fast,
542                 write_back=write_back,
543                 mode=mode,
544                 report=report,
545                 workers=workers,
546             )
547
548     if verbose or not quiet:
549         if code is None and (verbose or report.change_count or report.failure_count):
550             out()
551         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
552         if code is None:
553             click.echo(str(report), err=True)
554     ctx.exit(report.return_code)
555
556
557 def get_sources(
558     *,
559     ctx: click.Context,
560     src: Tuple[str, ...],
561     quiet: bool,
562     verbose: bool,
563     include: Pattern[str],
564     exclude: Optional[Pattern[str]],
565     extend_exclude: Optional[Pattern[str]],
566     force_exclude: Optional[Pattern[str]],
567     report: "Report",
568     stdin_filename: Optional[str],
569 ) -> Set[Path]:
570     """Compute the set of files to be formatted."""
571     sources: Set[Path] = set()
572     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
573
574     if exclude is None:
575         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
576         gitignore = get_gitignore(ctx.obj["root"])
577     else:
578         gitignore = None
579
580     for s in src:
581         if s == "-" and stdin_filename:
582             p = Path(stdin_filename)
583             is_stdin = True
584         else:
585             p = Path(s)
586             is_stdin = False
587
588         if is_stdin or p.is_file():
589             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
590             if normalized_path is None:
591                 continue
592
593             normalized_path = "/" + normalized_path
594             # Hard-exclude any files that matches the `--force-exclude` regex.
595             if force_exclude:
596                 force_exclude_match = force_exclude.search(normalized_path)
597             else:
598                 force_exclude_match = None
599             if force_exclude_match and force_exclude_match.group(0):
600                 report.path_ignored(p, "matches the --force-exclude regular expression")
601                 continue
602
603             if is_stdin:
604                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
605
606             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
607                 verbose=verbose, quiet=quiet
608             ):
609                 continue
610
611             sources.add(p)
612         elif p.is_dir():
613             sources.update(
614                 gen_python_files(
615                     p.iterdir(),
616                     ctx.obj["root"],
617                     include,
618                     exclude,
619                     extend_exclude,
620                     force_exclude,
621                     report,
622                     gitignore,
623                     verbose=verbose,
624                     quiet=quiet,
625                 )
626             )
627         elif s == "-":
628             sources.add(p)
629         else:
630             err(f"invalid path: {s}")
631     return sources
632
633
634 def path_empty(
635     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
636 ) -> None:
637     """
638     Exit if there is no `src` provided for formatting
639     """
640     if not src:
641         if verbose or not quiet:
642             out(msg)
643         ctx.exit(0)
644
645
646 def reformat_code(
647     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
648 ) -> None:
649     """
650     Reformat and print out `content` without spawning child processes.
651     Similar to `reformat_one`, but for string content.
652
653     `fast`, `write_back`, and `mode` options are passed to
654     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
655     """
656     path = Path("<string>")
657     try:
658         changed = Changed.NO
659         if format_stdin_to_stdout(
660             content=content, fast=fast, write_back=write_back, mode=mode
661         ):
662             changed = Changed.YES
663         report.done(path, changed)
664     except Exception as exc:
665         if report.verbose:
666             traceback.print_exc()
667         report.failed(path, str(exc))
668
669
670 def reformat_one(
671     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
672 ) -> None:
673     """Reformat a single file under `src` without spawning child processes.
674
675     `fast`, `write_back`, and `mode` options are passed to
676     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
677     """
678     try:
679         changed = Changed.NO
680
681         if str(src) == "-":
682             is_stdin = True
683         elif str(src).startswith(STDIN_PLACEHOLDER):
684             is_stdin = True
685             # Use the original name again in case we want to print something
686             # to the user
687             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
688         else:
689             is_stdin = False
690
691         if is_stdin:
692             if src.suffix == ".pyi":
693                 mode = replace(mode, is_pyi=True)
694             elif src.suffix == ".ipynb":
695                 mode = replace(mode, is_ipynb=True)
696             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
697                 changed = Changed.YES
698         else:
699             cache: Cache = {}
700             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
701                 cache = read_cache(mode)
702                 res_src = src.resolve()
703                 res_src_s = str(res_src)
704                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
705                     changed = Changed.CACHED
706             if changed is not Changed.CACHED and format_file_in_place(
707                 src, fast=fast, write_back=write_back, mode=mode
708             ):
709                 changed = Changed.YES
710             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
711                 write_back is WriteBack.CHECK and changed is Changed.NO
712             ):
713                 write_cache(cache, [src], mode)
714         report.done(src, changed)
715     except Exception as exc:
716         if report.verbose:
717             traceback.print_exc()
718         report.failed(src, str(exc))
719
720
721 # diff-shades depends on being to monkeypatch this function to operate. I know it's
722 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
723 @mypyc_attr(patchable=True)
724 def reformat_many(
725     sources: Set[Path],
726     fast: bool,
727     write_back: WriteBack,
728     mode: Mode,
729     report: "Report",
730     workers: Optional[int],
731 ) -> None:
732     """Reformat multiple files using a ProcessPoolExecutor."""
733     executor: Executor
734     loop = asyncio.get_event_loop()
735     worker_count = workers if workers is not None else DEFAULT_WORKERS
736     if sys.platform == "win32":
737         # Work around https://bugs.python.org/issue26903
738         assert worker_count is not None
739         worker_count = min(worker_count, 60)
740     try:
741         executor = ProcessPoolExecutor(max_workers=worker_count)
742     except (ImportError, NotImplementedError, OSError):
743         # we arrive here if the underlying system does not support multi-processing
744         # like in AWS Lambda or Termux, in which case we gracefully fallback to
745         # a ThreadPoolExecutor with just a single worker (more workers would not do us
746         # any good due to the Global Interpreter Lock)
747         executor = ThreadPoolExecutor(max_workers=1)
748
749     try:
750         loop.run_until_complete(
751             schedule_formatting(
752                 sources=sources,
753                 fast=fast,
754                 write_back=write_back,
755                 mode=mode,
756                 report=report,
757                 loop=loop,
758                 executor=executor,
759             )
760         )
761     finally:
762         shutdown(loop)
763         if executor is not None:
764             executor.shutdown()
765
766
767 async def schedule_formatting(
768     sources: Set[Path],
769     fast: bool,
770     write_back: WriteBack,
771     mode: Mode,
772     report: "Report",
773     loop: asyncio.AbstractEventLoop,
774     executor: Executor,
775 ) -> None:
776     """Run formatting of `sources` in parallel using the provided `executor`.
777
778     (Use ProcessPoolExecutors for actual parallelism.)
779
780     `write_back`, `fast`, and `mode` options are passed to
781     :func:`format_file_in_place`.
782     """
783     cache: Cache = {}
784     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
785         cache = read_cache(mode)
786         sources, cached = filter_cached(cache, sources)
787         for src in sorted(cached):
788             report.done(src, Changed.CACHED)
789     if not sources:
790         return
791
792     cancelled = []
793     sources_to_cache = []
794     lock = None
795     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
796         # For diff output, we need locks to ensure we don't interleave output
797         # from different processes.
798         manager = Manager()
799         lock = manager.Lock()
800     tasks = {
801         asyncio.ensure_future(
802             loop.run_in_executor(
803                 executor, format_file_in_place, src, fast, mode, write_back, lock
804             )
805         ): src
806         for src in sorted(sources)
807     }
808     pending = tasks.keys()
809     try:
810         loop.add_signal_handler(signal.SIGINT, cancel, pending)
811         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
812     except NotImplementedError:
813         # There are no good alternatives for these on Windows.
814         pass
815     while pending:
816         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
817         for task in done:
818             src = tasks.pop(task)
819             if task.cancelled():
820                 cancelled.append(task)
821             elif task.exception():
822                 report.failed(src, str(task.exception()))
823             else:
824                 changed = Changed.YES if task.result() else Changed.NO
825                 # If the file was written back or was successfully checked as
826                 # well-formatted, store this information in the cache.
827                 if write_back is WriteBack.YES or (
828                     write_back is WriteBack.CHECK and changed is Changed.NO
829                 ):
830                     sources_to_cache.append(src)
831                 report.done(src, changed)
832     if cancelled:
833         if sys.version_info >= (3, 7):
834             await asyncio.gather(*cancelled, return_exceptions=True)
835         else:
836             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
837     if sources_to_cache:
838         write_cache(cache, sources_to_cache, mode)
839
840
841 def format_file_in_place(
842     src: Path,
843     fast: bool,
844     mode: Mode,
845     write_back: WriteBack = WriteBack.NO,
846     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
847 ) -> bool:
848     """Format file under `src` path. Return True if changed.
849
850     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
851     code to the file.
852     `mode` and `fast` options are passed to :func:`format_file_contents`.
853     """
854     if src.suffix == ".pyi":
855         mode = replace(mode, is_pyi=True)
856     elif src.suffix == ".ipynb":
857         mode = replace(mode, is_ipynb=True)
858
859     then = datetime.utcfromtimestamp(src.stat().st_mtime)
860     with open(src, "rb") as buf:
861         src_contents, encoding, newline = decode_bytes(buf.read())
862     try:
863         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
864     except NothingChanged:
865         return False
866     except JSONDecodeError:
867         raise ValueError(
868             f"File '{src}' cannot be parsed as valid Jupyter notebook."
869         ) from None
870
871     if write_back == WriteBack.YES:
872         with open(src, "w", encoding=encoding, newline=newline) as f:
873             f.write(dst_contents)
874     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
875         now = datetime.utcnow()
876         src_name = f"{src}\t{then} +0000"
877         dst_name = f"{src}\t{now} +0000"
878         if mode.is_ipynb:
879             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
880         else:
881             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
882
883         if write_back == WriteBack.COLOR_DIFF:
884             diff_contents = color_diff(diff_contents)
885
886         with lock or nullcontext():
887             f = io.TextIOWrapper(
888                 sys.stdout.buffer,
889                 encoding=encoding,
890                 newline=newline,
891                 write_through=True,
892             )
893             f = wrap_stream_for_windows(f)
894             f.write(diff_contents)
895             f.detach()
896
897     return True
898
899
900 def format_stdin_to_stdout(
901     fast: bool,
902     *,
903     content: Optional[str] = None,
904     write_back: WriteBack = WriteBack.NO,
905     mode: Mode,
906 ) -> bool:
907     """Format file on stdin. Return True if changed.
908
909     If content is None, it's read from sys.stdin.
910
911     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
912     write a diff to stdout. The `mode` argument is passed to
913     :func:`format_file_contents`.
914     """
915     then = datetime.utcnow()
916
917     if content is None:
918         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
919     else:
920         src, encoding, newline = content, "utf-8", ""
921
922     dst = src
923     try:
924         dst = format_file_contents(src, fast=fast, mode=mode)
925         return True
926
927     except NothingChanged:
928         return False
929
930     finally:
931         f = io.TextIOWrapper(
932             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
933         )
934         if write_back == WriteBack.YES:
935             # Make sure there's a newline after the content
936             if dst and dst[-1] != "\n":
937                 dst += "\n"
938             f.write(dst)
939         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
940             now = datetime.utcnow()
941             src_name = f"STDIN\t{then} +0000"
942             dst_name = f"STDOUT\t{now} +0000"
943             d = diff(src, dst, src_name, dst_name)
944             if write_back == WriteBack.COLOR_DIFF:
945                 d = color_diff(d)
946                 f = wrap_stream_for_windows(f)
947             f.write(d)
948         f.detach()
949
950
951 def check_stability_and_equivalence(
952     src_contents: str, dst_contents: str, *, mode: Mode
953 ) -> None:
954     """Perform stability and equivalence checks.
955
956     Raise AssertionError if source and destination contents are not
957     equivalent, or if a second pass of the formatter would format the
958     content differently.
959     """
960     assert_equivalent(src_contents, dst_contents)
961
962     # Forced second pass to work around optional trailing commas (becoming
963     # forced trailing commas on pass 2) interacting differently with optional
964     # parentheses.  Admittedly ugly.
965     dst_contents_pass2 = format_str(dst_contents, mode=mode)
966     if dst_contents != dst_contents_pass2:
967         dst_contents = dst_contents_pass2
968         assert_equivalent(src_contents, dst_contents, pass_num=2)
969         assert_stable(src_contents, dst_contents, mode=mode)
970     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
971     # the same as `dst_contents_pass2`.
972
973
974 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
975     """Reformat contents of a file and return new contents.
976
977     If `fast` is False, additionally confirm that the reformatted code is
978     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
979     `mode` is passed to :func:`format_str`.
980     """
981     if not src_contents.strip():
982         raise NothingChanged
983
984     if mode.is_ipynb:
985         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
986     else:
987         dst_contents = format_str(src_contents, mode=mode)
988     if src_contents == dst_contents:
989         raise NothingChanged
990
991     if not fast and not mode.is_ipynb:
992         # Jupyter notebooks will already have been checked above.
993         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
994     return dst_contents
995
996
997 def validate_cell(src: str, mode: Mode) -> None:
998     """Check that cell does not already contain TransformerManager transformations,
999     or non-Python cell magics, which might cause tokenizer_rt to break because of
1000     indentations.
1001
1002     If a cell contains ``!ls``, then it'll be transformed to
1003     ``get_ipython().system('ls')``. However, if the cell originally contained
1004     ``get_ipython().system('ls')``, then it would get transformed in the same way:
1005
1006         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
1007         "get_ipython().system('ls')\n"
1008         >>> TransformerManager().transform_cell("!ls")
1009         "get_ipython().system('ls')\n"
1010
1011     Due to the impossibility of safely roundtripping in such situations, cells
1012     containing transformed magics will be ignored.
1013     """
1014     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
1015         raise NothingChanged
1016     if (
1017         src[:2] == "%%"
1018         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
1019     ):
1020         raise NothingChanged
1021
1022
1023 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
1024     """Format code in given cell of Jupyter notebook.
1025
1026     General idea is:
1027
1028       - if cell has trailing semicolon, remove it;
1029       - if cell has IPython magics, mask them;
1030       - format cell;
1031       - reinstate IPython magics;
1032       - reinstate trailing semicolon (if originally present);
1033       - strip trailing newlines.
1034
1035     Cells with syntax errors will not be processed, as they
1036     could potentially be automagics or multi-line magics, which
1037     are currently not supported.
1038     """
1039     validate_cell(src, mode)
1040     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
1041         src
1042     )
1043     try:
1044         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
1045     except SyntaxError:
1046         raise NothingChanged from None
1047     masked_dst = format_str(masked_src, mode=mode)
1048     if not fast:
1049         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
1050     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1051     dst = put_trailing_semicolon_back(
1052         dst_without_trailing_semicolon, has_trailing_semicolon
1053     )
1054     dst = dst.rstrip("\n")
1055     if dst == src:
1056         raise NothingChanged from None
1057     return dst
1058
1059
1060 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1061     """If notebook is marked as non-Python, don't format it.
1062
1063     All notebook metadata fields are optional, see
1064     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1065     if a notebook has empty metadata, we will try to parse it anyway.
1066     """
1067     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1068     if language is not None and language != "python":
1069         raise NothingChanged from None
1070
1071
1072 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1073     """Format Jupyter notebook.
1074
1075     Operate cell-by-cell, only on code cells, only for Python notebooks.
1076     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1077     """
1078     trailing_newline = src_contents[-1] == "\n"
1079     modified = False
1080     nb = json.loads(src_contents)
1081     validate_metadata(nb)
1082     for cell in nb["cells"]:
1083         if cell.get("cell_type", None) == "code":
1084             try:
1085                 src = "".join(cell["source"])
1086                 dst = format_cell(src, fast=fast, mode=mode)
1087             except NothingChanged:
1088                 pass
1089             else:
1090                 cell["source"] = dst.splitlines(keepends=True)
1091                 modified = True
1092     if modified:
1093         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1094         if trailing_newline:
1095             dst_contents = dst_contents + "\n"
1096         return dst_contents
1097     else:
1098         raise NothingChanged
1099
1100
1101 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1102     """Reformat a string and return new contents.
1103
1104     `mode` determines formatting options, such as how many characters per line are
1105     allowed.  Example:
1106
1107     >>> import black
1108     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1109     def f(arg: str = "") -> None:
1110         ...
1111
1112     A more complex example:
1113
1114     >>> print(
1115     ...   black.format_str(
1116     ...     "def f(arg:str='')->None: hey",
1117     ...     mode=black.Mode(
1118     ...       target_versions={black.TargetVersion.PY36},
1119     ...       line_length=10,
1120     ...       string_normalization=False,
1121     ...       is_pyi=False,
1122     ...     ),
1123     ...   ),
1124     ... )
1125     def f(
1126         arg: str = '',
1127     ) -> None:
1128         hey
1129
1130     """
1131     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1132     dst_contents = []
1133     future_imports = get_future_imports(src_node)
1134     if mode.target_versions:
1135         versions = mode.target_versions
1136     else:
1137         versions = detect_target_versions(src_node, future_imports=future_imports)
1138
1139     normalize_fmt_off(src_node)
1140     lines = LineGenerator(mode=mode)
1141     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1142     empty_line = Line(mode=mode)
1143     after = 0
1144     split_line_features = {
1145         feature
1146         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1147         if supports_feature(versions, feature)
1148     }
1149     for current_line in lines.visit(src_node):
1150         dst_contents.append(str(empty_line) * after)
1151         before, after = elt.maybe_empty_lines(current_line)
1152         dst_contents.append(str(empty_line) * before)
1153         for line in transform_line(
1154             current_line, mode=mode, features=split_line_features
1155         ):
1156             dst_contents.append(str(line))
1157     return "".join(dst_contents)
1158
1159
1160 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1161     """Return a tuple of (decoded_contents, encoding, newline).
1162
1163     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1164     universal newlines (i.e. only contains LF).
1165     """
1166     srcbuf = io.BytesIO(src)
1167     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1168     if not lines:
1169         return "", encoding, "\n"
1170
1171     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1172     srcbuf.seek(0)
1173     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1174         return tiow.read(), encoding, newline
1175
1176
1177 def get_features_used(  # noqa: C901
1178     node: Node, *, future_imports: Optional[Set[str]] = None
1179 ) -> Set[Feature]:
1180     """Return a set of (relatively) new Python features used in this file.
1181
1182     Currently looking for:
1183     - f-strings;
1184     - underscores in numeric literals;
1185     - trailing commas after * or ** in function signatures and calls;
1186     - positional only arguments in function signatures and lambdas;
1187     - assignment expression;
1188     - relaxed decorator syntax;
1189     - usage of __future__ flags (annotations);
1190     - print / exec statements;
1191     """
1192     features: Set[Feature] = set()
1193     if future_imports:
1194         features |= {
1195             FUTURE_FLAG_TO_FEATURE[future_import]
1196             for future_import in future_imports
1197             if future_import in FUTURE_FLAG_TO_FEATURE
1198         }
1199
1200     for n in node.pre_order():
1201         if is_string_token(n):
1202             value_head = n.value[:2]
1203             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1204                 features.add(Feature.F_STRINGS)
1205
1206         elif n.type == token.NUMBER:
1207             assert isinstance(n, Leaf)
1208             if "_" in n.value:
1209                 features.add(Feature.NUMERIC_UNDERSCORES)
1210
1211         elif n.type == token.SLASH:
1212             if n.parent and n.parent.type in {
1213                 syms.typedargslist,
1214                 syms.arglist,
1215                 syms.varargslist,
1216             }:
1217                 features.add(Feature.POS_ONLY_ARGUMENTS)
1218
1219         elif n.type == token.COLONEQUAL:
1220             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1221
1222         elif n.type == syms.decorator:
1223             if len(n.children) > 1 and not is_simple_decorator_expression(
1224                 n.children[1]
1225             ):
1226                 features.add(Feature.RELAXED_DECORATORS)
1227
1228         elif (
1229             n.type in {syms.typedargslist, syms.arglist}
1230             and n.children
1231             and n.children[-1].type == token.COMMA
1232         ):
1233             if n.type == syms.typedargslist:
1234                 feature = Feature.TRAILING_COMMA_IN_DEF
1235             else:
1236                 feature = Feature.TRAILING_COMMA_IN_CALL
1237
1238             for ch in n.children:
1239                 if ch.type in STARS:
1240                     features.add(feature)
1241
1242                 if ch.type == syms.argument:
1243                     for argch in ch.children:
1244                         if argch.type in STARS:
1245                             features.add(feature)
1246
1247         elif (
1248             n.type in {syms.return_stmt, syms.yield_expr}
1249             and len(n.children) >= 2
1250             and n.children[1].type == syms.testlist_star_expr
1251             and any(child.type == syms.star_expr for child in n.children[1].children)
1252         ):
1253             features.add(Feature.UNPACKING_ON_FLOW)
1254
1255         elif (
1256             n.type == syms.annassign
1257             and len(n.children) >= 4
1258             and n.children[3].type == syms.testlist_star_expr
1259         ):
1260             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1261
1262     return features
1263
1264
1265 def detect_target_versions(
1266     node: Node, *, future_imports: Optional[Set[str]] = None
1267 ) -> Set[TargetVersion]:
1268     """Detect the version to target based on the nodes used."""
1269     features = get_features_used(node, future_imports=future_imports)
1270     return {
1271         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1272     }
1273
1274
1275 def get_future_imports(node: Node) -> Set[str]:
1276     """Return a set of __future__ imports in the file."""
1277     imports: Set[str] = set()
1278
1279     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1280         for child in children:
1281             if isinstance(child, Leaf):
1282                 if child.type == token.NAME:
1283                     yield child.value
1284
1285             elif child.type == syms.import_as_name:
1286                 orig_name = child.children[0]
1287                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1288                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1289                 yield orig_name.value
1290
1291             elif child.type == syms.import_as_names:
1292                 yield from get_imports_from_children(child.children)
1293
1294             else:
1295                 raise AssertionError("Invalid syntax parsing imports")
1296
1297     for child in node.children:
1298         if child.type != syms.simple_stmt:
1299             break
1300
1301         first_child = child.children[0]
1302         if isinstance(first_child, Leaf):
1303             # Continue looking if we see a docstring; otherwise stop.
1304             if (
1305                 len(child.children) == 2
1306                 and first_child.type == token.STRING
1307                 and child.children[1].type == token.NEWLINE
1308             ):
1309                 continue
1310
1311             break
1312
1313         elif first_child.type == syms.import_from:
1314             module_name = first_child.children[1]
1315             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1316                 break
1317
1318             imports |= set(get_imports_from_children(first_child.children[3:]))
1319         else:
1320             break
1321
1322     return imports
1323
1324
1325 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1326     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1327     try:
1328         src_ast = parse_ast(src)
1329     except Exception as exc:
1330         raise AssertionError(
1331             f"cannot use --safe with this file; failed to parse source file AST: "
1332             f"{exc}\n"
1333             f"This could be caused by running Black with an older Python version "
1334             f"that does not support new syntax used in your source file."
1335         ) from exc
1336
1337     try:
1338         dst_ast = parse_ast(dst)
1339     except Exception as exc:
1340         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1341         raise AssertionError(
1342             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1343             "Please report a bug on https://github.com/psf/black/issues.  "
1344             f"This invalid output might be helpful: {log}"
1345         ) from None
1346
1347     src_ast_str = "\n".join(stringify_ast(src_ast))
1348     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1349     if src_ast_str != dst_ast_str:
1350         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1351         raise AssertionError(
1352             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1353             f" source on pass {pass_num}.  Please report a bug on "
1354             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1355         ) from None
1356
1357
1358 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1359     """Raise AssertionError if `dst` reformats differently the second time."""
1360     newdst = format_str(dst, mode=mode)
1361     if dst != newdst:
1362         log = dump_to_file(
1363             str(mode),
1364             diff(src, dst, "source", "first pass"),
1365             diff(dst, newdst, "first pass", "second pass"),
1366         )
1367         raise AssertionError(
1368             "INTERNAL ERROR: Black produced different code on the second pass of the"
1369             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1370             f"  This diff might be helpful: {log}"
1371         ) from None
1372
1373
1374 @contextmanager
1375 def nullcontext() -> Iterator[None]:
1376     """Return an empty context manager.
1377
1378     To be used like `nullcontext` in Python 3.7.
1379     """
1380     yield
1381
1382
1383 def patch_click() -> None:
1384     """Make Click not crash on Python 3.6 with LANG=C.
1385
1386     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1387     default which restricts paths that it can access during the lifetime of the
1388     application.  Click refuses to work in this scenario by raising a RuntimeError.
1389
1390     In case of Black the likelihood that non-ASCII characters are going to be used in
1391     file paths is minimal since it's Python source code.  Moreover, this crash was
1392     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1393     """
1394     try:
1395         from click import core
1396         from click import _unicodefun
1397     except ModuleNotFoundError:
1398         return
1399
1400     for module in (core, _unicodefun):
1401         if hasattr(module, "_verify_python3_env"):
1402             module._verify_python3_env = lambda: None  # type: ignore
1403         if hasattr(module, "_verify_python_env"):
1404             module._verify_python_env = lambda: None  # type: ignore
1405
1406
1407 def patched_main() -> None:
1408     maybe_install_uvloop()
1409     freeze_support()
1410     patch_click()
1411     main()
1412
1413
1414 if __name__ == "__main__":
1415     patched_main()