]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix handling of standalone match/case with newlines/comments (#2760)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 import click
34 from click.core import ParameterSource
35 from dataclasses import replace
36 from mypy_extensions import mypyc_attr
37
38 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
39 from black.const import STDIN_PLACEHOLDER
40 from black.nodes import STARS, syms, is_simple_decorator_expression
41 from black.nodes import is_string_token
42 from black.lines import Line, EmptyLineTracker
43 from black.linegen import transform_line, LineGenerator, LN
44 from black.comments import normalize_fmt_off
45 from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
46 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
47 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
48 from black.concurrency import cancel, shutdown, maybe_install_uvloop
49 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
50 from black.report import Report, Changed, NothingChanged
51 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
52 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
53 from black.files import wrap_stream_for_windows
54 from black.parsing import InvalidInput  # noqa F401
55 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
56 from black.handle_ipynb_magics import (
57     mask_cell,
58     unmask_cell,
59     remove_trailing_semicolon,
60     put_trailing_semicolon_back,
61     TRANSFORMED_MAGICS,
62     PYTHON_CELL_MAGICS,
63     jupyter_dependencies_are_installed,
64 )
65
66
67 # lib2to3 fork
68 from blib2to3.pytree import Node, Leaf
69 from blib2to3.pgen2 import token
70
71 from _black_version import version as __version__
72
73 COMPILED = Path(__file__).suffix in (".pyd", ".so")
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79
80
81 class WriteBack(Enum):
82     NO = 0
83     YES = 1
84     DIFF = 2
85     CHECK = 3
86     COLOR_DIFF = 4
87
88     @classmethod
89     def from_configuration(
90         cls, *, check: bool, diff: bool, color: bool = False
91     ) -> "WriteBack":
92         if check and not diff:
93             return cls.CHECK
94
95         if diff and color:
96             return cls.COLOR_DIFF
97
98         return cls.DIFF if diff else cls.YES
99
100
101 # Legacy name, left for integrations.
102 FileMode = Mode
103
104 DEFAULT_WORKERS = os.cpu_count()
105
106
107 def read_pyproject_toml(
108     ctx: click.Context, param: click.Parameter, value: Optional[str]
109 ) -> Optional[str]:
110     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
111
112     Returns the path to a successfully found and read configuration file, None
113     otherwise.
114     """
115     if not value:
116         value = find_pyproject_toml(ctx.params.get("src", ()))
117         if value is None:
118             return None
119
120     try:
121         config = parse_pyproject_toml(value)
122     except (OSError, ValueError) as e:
123         raise click.FileError(
124             filename=value, hint=f"Error reading configuration file: {e}"
125         ) from None
126
127     if not config:
128         return None
129     else:
130         # Sanitize the values to be Click friendly. For more information please see:
131         # https://github.com/psf/black/issues/1458
132         # https://github.com/pallets/click/issues/1567
133         config = {
134             k: str(v) if not isinstance(v, (list, dict)) else v
135             for k, v in config.items()
136         }
137
138     target_version = config.get("target_version")
139     if target_version is not None and not isinstance(target_version, list):
140         raise click.BadOptionUsage(
141             "target-version", "Config key target-version must be a list"
142         )
143
144     default_map: Dict[str, Any] = {}
145     if ctx.default_map:
146         default_map.update(ctx.default_map)
147     default_map.update(config)
148
149     ctx.default_map = default_map
150     return value
151
152
153 def target_version_option_callback(
154     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
155 ) -> List[TargetVersion]:
156     """Compute the target versions from a --target-version flag.
157
158     This is its own function because mypy couldn't infer the type correctly
159     when it was a lambda, causing mypyc trouble.
160     """
161     return [TargetVersion[val.upper()] for val in v]
162
163
164 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
165     """Compile a regular expression string in `regex`.
166
167     If it contains newlines, use verbose mode.
168     """
169     if "\n" in regex:
170         regex = "(?x)" + regex
171     compiled: Pattern[str] = re.compile(regex)
172     return compiled
173
174
175 def validate_regex(
176     ctx: click.Context,
177     param: click.Parameter,
178     value: Optional[str],
179 ) -> Optional[Pattern[str]]:
180     try:
181         return re_compile_maybe_verbose(value) if value is not None else None
182     except re.error as e:
183         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
184
185
186 @click.command(
187     context_settings={"help_option_names": ["-h", "--help"]},
188     # While Click does set this field automatically using the docstring, mypyc
189     # (annoyingly) strips 'em so we need to set it here too.
190     help="The uncompromising code formatter.",
191 )
192 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
193 @click.option(
194     "-l",
195     "--line-length",
196     type=int,
197     default=DEFAULT_LINE_LENGTH,
198     help="How many characters per line to allow.",
199     show_default=True,
200 )
201 @click.option(
202     "-t",
203     "--target-version",
204     type=click.Choice([v.name.lower() for v in TargetVersion]),
205     callback=target_version_option_callback,
206     multiple=True,
207     help=(
208         "Python versions that should be supported by Black's output. [default: per-file"
209         " auto-detection]"
210     ),
211 )
212 @click.option(
213     "--pyi",
214     is_flag=True,
215     help=(
216         "Format all input files like typing stubs regardless of file extension (useful"
217         " when piping source on standard input)."
218     ),
219 )
220 @click.option(
221     "--ipynb",
222     is_flag=True,
223     help=(
224         "Format all input files like Jupyter Notebooks regardless of file extension "
225         "(useful when piping source on standard input)."
226     ),
227 )
228 @click.option(
229     "-S",
230     "--skip-string-normalization",
231     is_flag=True,
232     help="Don't normalize string quotes or prefixes.",
233 )
234 @click.option(
235     "-C",
236     "--skip-magic-trailing-comma",
237     is_flag=True,
238     help="Don't use trailing commas as a reason to split lines.",
239 )
240 @click.option(
241     "--experimental-string-processing",
242     is_flag=True,
243     hidden=True,
244     help=(
245         "Experimental option that performs more normalization on string literals."
246         " Currently disabled because it leads to some crashes."
247     ),
248 )
249 @click.option(
250     "--check",
251     is_flag=True,
252     help=(
253         "Don't write the files back, just return the status. Return code 0 means"
254         " nothing would change. Return code 1 means some files would be reformatted."
255         " Return code 123 means there was an internal error."
256     ),
257 )
258 @click.option(
259     "--diff",
260     is_flag=True,
261     help="Don't write the files back, just output a diff for each file on stdout.",
262 )
263 @click.option(
264     "--color/--no-color",
265     is_flag=True,
266     help="Show colored diff. Only applies when `--diff` is given.",
267 )
268 @click.option(
269     "--fast/--safe",
270     is_flag=True,
271     help="If --fast given, skip temporary sanity checks. [default: --safe]",
272 )
273 @click.option(
274     "--required-version",
275     type=str,
276     help=(
277         "Require a specific version of Black to be running (useful for unifying results"
278         " across many environments e.g. with a pyproject.toml file)."
279     ),
280 )
281 @click.option(
282     "--include",
283     type=str,
284     default=DEFAULT_INCLUDES,
285     callback=validate_regex,
286     help=(
287         "A regular expression that matches files and directories that should be"
288         " included on recursive searches. An empty value means all files are included"
289         " regardless of the name. Use forward slashes for directories on all platforms"
290         " (Windows, too). Exclusions are calculated first, inclusions later."
291     ),
292     show_default=True,
293 )
294 @click.option(
295     "--exclude",
296     type=str,
297     callback=validate_regex,
298     help=(
299         "A regular expression that matches files and directories that should be"
300         " excluded on recursive searches. An empty value means no paths are excluded."
301         " Use forward slashes for directories on all platforms (Windows, too)."
302         " Exclusions are calculated first, inclusions later. [default:"
303         f" {DEFAULT_EXCLUDES}]"
304     ),
305     show_default=False,
306 )
307 @click.option(
308     "--extend-exclude",
309     type=str,
310     callback=validate_regex,
311     help=(
312         "Like --exclude, but adds additional files and directories on top of the"
313         " excluded ones. (Useful if you simply want to add to the default)"
314     ),
315 )
316 @click.option(
317     "--force-exclude",
318     type=str,
319     callback=validate_regex,
320     help=(
321         "Like --exclude, but files and directories matching this regex will be "
322         "excluded even when they are passed explicitly as arguments."
323     ),
324 )
325 @click.option(
326     "--stdin-filename",
327     type=str,
328     help=(
329         "The name of the file when passing it through stdin. Useful to make "
330         "sure Black will respect --force-exclude option on some "
331         "editors that rely on using stdin."
332     ),
333 )
334 @click.option(
335     "-W",
336     "--workers",
337     type=click.IntRange(min=1),
338     default=DEFAULT_WORKERS,
339     show_default=True,
340     help="Number of parallel workers",
341 )
342 @click.option(
343     "-q",
344     "--quiet",
345     is_flag=True,
346     help=(
347         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
348         " those with 2>/dev/null."
349     ),
350 )
351 @click.option(
352     "-v",
353     "--verbose",
354     is_flag=True,
355     help=(
356         "Also emit messages to stderr about files that were not changed or were ignored"
357         " due to exclusion patterns."
358     ),
359 )
360 @click.version_option(
361     version=__version__,
362     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
363 )
364 @click.argument(
365     "src",
366     nargs=-1,
367     type=click.Path(
368         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
369     ),
370     is_eager=True,
371     metavar="SRC ...",
372 )
373 @click.option(
374     "--config",
375     type=click.Path(
376         exists=True,
377         file_okay=True,
378         dir_okay=False,
379         readable=True,
380         allow_dash=False,
381         path_type=str,
382     ),
383     is_eager=True,
384     callback=read_pyproject_toml,
385     help="Read configuration from FILE path.",
386 )
387 @click.pass_context
388 def main(
389     ctx: click.Context,
390     code: Optional[str],
391     line_length: int,
392     target_version: List[TargetVersion],
393     check: bool,
394     diff: bool,
395     color: bool,
396     fast: bool,
397     pyi: bool,
398     ipynb: bool,
399     skip_string_normalization: bool,
400     skip_magic_trailing_comma: bool,
401     experimental_string_processing: bool,
402     quiet: bool,
403     verbose: bool,
404     required_version: Optional[str],
405     include: Pattern[str],
406     exclude: Optional[Pattern[str]],
407     extend_exclude: Optional[Pattern[str]],
408     force_exclude: Optional[Pattern[str]],
409     stdin_filename: Optional[str],
410     workers: int,
411     src: Tuple[str, ...],
412     config: Optional[str],
413 ) -> None:
414     """The uncompromising code formatter."""
415     ctx.ensure_object(dict)
416     root, method = find_project_root(src) if code is None else (None, None)
417     ctx.obj["root"] = root
418
419     if verbose:
420         if root:
421             out(
422                 f"Identified `{root}` as project root containing a {method}.",
423                 fg="blue",
424             )
425
426             normalized = [
427                 (normalize_path_maybe_ignore(Path(source), root), source)
428                 for source in src
429             ]
430             srcs_string = ", ".join(
431                 [
432                     f'"{_norm}"'
433                     if _norm
434                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
435                     for _norm, source in normalized
436                 ]
437             )
438             out(f"Sources to be formatted: {srcs_string}", fg="blue")
439
440         if config:
441             config_source = ctx.get_parameter_source("config")
442             if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP):
443                 out("Using configuration from project root.", fg="blue")
444             else:
445                 out(f"Using configuration in '{config}'.", fg="blue")
446
447     error_msg = "Oh no! 💥 💔 💥"
448     if required_version and required_version != __version__:
449         err(
450             f"{error_msg} The required version `{required_version}` does not match"
451             f" the running version `{__version__}`!"
452         )
453         ctx.exit(1)
454     if ipynb and pyi:
455         err("Cannot pass both `pyi` and `ipynb` flags!")
456         ctx.exit(1)
457
458     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
459     if target_version:
460         versions = set(target_version)
461     else:
462         # We'll autodetect later.
463         versions = set()
464     mode = Mode(
465         target_versions=versions,
466         line_length=line_length,
467         is_pyi=pyi,
468         is_ipynb=ipynb,
469         string_normalization=not skip_string_normalization,
470         magic_trailing_comma=not skip_magic_trailing_comma,
471         experimental_string_processing=experimental_string_processing,
472     )
473
474     if code is not None:
475         # Run in quiet mode by default with -c; the extra output isn't useful.
476         # You can still pass -v to get verbose output.
477         quiet = True
478
479     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
480
481     if code is not None:
482         reformat_code(
483             content=code, fast=fast, write_back=write_back, mode=mode, report=report
484         )
485     else:
486         try:
487             sources = get_sources(
488                 ctx=ctx,
489                 src=src,
490                 quiet=quiet,
491                 verbose=verbose,
492                 include=include,
493                 exclude=exclude,
494                 extend_exclude=extend_exclude,
495                 force_exclude=force_exclude,
496                 report=report,
497                 stdin_filename=stdin_filename,
498             )
499         except GitWildMatchPatternError:
500             ctx.exit(1)
501
502         path_empty(
503             sources,
504             "No Python files are present to be formatted. Nothing to do 😴",
505             quiet,
506             verbose,
507             ctx,
508         )
509
510         if len(sources) == 1:
511             reformat_one(
512                 src=sources.pop(),
513                 fast=fast,
514                 write_back=write_back,
515                 mode=mode,
516                 report=report,
517             )
518         else:
519             reformat_many(
520                 sources=sources,
521                 fast=fast,
522                 write_back=write_back,
523                 mode=mode,
524                 report=report,
525                 workers=workers,
526             )
527
528     if verbose or not quiet:
529         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
530         if code is None:
531             click.echo(str(report), err=True)
532     ctx.exit(report.return_code)
533
534
535 def get_sources(
536     *,
537     ctx: click.Context,
538     src: Tuple[str, ...],
539     quiet: bool,
540     verbose: bool,
541     include: Pattern[str],
542     exclude: Optional[Pattern[str]],
543     extend_exclude: Optional[Pattern[str]],
544     force_exclude: Optional[Pattern[str]],
545     report: "Report",
546     stdin_filename: Optional[str],
547 ) -> Set[Path]:
548     """Compute the set of files to be formatted."""
549     sources: Set[Path] = set()
550     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
551
552     if exclude is None:
553         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
554         gitignore = get_gitignore(ctx.obj["root"])
555     else:
556         gitignore = None
557
558     for s in src:
559         if s == "-" and stdin_filename:
560             p = Path(stdin_filename)
561             is_stdin = True
562         else:
563             p = Path(s)
564             is_stdin = False
565
566         if is_stdin or p.is_file():
567             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
568             if normalized_path is None:
569                 continue
570
571             normalized_path = "/" + normalized_path
572             # Hard-exclude any files that matches the `--force-exclude` regex.
573             if force_exclude:
574                 force_exclude_match = force_exclude.search(normalized_path)
575             else:
576                 force_exclude_match = None
577             if force_exclude_match and force_exclude_match.group(0):
578                 report.path_ignored(p, "matches the --force-exclude regular expression")
579                 continue
580
581             if is_stdin:
582                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
583
584             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
585                 verbose=verbose, quiet=quiet
586             ):
587                 continue
588
589             sources.add(p)
590         elif p.is_dir():
591             sources.update(
592                 gen_python_files(
593                     p.iterdir(),
594                     ctx.obj["root"],
595                     include,
596                     exclude,
597                     extend_exclude,
598                     force_exclude,
599                     report,
600                     gitignore,
601                     verbose=verbose,
602                     quiet=quiet,
603                 )
604             )
605         elif s == "-":
606             sources.add(p)
607         else:
608             err(f"invalid path: {s}")
609     return sources
610
611
612 def path_empty(
613     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
614 ) -> None:
615     """
616     Exit if there is no `src` provided for formatting
617     """
618     if not src:
619         if verbose or not quiet:
620             out(msg)
621         ctx.exit(0)
622
623
624 def reformat_code(
625     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
626 ) -> None:
627     """
628     Reformat and print out `content` without spawning child processes.
629     Similar to `reformat_one`, but for string content.
630
631     `fast`, `write_back`, and `mode` options are passed to
632     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
633     """
634     path = Path("<string>")
635     try:
636         changed = Changed.NO
637         if format_stdin_to_stdout(
638             content=content, fast=fast, write_back=write_back, mode=mode
639         ):
640             changed = Changed.YES
641         report.done(path, changed)
642     except Exception as exc:
643         if report.verbose:
644             traceback.print_exc()
645         report.failed(path, str(exc))
646
647
648 def reformat_one(
649     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
650 ) -> None:
651     """Reformat a single file under `src` without spawning child processes.
652
653     `fast`, `write_back`, and `mode` options are passed to
654     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
655     """
656     try:
657         changed = Changed.NO
658
659         if str(src) == "-":
660             is_stdin = True
661         elif str(src).startswith(STDIN_PLACEHOLDER):
662             is_stdin = True
663             # Use the original name again in case we want to print something
664             # to the user
665             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
666         else:
667             is_stdin = False
668
669         if is_stdin:
670             if src.suffix == ".pyi":
671                 mode = replace(mode, is_pyi=True)
672             elif src.suffix == ".ipynb":
673                 mode = replace(mode, is_ipynb=True)
674             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
675                 changed = Changed.YES
676         else:
677             cache: Cache = {}
678             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
679                 cache = read_cache(mode)
680                 res_src = src.resolve()
681                 res_src_s = str(res_src)
682                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
683                     changed = Changed.CACHED
684             if changed is not Changed.CACHED and format_file_in_place(
685                 src, fast=fast, write_back=write_back, mode=mode
686             ):
687                 changed = Changed.YES
688             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
689                 write_back is WriteBack.CHECK and changed is Changed.NO
690             ):
691                 write_cache(cache, [src], mode)
692         report.done(src, changed)
693     except Exception as exc:
694         if report.verbose:
695             traceback.print_exc()
696         report.failed(src, str(exc))
697
698
699 # diff-shades depends on being to monkeypatch this function to operate. I know it's
700 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
701 @mypyc_attr(patchable=True)
702 def reformat_many(
703     sources: Set[Path],
704     fast: bool,
705     write_back: WriteBack,
706     mode: Mode,
707     report: "Report",
708     workers: Optional[int],
709 ) -> None:
710     """Reformat multiple files using a ProcessPoolExecutor."""
711     executor: Executor
712     loop = asyncio.get_event_loop()
713     worker_count = workers if workers is not None else DEFAULT_WORKERS
714     if sys.platform == "win32":
715         # Work around https://bugs.python.org/issue26903
716         assert worker_count is not None
717         worker_count = min(worker_count, 60)
718     try:
719         executor = ProcessPoolExecutor(max_workers=worker_count)
720     except (ImportError, NotImplementedError, OSError):
721         # we arrive here if the underlying system does not support multi-processing
722         # like in AWS Lambda or Termux, in which case we gracefully fallback to
723         # a ThreadPoolExecutor with just a single worker (more workers would not do us
724         # any good due to the Global Interpreter Lock)
725         executor = ThreadPoolExecutor(max_workers=1)
726
727     try:
728         loop.run_until_complete(
729             schedule_formatting(
730                 sources=sources,
731                 fast=fast,
732                 write_back=write_back,
733                 mode=mode,
734                 report=report,
735                 loop=loop,
736                 executor=executor,
737             )
738         )
739     finally:
740         shutdown(loop)
741         if executor is not None:
742             executor.shutdown()
743
744
745 async def schedule_formatting(
746     sources: Set[Path],
747     fast: bool,
748     write_back: WriteBack,
749     mode: Mode,
750     report: "Report",
751     loop: asyncio.AbstractEventLoop,
752     executor: Executor,
753 ) -> None:
754     """Run formatting of `sources` in parallel using the provided `executor`.
755
756     (Use ProcessPoolExecutors for actual parallelism.)
757
758     `write_back`, `fast`, and `mode` options are passed to
759     :func:`format_file_in_place`.
760     """
761     cache: Cache = {}
762     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
763         cache = read_cache(mode)
764         sources, cached = filter_cached(cache, sources)
765         for src in sorted(cached):
766             report.done(src, Changed.CACHED)
767     if not sources:
768         return
769
770     cancelled = []
771     sources_to_cache = []
772     lock = None
773     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
774         # For diff output, we need locks to ensure we don't interleave output
775         # from different processes.
776         manager = Manager()
777         lock = manager.Lock()
778     tasks = {
779         asyncio.ensure_future(
780             loop.run_in_executor(
781                 executor, format_file_in_place, src, fast, mode, write_back, lock
782             )
783         ): src
784         for src in sorted(sources)
785     }
786     pending = tasks.keys()
787     try:
788         loop.add_signal_handler(signal.SIGINT, cancel, pending)
789         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
790     except NotImplementedError:
791         # There are no good alternatives for these on Windows.
792         pass
793     while pending:
794         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
795         for task in done:
796             src = tasks.pop(task)
797             if task.cancelled():
798                 cancelled.append(task)
799             elif task.exception():
800                 report.failed(src, str(task.exception()))
801             else:
802                 changed = Changed.YES if task.result() else Changed.NO
803                 # If the file was written back or was successfully checked as
804                 # well-formatted, store this information in the cache.
805                 if write_back is WriteBack.YES or (
806                     write_back is WriteBack.CHECK and changed is Changed.NO
807                 ):
808                     sources_to_cache.append(src)
809                 report.done(src, changed)
810     if cancelled:
811         if sys.version_info >= (3, 7):
812             await asyncio.gather(*cancelled, return_exceptions=True)
813         else:
814             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
815     if sources_to_cache:
816         write_cache(cache, sources_to_cache, mode)
817
818
819 def format_file_in_place(
820     src: Path,
821     fast: bool,
822     mode: Mode,
823     write_back: WriteBack = WriteBack.NO,
824     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
825 ) -> bool:
826     """Format file under `src` path. Return True if changed.
827
828     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
829     code to the file.
830     `mode` and `fast` options are passed to :func:`format_file_contents`.
831     """
832     if src.suffix == ".pyi":
833         mode = replace(mode, is_pyi=True)
834     elif src.suffix == ".ipynb":
835         mode = replace(mode, is_ipynb=True)
836
837     then = datetime.utcfromtimestamp(src.stat().st_mtime)
838     with open(src, "rb") as buf:
839         src_contents, encoding, newline = decode_bytes(buf.read())
840     try:
841         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
842     except NothingChanged:
843         return False
844     except JSONDecodeError:
845         raise ValueError(
846             f"File '{src}' cannot be parsed as valid Jupyter notebook."
847         ) from None
848
849     if write_back == WriteBack.YES:
850         with open(src, "w", encoding=encoding, newline=newline) as f:
851             f.write(dst_contents)
852     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
853         now = datetime.utcnow()
854         src_name = f"{src}\t{then} +0000"
855         dst_name = f"{src}\t{now} +0000"
856         if mode.is_ipynb:
857             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
858         else:
859             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
860
861         if write_back == WriteBack.COLOR_DIFF:
862             diff_contents = color_diff(diff_contents)
863
864         with lock or nullcontext():
865             f = io.TextIOWrapper(
866                 sys.stdout.buffer,
867                 encoding=encoding,
868                 newline=newline,
869                 write_through=True,
870             )
871             f = wrap_stream_for_windows(f)
872             f.write(diff_contents)
873             f.detach()
874
875     return True
876
877
878 def format_stdin_to_stdout(
879     fast: bool,
880     *,
881     content: Optional[str] = None,
882     write_back: WriteBack = WriteBack.NO,
883     mode: Mode,
884 ) -> bool:
885     """Format file on stdin. Return True if changed.
886
887     If content is None, it's read from sys.stdin.
888
889     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
890     write a diff to stdout. The `mode` argument is passed to
891     :func:`format_file_contents`.
892     """
893     then = datetime.utcnow()
894
895     if content is None:
896         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
897     else:
898         src, encoding, newline = content, "utf-8", ""
899
900     dst = src
901     try:
902         dst = format_file_contents(src, fast=fast, mode=mode)
903         return True
904
905     except NothingChanged:
906         return False
907
908     finally:
909         f = io.TextIOWrapper(
910             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
911         )
912         if write_back == WriteBack.YES:
913             # Make sure there's a newline after the content
914             if dst and dst[-1] != "\n":
915                 dst += "\n"
916             f.write(dst)
917         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
918             now = datetime.utcnow()
919             src_name = f"STDIN\t{then} +0000"
920             dst_name = f"STDOUT\t{now} +0000"
921             d = diff(src, dst, src_name, dst_name)
922             if write_back == WriteBack.COLOR_DIFF:
923                 d = color_diff(d)
924                 f = wrap_stream_for_windows(f)
925             f.write(d)
926         f.detach()
927
928
929 def check_stability_and_equivalence(
930     src_contents: str, dst_contents: str, *, mode: Mode
931 ) -> None:
932     """Perform stability and equivalence checks.
933
934     Raise AssertionError if source and destination contents are not
935     equivalent, or if a second pass of the formatter would format the
936     content differently.
937     """
938     assert_equivalent(src_contents, dst_contents)
939
940     # Forced second pass to work around optional trailing commas (becoming
941     # forced trailing commas on pass 2) interacting differently with optional
942     # parentheses.  Admittedly ugly.
943     dst_contents_pass2 = format_str(dst_contents, mode=mode)
944     if dst_contents != dst_contents_pass2:
945         dst_contents = dst_contents_pass2
946         assert_equivalent(src_contents, dst_contents, pass_num=2)
947         assert_stable(src_contents, dst_contents, mode=mode)
948     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
949     # the same as `dst_contents_pass2`.
950
951
952 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
953     """Reformat contents of a file and return new contents.
954
955     If `fast` is False, additionally confirm that the reformatted code is
956     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
957     `mode` is passed to :func:`format_str`.
958     """
959     if not src_contents.strip():
960         raise NothingChanged
961
962     if mode.is_ipynb:
963         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
964     else:
965         dst_contents = format_str(src_contents, mode=mode)
966     if src_contents == dst_contents:
967         raise NothingChanged
968
969     if not fast and not mode.is_ipynb:
970         # Jupyter notebooks will already have been checked above.
971         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
972     return dst_contents
973
974
975 def validate_cell(src: str) -> None:
976     """Check that cell does not already contain TransformerManager transformations,
977     or non-Python cell magics, which might cause tokenizer_rt to break because of
978     indentations.
979
980     If a cell contains ``!ls``, then it'll be transformed to
981     ``get_ipython().system('ls')``. However, if the cell originally contained
982     ``get_ipython().system('ls')``, then it would get transformed in the same way:
983
984         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
985         "get_ipython().system('ls')\n"
986         >>> TransformerManager().transform_cell("!ls")
987         "get_ipython().system('ls')\n"
988
989     Due to the impossibility of safely roundtripping in such situations, cells
990     containing transformed magics will be ignored.
991     """
992     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
993         raise NothingChanged
994     if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS:
995         raise NothingChanged
996
997
998 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
999     """Format code in given cell of Jupyter notebook.
1000
1001     General idea is:
1002
1003       - if cell has trailing semicolon, remove it;
1004       - if cell has IPython magics, mask them;
1005       - format cell;
1006       - reinstate IPython magics;
1007       - reinstate trailing semicolon (if originally present);
1008       - strip trailing newlines.
1009
1010     Cells with syntax errors will not be processed, as they
1011     could potentially be automagics or multi-line magics, which
1012     are currently not supported.
1013     """
1014     validate_cell(src)
1015     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
1016         src
1017     )
1018     try:
1019         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
1020     except SyntaxError:
1021         raise NothingChanged from None
1022     masked_dst = format_str(masked_src, mode=mode)
1023     if not fast:
1024         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
1025     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1026     dst = put_trailing_semicolon_back(
1027         dst_without_trailing_semicolon, has_trailing_semicolon
1028     )
1029     dst = dst.rstrip("\n")
1030     if dst == src:
1031         raise NothingChanged from None
1032     return dst
1033
1034
1035 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1036     """If notebook is marked as non-Python, don't format it.
1037
1038     All notebook metadata fields are optional, see
1039     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1040     if a notebook has empty metadata, we will try to parse it anyway.
1041     """
1042     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1043     if language is not None and language != "python":
1044         raise NothingChanged from None
1045
1046
1047 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1048     """Format Jupyter notebook.
1049
1050     Operate cell-by-cell, only on code cells, only for Python notebooks.
1051     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1052     """
1053     trailing_newline = src_contents[-1] == "\n"
1054     modified = False
1055     nb = json.loads(src_contents)
1056     validate_metadata(nb)
1057     for cell in nb["cells"]:
1058         if cell.get("cell_type", None) == "code":
1059             try:
1060                 src = "".join(cell["source"])
1061                 dst = format_cell(src, fast=fast, mode=mode)
1062             except NothingChanged:
1063                 pass
1064             else:
1065                 cell["source"] = dst.splitlines(keepends=True)
1066                 modified = True
1067     if modified:
1068         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1069         if trailing_newline:
1070             dst_contents = dst_contents + "\n"
1071         return dst_contents
1072     else:
1073         raise NothingChanged
1074
1075
1076 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1077     """Reformat a string and return new contents.
1078
1079     `mode` determines formatting options, such as how many characters per line are
1080     allowed.  Example:
1081
1082     >>> import black
1083     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1084     def f(arg: str = "") -> None:
1085         ...
1086
1087     A more complex example:
1088
1089     >>> print(
1090     ...   black.format_str(
1091     ...     "def f(arg:str='')->None: hey",
1092     ...     mode=black.Mode(
1093     ...       target_versions={black.TargetVersion.PY36},
1094     ...       line_length=10,
1095     ...       string_normalization=False,
1096     ...       is_pyi=False,
1097     ...     ),
1098     ...   ),
1099     ... )
1100     def f(
1101         arg: str = '',
1102     ) -> None:
1103         hey
1104
1105     """
1106     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1107     dst_contents = []
1108     future_imports = get_future_imports(src_node)
1109     if mode.target_versions:
1110         versions = mode.target_versions
1111     else:
1112         versions = detect_target_versions(src_node, future_imports=future_imports)
1113
1114     normalize_fmt_off(src_node)
1115     lines = LineGenerator(mode=mode)
1116     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1117     empty_line = Line(mode=mode)
1118     after = 0
1119     split_line_features = {
1120         feature
1121         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1122         if supports_feature(versions, feature)
1123     }
1124     for current_line in lines.visit(src_node):
1125         dst_contents.append(str(empty_line) * after)
1126         before, after = elt.maybe_empty_lines(current_line)
1127         dst_contents.append(str(empty_line) * before)
1128         for line in transform_line(
1129             current_line, mode=mode, features=split_line_features
1130         ):
1131             dst_contents.append(str(line))
1132     return "".join(dst_contents)
1133
1134
1135 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1136     """Return a tuple of (decoded_contents, encoding, newline).
1137
1138     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1139     universal newlines (i.e. only contains LF).
1140     """
1141     srcbuf = io.BytesIO(src)
1142     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1143     if not lines:
1144         return "", encoding, "\n"
1145
1146     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1147     srcbuf.seek(0)
1148     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1149         return tiow.read(), encoding, newline
1150
1151
1152 def get_features_used(  # noqa: C901
1153     node: Node, *, future_imports: Optional[Set[str]] = None
1154 ) -> Set[Feature]:
1155     """Return a set of (relatively) new Python features used in this file.
1156
1157     Currently looking for:
1158     - f-strings;
1159     - underscores in numeric literals;
1160     - trailing commas after * or ** in function signatures and calls;
1161     - positional only arguments in function signatures and lambdas;
1162     - assignment expression;
1163     - relaxed decorator syntax;
1164     - usage of __future__ flags (annotations);
1165     - print / exec statements;
1166     """
1167     features: Set[Feature] = set()
1168     if future_imports:
1169         features |= {
1170             FUTURE_FLAG_TO_FEATURE[future_import]
1171             for future_import in future_imports
1172             if future_import in FUTURE_FLAG_TO_FEATURE
1173         }
1174
1175     for n in node.pre_order():
1176         if is_string_token(n):
1177             value_head = n.value[:2]
1178             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1179                 features.add(Feature.F_STRINGS)
1180
1181         elif n.type == token.NUMBER:
1182             assert isinstance(n, Leaf)
1183             if "_" in n.value:
1184                 features.add(Feature.NUMERIC_UNDERSCORES)
1185
1186         elif n.type == token.SLASH:
1187             if n.parent and n.parent.type in {
1188                 syms.typedargslist,
1189                 syms.arglist,
1190                 syms.varargslist,
1191             }:
1192                 features.add(Feature.POS_ONLY_ARGUMENTS)
1193
1194         elif n.type == token.COLONEQUAL:
1195             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1196
1197         elif n.type == syms.decorator:
1198             if len(n.children) > 1 and not is_simple_decorator_expression(
1199                 n.children[1]
1200             ):
1201                 features.add(Feature.RELAXED_DECORATORS)
1202
1203         elif (
1204             n.type in {syms.typedargslist, syms.arglist}
1205             and n.children
1206             and n.children[-1].type == token.COMMA
1207         ):
1208             if n.type == syms.typedargslist:
1209                 feature = Feature.TRAILING_COMMA_IN_DEF
1210             else:
1211                 feature = Feature.TRAILING_COMMA_IN_CALL
1212
1213             for ch in n.children:
1214                 if ch.type in STARS:
1215                     features.add(feature)
1216
1217                 if ch.type == syms.argument:
1218                     for argch in ch.children:
1219                         if argch.type in STARS:
1220                             features.add(feature)
1221
1222         elif (
1223             n.type in {syms.return_stmt, syms.yield_expr}
1224             and len(n.children) >= 2
1225             and n.children[1].type == syms.testlist_star_expr
1226             and any(child.type == syms.star_expr for child in n.children[1].children)
1227         ):
1228             features.add(Feature.UNPACKING_ON_FLOW)
1229
1230         elif (
1231             n.type == syms.annassign
1232             and len(n.children) >= 4
1233             and n.children[3].type == syms.testlist_star_expr
1234         ):
1235             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1236
1237     return features
1238
1239
1240 def detect_target_versions(
1241     node: Node, *, future_imports: Optional[Set[str]] = None
1242 ) -> Set[TargetVersion]:
1243     """Detect the version to target based on the nodes used."""
1244     features = get_features_used(node, future_imports=future_imports)
1245     return {
1246         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1247     }
1248
1249
1250 def get_future_imports(node: Node) -> Set[str]:
1251     """Return a set of __future__ imports in the file."""
1252     imports: Set[str] = set()
1253
1254     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1255         for child in children:
1256             if isinstance(child, Leaf):
1257                 if child.type == token.NAME:
1258                     yield child.value
1259
1260             elif child.type == syms.import_as_name:
1261                 orig_name = child.children[0]
1262                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1263                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1264                 yield orig_name.value
1265
1266             elif child.type == syms.import_as_names:
1267                 yield from get_imports_from_children(child.children)
1268
1269             else:
1270                 raise AssertionError("Invalid syntax parsing imports")
1271
1272     for child in node.children:
1273         if child.type != syms.simple_stmt:
1274             break
1275
1276         first_child = child.children[0]
1277         if isinstance(first_child, Leaf):
1278             # Continue looking if we see a docstring; otherwise stop.
1279             if (
1280                 len(child.children) == 2
1281                 and first_child.type == token.STRING
1282                 and child.children[1].type == token.NEWLINE
1283             ):
1284                 continue
1285
1286             break
1287
1288         elif first_child.type == syms.import_from:
1289             module_name = first_child.children[1]
1290             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1291                 break
1292
1293             imports |= set(get_imports_from_children(first_child.children[3:]))
1294         else:
1295             break
1296
1297     return imports
1298
1299
1300 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1301     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1302     try:
1303         src_ast = parse_ast(src)
1304     except Exception as exc:
1305         raise AssertionError(
1306             f"cannot use --safe with this file; failed to parse source file: {exc}"
1307         ) from exc
1308
1309     try:
1310         dst_ast = parse_ast(dst)
1311     except Exception as exc:
1312         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1313         raise AssertionError(
1314             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1315             "Please report a bug on https://github.com/psf/black/issues.  "
1316             f"This invalid output might be helpful: {log}"
1317         ) from None
1318
1319     src_ast_str = "\n".join(stringify_ast(src_ast))
1320     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1321     if src_ast_str != dst_ast_str:
1322         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1323         raise AssertionError(
1324             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1325             f" source on pass {pass_num}.  Please report a bug on "
1326             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1327         ) from None
1328
1329
1330 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1331     """Raise AssertionError if `dst` reformats differently the second time."""
1332     newdst = format_str(dst, mode=mode)
1333     if dst != newdst:
1334         log = dump_to_file(
1335             str(mode),
1336             diff(src, dst, "source", "first pass"),
1337             diff(dst, newdst, "first pass", "second pass"),
1338         )
1339         raise AssertionError(
1340             "INTERNAL ERROR: Black produced different code on the second pass of the"
1341             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1342             f"  This diff might be helpful: {log}"
1343         ) from None
1344
1345
1346 @contextmanager
1347 def nullcontext() -> Iterator[None]:
1348     """Return an empty context manager.
1349
1350     To be used like `nullcontext` in Python 3.7.
1351     """
1352     yield
1353
1354
1355 def patch_click() -> None:
1356     """Make Click not crash on Python 3.6 with LANG=C.
1357
1358     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1359     default which restricts paths that it can access during the lifetime of the
1360     application.  Click refuses to work in this scenario by raising a RuntimeError.
1361
1362     In case of Black the likelihood that non-ASCII characters are going to be used in
1363     file paths is minimal since it's Python source code.  Moreover, this crash was
1364     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1365     """
1366     try:
1367         from click import core
1368         from click import _unicodefun
1369     except ModuleNotFoundError:
1370         return
1371
1372     for module in (core, _unicodefun):
1373         if hasattr(module, "_verify_python3_env"):
1374             module._verify_python3_env = lambda: None  # type: ignore
1375         if hasattr(module, "_verify_python_env"):
1376             module._verify_python_env = lambda: None  # type: ignore
1377
1378
1379 def patched_main() -> None:
1380     maybe_install_uvloop()
1381     freeze_support()
1382     patch_click()
1383     main()
1384
1385
1386 if __name__ == "__main__":
1387     patched_main()