]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

c2b52e6eadb07936e41d53c59f61abf37bbb1e9d
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import regex as re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 import click
34 from dataclasses import replace
35 from mypy_extensions import mypyc_attr
36
37 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
38 from black.const import STDIN_PLACEHOLDER
39 from black.nodes import STARS, syms, is_simple_decorator_expression
40 from black.lines import Line, EmptyLineTracker
41 from black.linegen import transform_line, LineGenerator, LN
42 from black.comments import normalize_fmt_off
43 from black.mode import Mode, TargetVersion
44 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
45 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
46 from black.concurrency import cancel, shutdown, maybe_install_uvloop
47 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
48 from black.report import Report, Changed, NothingChanged
49 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
50 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
51 from black.files import wrap_stream_for_windows
52 from black.parsing import InvalidInput  # noqa F401
53 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
54 from black.handle_ipynb_magics import (
55     mask_cell,
56     unmask_cell,
57     remove_trailing_semicolon,
58     put_trailing_semicolon_back,
59     TRANSFORMED_MAGICS,
60     PYTHON_CELL_MAGICS,
61     jupyter_dependencies_are_installed,
62 )
63
64
65 # lib2to3 fork
66 from blib2to3.pytree import Node, Leaf
67 from blib2to3.pgen2 import token
68
69 from _black_version import version as __version__
70
71 COMPILED = Path(__file__).suffix in (".pyd", ".so")
72
73 # types
74 FileContent = str
75 Encoding = str
76 NewLine = str
77
78
79 class WriteBack(Enum):
80     NO = 0
81     YES = 1
82     DIFF = 2
83     CHECK = 3
84     COLOR_DIFF = 4
85
86     @classmethod
87     def from_configuration(
88         cls, *, check: bool, diff: bool, color: bool = False
89     ) -> "WriteBack":
90         if check and not diff:
91             return cls.CHECK
92
93         if diff and color:
94             return cls.COLOR_DIFF
95
96         return cls.DIFF if diff else cls.YES
97
98
99 # Legacy name, left for integrations.
100 FileMode = Mode
101
102 DEFAULT_WORKERS = os.cpu_count()
103
104
105 def read_pyproject_toml(
106     ctx: click.Context, param: click.Parameter, value: Optional[str]
107 ) -> Optional[str]:
108     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
109
110     Returns the path to a successfully found and read configuration file, None
111     otherwise.
112     """
113     if not value:
114         value = find_pyproject_toml(ctx.params.get("src", ()))
115         if value is None:
116             return None
117
118     try:
119         config = parse_pyproject_toml(value)
120     except (OSError, ValueError) as e:
121         raise click.FileError(
122             filename=value, hint=f"Error reading configuration file: {e}"
123         ) from None
124
125     if not config:
126         return None
127     else:
128         # Sanitize the values to be Click friendly. For more information please see:
129         # https://github.com/psf/black/issues/1458
130         # https://github.com/pallets/click/issues/1567
131         config = {
132             k: str(v) if not isinstance(v, (list, dict)) else v
133             for k, v in config.items()
134         }
135
136     target_version = config.get("target_version")
137     if target_version is not None and not isinstance(target_version, list):
138         raise click.BadOptionUsage(
139             "target-version", "Config key target-version must be a list"
140         )
141
142     default_map: Dict[str, Any] = {}
143     if ctx.default_map:
144         default_map.update(ctx.default_map)
145     default_map.update(config)
146
147     ctx.default_map = default_map
148     return value
149
150
151 def target_version_option_callback(
152     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
153 ) -> List[TargetVersion]:
154     """Compute the target versions from a --target-version flag.
155
156     This is its own function because mypy couldn't infer the type correctly
157     when it was a lambda, causing mypyc trouble.
158     """
159     return [TargetVersion[val.upper()] for val in v]
160
161
162 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
163     """Compile a regular expression string in `regex`.
164
165     If it contains newlines, use verbose mode.
166     """
167     if "\n" in regex:
168         regex = "(?x)" + regex
169     compiled: Pattern[str] = re.compile(regex)
170     return compiled
171
172
173 def validate_regex(
174     ctx: click.Context,
175     param: click.Parameter,
176     value: Optional[str],
177 ) -> Optional[Pattern[str]]:
178     try:
179         return re_compile_maybe_verbose(value) if value is not None else None
180     except re.error:
181         raise click.BadParameter("Not a valid regular expression") from None
182
183
184 @click.command(
185     context_settings={"help_option_names": ["-h", "--help"]},
186     # While Click does set this field automatically using the docstring, mypyc
187     # (annoyingly) strips 'em so we need to set it here too.
188     help="The uncompromising code formatter.",
189 )
190 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
191 @click.option(
192     "-l",
193     "--line-length",
194     type=int,
195     default=DEFAULT_LINE_LENGTH,
196     help="How many characters per line to allow.",
197     show_default=True,
198 )
199 @click.option(
200     "-t",
201     "--target-version",
202     type=click.Choice([v.name.lower() for v in TargetVersion]),
203     callback=target_version_option_callback,
204     multiple=True,
205     help=(
206         "Python versions that should be supported by Black's output. [default: per-file"
207         " auto-detection]"
208     ),
209 )
210 @click.option(
211     "--pyi",
212     is_flag=True,
213     help=(
214         "Format all input files like typing stubs regardless of file extension (useful"
215         " when piping source on standard input)."
216     ),
217 )
218 @click.option(
219     "--ipynb",
220     is_flag=True,
221     help=(
222         "Format all input files like Jupyter Notebooks regardless of file extension "
223         "(useful when piping source on standard input)."
224     ),
225 )
226 @click.option(
227     "-S",
228     "--skip-string-normalization",
229     is_flag=True,
230     help="Don't normalize string quotes or prefixes.",
231 )
232 @click.option(
233     "-C",
234     "--skip-magic-trailing-comma",
235     is_flag=True,
236     help="Don't use trailing commas as a reason to split lines.",
237 )
238 @click.option(
239     "--experimental-string-processing",
240     is_flag=True,
241     hidden=True,
242     help=(
243         "Experimental option that performs more normalization on string literals."
244         " Currently disabled because it leads to some crashes."
245     ),
246 )
247 @click.option(
248     "--check",
249     is_flag=True,
250     help=(
251         "Don't write the files back, just return the status. Return code 0 means"
252         " nothing would change. Return code 1 means some files would be reformatted."
253         " Return code 123 means there was an internal error."
254     ),
255 )
256 @click.option(
257     "--diff",
258     is_flag=True,
259     help="Don't write the files back, just output a diff for each file on stdout.",
260 )
261 @click.option(
262     "--color/--no-color",
263     is_flag=True,
264     help="Show colored diff. Only applies when `--diff` is given.",
265 )
266 @click.option(
267     "--fast/--safe",
268     is_flag=True,
269     help="If --fast given, skip temporary sanity checks. [default: --safe]",
270 )
271 @click.option(
272     "--required-version",
273     type=str,
274     help=(
275         "Require a specific version of Black to be running (useful for unifying results"
276         " across many environments e.g. with a pyproject.toml file)."
277     ),
278 )
279 @click.option(
280     "--include",
281     type=str,
282     default=DEFAULT_INCLUDES,
283     callback=validate_regex,
284     help=(
285         "A regular expression that matches files and directories that should be"
286         " included on recursive searches. An empty value means all files are included"
287         " regardless of the name. Use forward slashes for directories on all platforms"
288         " (Windows, too). Exclusions are calculated first, inclusions later."
289     ),
290     show_default=True,
291 )
292 @click.option(
293     "--exclude",
294     type=str,
295     callback=validate_regex,
296     help=(
297         "A regular expression that matches files and directories that should be"
298         " excluded on recursive searches. An empty value means no paths are excluded."
299         " Use forward slashes for directories on all platforms (Windows, too)."
300         " Exclusions are calculated first, inclusions later. [default:"
301         f" {DEFAULT_EXCLUDES}]"
302     ),
303     show_default=False,
304 )
305 @click.option(
306     "--extend-exclude",
307     type=str,
308     callback=validate_regex,
309     help=(
310         "Like --exclude, but adds additional files and directories on top of the"
311         " excluded ones. (Useful if you simply want to add to the default)"
312     ),
313 )
314 @click.option(
315     "--force-exclude",
316     type=str,
317     callback=validate_regex,
318     help=(
319         "Like --exclude, but files and directories matching this regex will be "
320         "excluded even when they are passed explicitly as arguments."
321     ),
322 )
323 @click.option(
324     "--stdin-filename",
325     type=str,
326     help=(
327         "The name of the file when passing it through stdin. Useful to make "
328         "sure Black will respect --force-exclude option on some "
329         "editors that rely on using stdin."
330     ),
331 )
332 @click.option(
333     "-W",
334     "--workers",
335     type=click.IntRange(min=1),
336     default=DEFAULT_WORKERS,
337     show_default=True,
338     help="Number of parallel workers",
339 )
340 @click.option(
341     "-q",
342     "--quiet",
343     is_flag=True,
344     help=(
345         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
346         " those with 2>/dev/null."
347     ),
348 )
349 @click.option(
350     "-v",
351     "--verbose",
352     is_flag=True,
353     help=(
354         "Also emit messages to stderr about files that were not changed or were ignored"
355         " due to exclusion patterns."
356     ),
357 )
358 @click.version_option(
359     version=__version__,
360     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
361 )
362 @click.argument(
363     "src",
364     nargs=-1,
365     type=click.Path(
366         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
367     ),
368     is_eager=True,
369     metavar="SRC ...",
370 )
371 @click.option(
372     "--config",
373     type=click.Path(
374         exists=True,
375         file_okay=True,
376         dir_okay=False,
377         readable=True,
378         allow_dash=False,
379         path_type=str,
380     ),
381     is_eager=True,
382     callback=read_pyproject_toml,
383     help="Read configuration from FILE path.",
384 )
385 @click.pass_context
386 def main(
387     ctx: click.Context,
388     code: Optional[str],
389     line_length: int,
390     target_version: List[TargetVersion],
391     check: bool,
392     diff: bool,
393     color: bool,
394     fast: bool,
395     pyi: bool,
396     ipynb: bool,
397     skip_string_normalization: bool,
398     skip_magic_trailing_comma: bool,
399     experimental_string_processing: bool,
400     quiet: bool,
401     verbose: bool,
402     required_version: Optional[str],
403     include: Pattern[str],
404     exclude: Optional[Pattern[str]],
405     extend_exclude: Optional[Pattern[str]],
406     force_exclude: Optional[Pattern[str]],
407     stdin_filename: Optional[str],
408     workers: int,
409     src: Tuple[str, ...],
410     config: Optional[str],
411 ) -> None:
412     """The uncompromising code formatter."""
413     if config and verbose:
414         out(f"Using configuration from {config}.", bold=False, fg="blue")
415
416     error_msg = "Oh no! 💥 💔 💥"
417     if required_version and required_version != __version__:
418         err(
419             f"{error_msg} The required version `{required_version}` does not match"
420             f" the running version `{__version__}`!"
421         )
422         ctx.exit(1)
423     if ipynb and pyi:
424         err("Cannot pass both `pyi` and `ipynb` flags!")
425         ctx.exit(1)
426
427     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
428     if target_version:
429         versions = set(target_version)
430     else:
431         # We'll autodetect later.
432         versions = set()
433     mode = Mode(
434         target_versions=versions,
435         line_length=line_length,
436         is_pyi=pyi,
437         is_ipynb=ipynb,
438         string_normalization=not skip_string_normalization,
439         magic_trailing_comma=not skip_magic_trailing_comma,
440         experimental_string_processing=experimental_string_processing,
441     )
442
443     if code is not None:
444         # Run in quiet mode by default with -c; the extra output isn't useful.
445         # You can still pass -v to get verbose output.
446         quiet = True
447
448     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
449
450     if code is not None:
451         reformat_code(
452             content=code, fast=fast, write_back=write_back, mode=mode, report=report
453         )
454     else:
455         try:
456             sources = get_sources(
457                 ctx=ctx,
458                 src=src,
459                 quiet=quiet,
460                 verbose=verbose,
461                 include=include,
462                 exclude=exclude,
463                 extend_exclude=extend_exclude,
464                 force_exclude=force_exclude,
465                 report=report,
466                 stdin_filename=stdin_filename,
467             )
468         except GitWildMatchPatternError:
469             ctx.exit(1)
470
471         path_empty(
472             sources,
473             "No Python files are present to be formatted. Nothing to do 😴",
474             quiet,
475             verbose,
476             ctx,
477         )
478
479         if len(sources) == 1:
480             reformat_one(
481                 src=sources.pop(),
482                 fast=fast,
483                 write_back=write_back,
484                 mode=mode,
485                 report=report,
486             )
487         else:
488             reformat_many(
489                 sources=sources,
490                 fast=fast,
491                 write_back=write_back,
492                 mode=mode,
493                 report=report,
494                 workers=workers,
495             )
496
497     if verbose or not quiet:
498         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
499         if code is None:
500             click.echo(str(report), err=True)
501     ctx.exit(report.return_code)
502
503
504 def get_sources(
505     *,
506     ctx: click.Context,
507     src: Tuple[str, ...],
508     quiet: bool,
509     verbose: bool,
510     include: Pattern[str],
511     exclude: Optional[Pattern[str]],
512     extend_exclude: Optional[Pattern[str]],
513     force_exclude: Optional[Pattern[str]],
514     report: "Report",
515     stdin_filename: Optional[str],
516 ) -> Set[Path]:
517     """Compute the set of files to be formatted."""
518
519     root = find_project_root(src)
520     sources: Set[Path] = set()
521     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
522
523     if exclude is None:
524         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
525         gitignore = get_gitignore(root)
526     else:
527         gitignore = None
528
529     for s in src:
530         if s == "-" and stdin_filename:
531             p = Path(stdin_filename)
532             is_stdin = True
533         else:
534             p = Path(s)
535             is_stdin = False
536
537         if is_stdin or p.is_file():
538             normalized_path = normalize_path_maybe_ignore(p, root, report)
539             if normalized_path is None:
540                 continue
541
542             normalized_path = "/" + normalized_path
543             # Hard-exclude any files that matches the `--force-exclude` regex.
544             if force_exclude:
545                 force_exclude_match = force_exclude.search(normalized_path)
546             else:
547                 force_exclude_match = None
548             if force_exclude_match and force_exclude_match.group(0):
549                 report.path_ignored(p, "matches the --force-exclude regular expression")
550                 continue
551
552             if is_stdin:
553                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
554
555             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
556                 verbose=verbose, quiet=quiet
557             ):
558                 continue
559
560             sources.add(p)
561         elif p.is_dir():
562             sources.update(
563                 gen_python_files(
564                     p.iterdir(),
565                     root,
566                     include,
567                     exclude,
568                     extend_exclude,
569                     force_exclude,
570                     report,
571                     gitignore,
572                     verbose=verbose,
573                     quiet=quiet,
574                 )
575             )
576         elif s == "-":
577             sources.add(p)
578         else:
579             err(f"invalid path: {s}")
580     return sources
581
582
583 def path_empty(
584     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
585 ) -> None:
586     """
587     Exit if there is no `src` provided for formatting
588     """
589     if not src:
590         if verbose or not quiet:
591             out(msg)
592         ctx.exit(0)
593
594
595 def reformat_code(
596     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
597 ) -> None:
598     """
599     Reformat and print out `content` without spawning child processes.
600     Similar to `reformat_one`, but for string content.
601
602     `fast`, `write_back`, and `mode` options are passed to
603     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
604     """
605     path = Path("<string>")
606     try:
607         changed = Changed.NO
608         if format_stdin_to_stdout(
609             content=content, fast=fast, write_back=write_back, mode=mode
610         ):
611             changed = Changed.YES
612         report.done(path, changed)
613     except Exception as exc:
614         if report.verbose:
615             traceback.print_exc()
616         report.failed(path, str(exc))
617
618
619 def reformat_one(
620     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
621 ) -> None:
622     """Reformat a single file under `src` without spawning child processes.
623
624     `fast`, `write_back`, and `mode` options are passed to
625     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
626     """
627     try:
628         changed = Changed.NO
629
630         if str(src) == "-":
631             is_stdin = True
632         elif str(src).startswith(STDIN_PLACEHOLDER):
633             is_stdin = True
634             # Use the original name again in case we want to print something
635             # to the user
636             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
637         else:
638             is_stdin = False
639
640         if is_stdin:
641             if src.suffix == ".pyi":
642                 mode = replace(mode, is_pyi=True)
643             elif src.suffix == ".ipynb":
644                 mode = replace(mode, is_ipynb=True)
645             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
646                 changed = Changed.YES
647         else:
648             cache: Cache = {}
649             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
650                 cache = read_cache(mode)
651                 res_src = src.resolve()
652                 res_src_s = str(res_src)
653                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
654                     changed = Changed.CACHED
655             if changed is not Changed.CACHED and format_file_in_place(
656                 src, fast=fast, write_back=write_back, mode=mode
657             ):
658                 changed = Changed.YES
659             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
660                 write_back is WriteBack.CHECK and changed is Changed.NO
661             ):
662                 write_cache(cache, [src], mode)
663         report.done(src, changed)
664     except Exception as exc:
665         if report.verbose:
666             traceback.print_exc()
667         report.failed(src, str(exc))
668
669
670 # diff-shades depends on being to monkeypatch this function to operate. I know it's
671 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
672 @mypyc_attr(patchable=True)
673 def reformat_many(
674     sources: Set[Path],
675     fast: bool,
676     write_back: WriteBack,
677     mode: Mode,
678     report: "Report",
679     workers: Optional[int],
680 ) -> None:
681     """Reformat multiple files using a ProcessPoolExecutor."""
682     executor: Executor
683     loop = asyncio.get_event_loop()
684     worker_count = workers if workers is not None else DEFAULT_WORKERS
685     if sys.platform == "win32":
686         # Work around https://bugs.python.org/issue26903
687         assert worker_count is not None
688         worker_count = min(worker_count, 60)
689     try:
690         executor = ProcessPoolExecutor(max_workers=worker_count)
691     except (ImportError, NotImplementedError, OSError):
692         # we arrive here if the underlying system does not support multi-processing
693         # like in AWS Lambda or Termux, in which case we gracefully fallback to
694         # a ThreadPoolExecutor with just a single worker (more workers would not do us
695         # any good due to the Global Interpreter Lock)
696         executor = ThreadPoolExecutor(max_workers=1)
697
698     try:
699         loop.run_until_complete(
700             schedule_formatting(
701                 sources=sources,
702                 fast=fast,
703                 write_back=write_back,
704                 mode=mode,
705                 report=report,
706                 loop=loop,
707                 executor=executor,
708             )
709         )
710     finally:
711         shutdown(loop)
712         if executor is not None:
713             executor.shutdown()
714
715
716 async def schedule_formatting(
717     sources: Set[Path],
718     fast: bool,
719     write_back: WriteBack,
720     mode: Mode,
721     report: "Report",
722     loop: asyncio.AbstractEventLoop,
723     executor: Executor,
724 ) -> None:
725     """Run formatting of `sources` in parallel using the provided `executor`.
726
727     (Use ProcessPoolExecutors for actual parallelism.)
728
729     `write_back`, `fast`, and `mode` options are passed to
730     :func:`format_file_in_place`.
731     """
732     cache: Cache = {}
733     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
734         cache = read_cache(mode)
735         sources, cached = filter_cached(cache, sources)
736         for src in sorted(cached):
737             report.done(src, Changed.CACHED)
738     if not sources:
739         return
740
741     cancelled = []
742     sources_to_cache = []
743     lock = None
744     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
745         # For diff output, we need locks to ensure we don't interleave output
746         # from different processes.
747         manager = Manager()
748         lock = manager.Lock()
749     tasks = {
750         asyncio.ensure_future(
751             loop.run_in_executor(
752                 executor, format_file_in_place, src, fast, mode, write_back, lock
753             )
754         ): src
755         for src in sorted(sources)
756     }
757     pending = tasks.keys()
758     try:
759         loop.add_signal_handler(signal.SIGINT, cancel, pending)
760         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
761     except NotImplementedError:
762         # There are no good alternatives for these on Windows.
763         pass
764     while pending:
765         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
766         for task in done:
767             src = tasks.pop(task)
768             if task.cancelled():
769                 cancelled.append(task)
770             elif task.exception():
771                 report.failed(src, str(task.exception()))
772             else:
773                 changed = Changed.YES if task.result() else Changed.NO
774                 # If the file was written back or was successfully checked as
775                 # well-formatted, store this information in the cache.
776                 if write_back is WriteBack.YES or (
777                     write_back is WriteBack.CHECK and changed is Changed.NO
778                 ):
779                     sources_to_cache.append(src)
780                 report.done(src, changed)
781     if cancelled:
782         if sys.version_info >= (3, 7):
783             await asyncio.gather(*cancelled, return_exceptions=True)
784         else:
785             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
786     if sources_to_cache:
787         write_cache(cache, sources_to_cache, mode)
788
789
790 def format_file_in_place(
791     src: Path,
792     fast: bool,
793     mode: Mode,
794     write_back: WriteBack = WriteBack.NO,
795     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
796 ) -> bool:
797     """Format file under `src` path. Return True if changed.
798
799     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
800     code to the file.
801     `mode` and `fast` options are passed to :func:`format_file_contents`.
802     """
803     if src.suffix == ".pyi":
804         mode = replace(mode, is_pyi=True)
805     elif src.suffix == ".ipynb":
806         mode = replace(mode, is_ipynb=True)
807
808     then = datetime.utcfromtimestamp(src.stat().st_mtime)
809     with open(src, "rb") as buf:
810         src_contents, encoding, newline = decode_bytes(buf.read())
811     try:
812         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
813     except NothingChanged:
814         return False
815     except JSONDecodeError:
816         raise ValueError(
817             f"File '{src}' cannot be parsed as valid Jupyter notebook."
818         ) from None
819
820     if write_back == WriteBack.YES:
821         with open(src, "w", encoding=encoding, newline=newline) as f:
822             f.write(dst_contents)
823     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
824         now = datetime.utcnow()
825         src_name = f"{src}\t{then} +0000"
826         dst_name = f"{src}\t{now} +0000"
827         if mode.is_ipynb:
828             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
829         else:
830             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
831
832         if write_back == WriteBack.COLOR_DIFF:
833             diff_contents = color_diff(diff_contents)
834
835         with lock or nullcontext():
836             f = io.TextIOWrapper(
837                 sys.stdout.buffer,
838                 encoding=encoding,
839                 newline=newline,
840                 write_through=True,
841             )
842             f = wrap_stream_for_windows(f)
843             f.write(diff_contents)
844             f.detach()
845
846     return True
847
848
849 def format_stdin_to_stdout(
850     fast: bool,
851     *,
852     content: Optional[str] = None,
853     write_back: WriteBack = WriteBack.NO,
854     mode: Mode,
855 ) -> bool:
856     """Format file on stdin. Return True if changed.
857
858     If content is None, it's read from sys.stdin.
859
860     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
861     write a diff to stdout. The `mode` argument is passed to
862     :func:`format_file_contents`.
863     """
864     then = datetime.utcnow()
865
866     if content is None:
867         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
868     else:
869         src, encoding, newline = content, "utf-8", ""
870
871     dst = src
872     try:
873         dst = format_file_contents(src, fast=fast, mode=mode)
874         return True
875
876     except NothingChanged:
877         return False
878
879     finally:
880         f = io.TextIOWrapper(
881             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
882         )
883         if write_back == WriteBack.YES:
884             # Make sure there's a newline after the content
885             if dst and dst[-1] != "\n":
886                 dst += "\n"
887             f.write(dst)
888         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
889             now = datetime.utcnow()
890             src_name = f"STDIN\t{then} +0000"
891             dst_name = f"STDOUT\t{now} +0000"
892             d = diff(src, dst, src_name, dst_name)
893             if write_back == WriteBack.COLOR_DIFF:
894                 d = color_diff(d)
895                 f = wrap_stream_for_windows(f)
896             f.write(d)
897         f.detach()
898
899
900 def check_stability_and_equivalence(
901     src_contents: str, dst_contents: str, *, mode: Mode
902 ) -> None:
903     """Perform stability and equivalence checks.
904
905     Raise AssertionError if source and destination contents are not
906     equivalent, or if a second pass of the formatter would format the
907     content differently.
908     """
909     assert_equivalent(src_contents, dst_contents)
910
911     # Forced second pass to work around optional trailing commas (becoming
912     # forced trailing commas on pass 2) interacting differently with optional
913     # parentheses.  Admittedly ugly.
914     dst_contents_pass2 = format_str(dst_contents, mode=mode)
915     if dst_contents != dst_contents_pass2:
916         dst_contents = dst_contents_pass2
917         assert_equivalent(src_contents, dst_contents, pass_num=2)
918         assert_stable(src_contents, dst_contents, mode=mode)
919     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
920     # the same as `dst_contents_pass2`.
921
922
923 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
924     """Reformat contents of a file and return new contents.
925
926     If `fast` is False, additionally confirm that the reformatted code is
927     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
928     `mode` is passed to :func:`format_str`.
929     """
930     if not src_contents.strip():
931         raise NothingChanged
932
933     if mode.is_ipynb:
934         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
935     else:
936         dst_contents = format_str(src_contents, mode=mode)
937     if src_contents == dst_contents:
938         raise NothingChanged
939
940     if not fast and not mode.is_ipynb:
941         # Jupyter notebooks will already have been checked above.
942         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
943     return dst_contents
944
945
946 def validate_cell(src: str) -> None:
947     """Check that cell does not already contain TransformerManager transformations,
948     or non-Python cell magics, which might cause tokenizer_rt to break because of
949     indentations.
950
951     If a cell contains ``!ls``, then it'll be transformed to
952     ``get_ipython().system('ls')``. However, if the cell originally contained
953     ``get_ipython().system('ls')``, then it would get transformed in the same way:
954
955         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
956         "get_ipython().system('ls')\n"
957         >>> TransformerManager().transform_cell("!ls")
958         "get_ipython().system('ls')\n"
959
960     Due to the impossibility of safely roundtripping in such situations, cells
961     containing transformed magics will be ignored.
962     """
963     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
964         raise NothingChanged
965     if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS:
966         raise NothingChanged
967
968
969 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
970     """Format code in given cell of Jupyter notebook.
971
972     General idea is:
973
974       - if cell has trailing semicolon, remove it;
975       - if cell has IPython magics, mask them;
976       - format cell;
977       - reinstate IPython magics;
978       - reinstate trailing semicolon (if originally present);
979       - strip trailing newlines.
980
981     Cells with syntax errors will not be processed, as they
982     could potentially be automagics or multi-line magics, which
983     are currently not supported.
984     """
985     validate_cell(src)
986     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
987         src
988     )
989     try:
990         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
991     except SyntaxError:
992         raise NothingChanged from None
993     masked_dst = format_str(masked_src, mode=mode)
994     if not fast:
995         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
996     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
997     dst = put_trailing_semicolon_back(
998         dst_without_trailing_semicolon, has_trailing_semicolon
999     )
1000     dst = dst.rstrip("\n")
1001     if dst == src:
1002         raise NothingChanged from None
1003     return dst
1004
1005
1006 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1007     """If notebook is marked as non-Python, don't format it.
1008
1009     All notebook metadata fields are optional, see
1010     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1011     if a notebook has empty metadata, we will try to parse it anyway.
1012     """
1013     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1014     if language is not None and language != "python":
1015         raise NothingChanged from None
1016
1017
1018 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1019     """Format Jupyter notebook.
1020
1021     Operate cell-by-cell, only on code cells, only for Python notebooks.
1022     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1023     """
1024     trailing_newline = src_contents[-1] == "\n"
1025     modified = False
1026     nb = json.loads(src_contents)
1027     validate_metadata(nb)
1028     for cell in nb["cells"]:
1029         if cell.get("cell_type", None) == "code":
1030             try:
1031                 src = "".join(cell["source"])
1032                 dst = format_cell(src, fast=fast, mode=mode)
1033             except NothingChanged:
1034                 pass
1035             else:
1036                 cell["source"] = dst.splitlines(keepends=True)
1037                 modified = True
1038     if modified:
1039         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1040         if trailing_newline:
1041             dst_contents = dst_contents + "\n"
1042         return dst_contents
1043     else:
1044         raise NothingChanged
1045
1046
1047 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1048     """Reformat a string and return new contents.
1049
1050     `mode` determines formatting options, such as how many characters per line are
1051     allowed.  Example:
1052
1053     >>> import black
1054     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1055     def f(arg: str = "") -> None:
1056         ...
1057
1058     A more complex example:
1059
1060     >>> print(
1061     ...   black.format_str(
1062     ...     "def f(arg:str='')->None: hey",
1063     ...     mode=black.Mode(
1064     ...       target_versions={black.TargetVersion.PY36},
1065     ...       line_length=10,
1066     ...       string_normalization=False,
1067     ...       is_pyi=False,
1068     ...     ),
1069     ...   ),
1070     ... )
1071     def f(
1072         arg: str = '',
1073     ) -> None:
1074         hey
1075
1076     """
1077     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1078     dst_contents = []
1079     future_imports = get_future_imports(src_node)
1080     if mode.target_versions:
1081         versions = mode.target_versions
1082     else:
1083         versions = detect_target_versions(src_node)
1084
1085     # TODO: fully drop support and this code hopefully in January 2022 :D
1086     if TargetVersion.PY27 in mode.target_versions or versions == {TargetVersion.PY27}:
1087         msg = (
1088             "DEPRECATION: Python 2 support will be removed in the first stable release "
1089             "expected in January 2022."
1090         )
1091         err(msg, fg="yellow", bold=True)
1092
1093     normalize_fmt_off(src_node)
1094     lines = LineGenerator(
1095         mode=mode,
1096         remove_u_prefix="unicode_literals" in future_imports
1097         or supports_feature(versions, Feature.UNICODE_LITERALS),
1098     )
1099     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1100     empty_line = Line(mode=mode)
1101     after = 0
1102     split_line_features = {
1103         feature
1104         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1105         if supports_feature(versions, feature)
1106     }
1107     for current_line in lines.visit(src_node):
1108         dst_contents.append(str(empty_line) * after)
1109         before, after = elt.maybe_empty_lines(current_line)
1110         dst_contents.append(str(empty_line) * before)
1111         for line in transform_line(
1112             current_line, mode=mode, features=split_line_features
1113         ):
1114             dst_contents.append(str(line))
1115     return "".join(dst_contents)
1116
1117
1118 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1119     """Return a tuple of (decoded_contents, encoding, newline).
1120
1121     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1122     universal newlines (i.e. only contains LF).
1123     """
1124     srcbuf = io.BytesIO(src)
1125     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1126     if not lines:
1127         return "", encoding, "\n"
1128
1129     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1130     srcbuf.seek(0)
1131     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1132         return tiow.read(), encoding, newline
1133
1134
1135 def get_features_used(node: Node) -> Set[Feature]:  # noqa: C901
1136     """Return a set of (relatively) new Python features used in this file.
1137
1138     Currently looking for:
1139     - f-strings;
1140     - underscores in numeric literals;
1141     - trailing commas after * or ** in function signatures and calls;
1142     - positional only arguments in function signatures and lambdas;
1143     - assignment expression;
1144     - relaxed decorator syntax;
1145     - print / exec statements;
1146     """
1147     features: Set[Feature] = set()
1148     for n in node.pre_order():
1149         if n.type == token.STRING:
1150             value_head = n.value[:2]  # type: ignore
1151             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1152                 features.add(Feature.F_STRINGS)
1153
1154         elif n.type == token.NUMBER:
1155             assert isinstance(n, Leaf)
1156             if "_" in n.value:
1157                 features.add(Feature.NUMERIC_UNDERSCORES)
1158             elif n.value.endswith(("L", "l")):
1159                 # Python 2: 10L
1160                 features.add(Feature.LONG_INT_LITERAL)
1161             elif len(n.value) >= 2 and n.value[0] == "0" and n.value[1].isdigit():
1162                 # Python 2: 0123; 00123; ...
1163                 if not all(char == "0" for char in n.value):
1164                     # although we don't want to match 0000 or similar
1165                     features.add(Feature.OCTAL_INT_LITERAL)
1166
1167         elif n.type == token.SLASH:
1168             if n.parent and n.parent.type in {
1169                 syms.typedargslist,
1170                 syms.arglist,
1171                 syms.varargslist,
1172             }:
1173                 features.add(Feature.POS_ONLY_ARGUMENTS)
1174
1175         elif n.type == token.COLONEQUAL:
1176             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1177
1178         elif n.type == syms.decorator:
1179             if len(n.children) > 1 and not is_simple_decorator_expression(
1180                 n.children[1]
1181             ):
1182                 features.add(Feature.RELAXED_DECORATORS)
1183
1184         elif (
1185             n.type in {syms.typedargslist, syms.arglist}
1186             and n.children
1187             and n.children[-1].type == token.COMMA
1188         ):
1189             if n.type == syms.typedargslist:
1190                 feature = Feature.TRAILING_COMMA_IN_DEF
1191             else:
1192                 feature = Feature.TRAILING_COMMA_IN_CALL
1193
1194             for ch in n.children:
1195                 if ch.type in STARS:
1196                     features.add(feature)
1197
1198                 if ch.type == syms.argument:
1199                     for argch in ch.children:
1200                         if argch.type in STARS:
1201                             features.add(feature)
1202
1203         # Python 2 only features (for its deprecation) except for integers, see above
1204         elif n.type == syms.print_stmt:
1205             features.add(Feature.PRINT_STMT)
1206         elif n.type == syms.exec_stmt:
1207             features.add(Feature.EXEC_STMT)
1208         elif n.type == syms.tfpdef:
1209             # def set_position((x, y), value):
1210             #     ...
1211             features.add(Feature.AUTOMATIC_PARAMETER_UNPACKING)
1212         elif n.type == syms.except_clause:
1213             # try:
1214             #     ...
1215             # except Exception, err:
1216             #     ...
1217             if len(n.children) >= 4:
1218                 if n.children[-2].type == token.COMMA:
1219                     features.add(Feature.COMMA_STYLE_EXCEPT)
1220         elif n.type == syms.raise_stmt:
1221             # raise Exception, "msg"
1222             if len(n.children) >= 4:
1223                 if n.children[-2].type == token.COMMA:
1224                     features.add(Feature.COMMA_STYLE_RAISE)
1225         elif n.type == token.BACKQUOTE:
1226             # `i'm surprised this ever existed`
1227             features.add(Feature.BACKQUOTE_REPR)
1228
1229     return features
1230
1231
1232 def detect_target_versions(node: Node) -> Set[TargetVersion]:
1233     """Detect the version to target based on the nodes used."""
1234     features = get_features_used(node)
1235     return {
1236         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1237     }
1238
1239
1240 def get_future_imports(node: Node) -> Set[str]:
1241     """Return a set of __future__ imports in the file."""
1242     imports: Set[str] = set()
1243
1244     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1245         for child in children:
1246             if isinstance(child, Leaf):
1247                 if child.type == token.NAME:
1248                     yield child.value
1249
1250             elif child.type == syms.import_as_name:
1251                 orig_name = child.children[0]
1252                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1253                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1254                 yield orig_name.value
1255
1256             elif child.type == syms.import_as_names:
1257                 yield from get_imports_from_children(child.children)
1258
1259             else:
1260                 raise AssertionError("Invalid syntax parsing imports")
1261
1262     for child in node.children:
1263         if child.type != syms.simple_stmt:
1264             break
1265
1266         first_child = child.children[0]
1267         if isinstance(first_child, Leaf):
1268             # Continue looking if we see a docstring; otherwise stop.
1269             if (
1270                 len(child.children) == 2
1271                 and first_child.type == token.STRING
1272                 and child.children[1].type == token.NEWLINE
1273             ):
1274                 continue
1275
1276             break
1277
1278         elif first_child.type == syms.import_from:
1279             module_name = first_child.children[1]
1280             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1281                 break
1282
1283             imports |= set(get_imports_from_children(first_child.children[3:]))
1284         else:
1285             break
1286
1287     return imports
1288
1289
1290 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1291     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1292     try:
1293         src_ast = parse_ast(src)
1294     except Exception as exc:
1295         raise AssertionError(
1296             "cannot use --safe with this file; failed to parse source file."
1297         ) from exc
1298
1299     try:
1300         dst_ast = parse_ast(dst)
1301     except Exception as exc:
1302         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1303         raise AssertionError(
1304             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1305             "Please report a bug on https://github.com/psf/black/issues.  "
1306             f"This invalid output might be helpful: {log}"
1307         ) from None
1308
1309     src_ast_str = "\n".join(stringify_ast(src_ast))
1310     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1311     if src_ast_str != dst_ast_str:
1312         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1313         raise AssertionError(
1314             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1315             f" source on pass {pass_num}.  Please report a bug on "
1316             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1317         ) from None
1318
1319
1320 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1321     """Raise AssertionError if `dst` reformats differently the second time."""
1322     newdst = format_str(dst, mode=mode)
1323     if dst != newdst:
1324         log = dump_to_file(
1325             str(mode),
1326             diff(src, dst, "source", "first pass"),
1327             diff(dst, newdst, "first pass", "second pass"),
1328         )
1329         raise AssertionError(
1330             "INTERNAL ERROR: Black produced different code on the second pass of the"
1331             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1332             f"  This diff might be helpful: {log}"
1333         ) from None
1334
1335
1336 @contextmanager
1337 def nullcontext() -> Iterator[None]:
1338     """Return an empty context manager.
1339
1340     To be used like `nullcontext` in Python 3.7.
1341     """
1342     yield
1343
1344
1345 def patch_click() -> None:
1346     """Make Click not crash on Python 3.6 with LANG=C.
1347
1348     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1349     default which restricts paths that it can access during the lifetime of the
1350     application.  Click refuses to work in this scenario by raising a RuntimeError.
1351
1352     In case of Black the likelihood that non-ASCII characters are going to be used in
1353     file paths is minimal since it's Python source code.  Moreover, this crash was
1354     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1355     """
1356     try:
1357         from click import core
1358         from click import _unicodefun
1359     except ModuleNotFoundError:
1360         return
1361
1362     for module in (core, _unicodefun):
1363         if hasattr(module, "_verify_python3_env"):
1364             module._verify_python3_env = lambda: None  # type: ignore
1365         if hasattr(module, "_verify_python_env"):
1366             module._verify_python_env = lambda: None  # type: ignore
1367
1368
1369 def patched_main() -> None:
1370     maybe_install_uvloop()
1371     freeze_support()
1372     patch_click()
1373     main()
1374
1375
1376 if __name__ == "__main__":
1377     patched_main()