]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix and speedup diff-shades integration (#2773)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 import click
34 from click.core import ParameterSource
35 from dataclasses import replace
36 from mypy_extensions import mypyc_attr
37
38 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
39 from black.const import STDIN_PLACEHOLDER
40 from black.nodes import STARS, syms, is_simple_decorator_expression
41 from black.nodes import is_string_token
42 from black.lines import Line, EmptyLineTracker
43 from black.linegen import transform_line, LineGenerator, LN
44 from black.comments import normalize_fmt_off
45 from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
46 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
47 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
48 from black.concurrency import cancel, shutdown, maybe_install_uvloop
49 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
50 from black.report import Report, Changed, NothingChanged
51 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
52 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
53 from black.files import wrap_stream_for_windows
54 from black.parsing import InvalidInput  # noqa F401
55 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
56 from black.handle_ipynb_magics import (
57     mask_cell,
58     unmask_cell,
59     remove_trailing_semicolon,
60     put_trailing_semicolon_back,
61     TRANSFORMED_MAGICS,
62     PYTHON_CELL_MAGICS,
63     jupyter_dependencies_are_installed,
64 )
65
66
67 # lib2to3 fork
68 from blib2to3.pytree import Node, Leaf
69 from blib2to3.pgen2 import token
70
71 from _black_version import version as __version__
72
73 COMPILED = Path(__file__).suffix in (".pyd", ".so")
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79
80
81 class WriteBack(Enum):
82     NO = 0
83     YES = 1
84     DIFF = 2
85     CHECK = 3
86     COLOR_DIFF = 4
87
88     @classmethod
89     def from_configuration(
90         cls, *, check: bool, diff: bool, color: bool = False
91     ) -> "WriteBack":
92         if check and not diff:
93             return cls.CHECK
94
95         if diff and color:
96             return cls.COLOR_DIFF
97
98         return cls.DIFF if diff else cls.YES
99
100
101 # Legacy name, left for integrations.
102 FileMode = Mode
103
104 DEFAULT_WORKERS = os.cpu_count()
105
106
107 def read_pyproject_toml(
108     ctx: click.Context, param: click.Parameter, value: Optional[str]
109 ) -> Optional[str]:
110     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
111
112     Returns the path to a successfully found and read configuration file, None
113     otherwise.
114     """
115     if not value:
116         value = find_pyproject_toml(ctx.params.get("src", ()))
117         if value is None:
118             return None
119
120     try:
121         config = parse_pyproject_toml(value)
122     except (OSError, ValueError) as e:
123         raise click.FileError(
124             filename=value, hint=f"Error reading configuration file: {e}"
125         ) from None
126
127     if not config:
128         return None
129     else:
130         # Sanitize the values to be Click friendly. For more information please see:
131         # https://github.com/psf/black/issues/1458
132         # https://github.com/pallets/click/issues/1567
133         config = {
134             k: str(v) if not isinstance(v, (list, dict)) else v
135             for k, v in config.items()
136         }
137
138     target_version = config.get("target_version")
139     if target_version is not None and not isinstance(target_version, list):
140         raise click.BadOptionUsage(
141             "target-version", "Config key target-version must be a list"
142         )
143
144     default_map: Dict[str, Any] = {}
145     if ctx.default_map:
146         default_map.update(ctx.default_map)
147     default_map.update(config)
148
149     ctx.default_map = default_map
150     return value
151
152
153 def target_version_option_callback(
154     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
155 ) -> List[TargetVersion]:
156     """Compute the target versions from a --target-version flag.
157
158     This is its own function because mypy couldn't infer the type correctly
159     when it was a lambda, causing mypyc trouble.
160     """
161     return [TargetVersion[val.upper()] for val in v]
162
163
164 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
165     """Compile a regular expression string in `regex`.
166
167     If it contains newlines, use verbose mode.
168     """
169     if "\n" in regex:
170         regex = "(?x)" + regex
171     compiled: Pattern[str] = re.compile(regex)
172     return compiled
173
174
175 def validate_regex(
176     ctx: click.Context,
177     param: click.Parameter,
178     value: Optional[str],
179 ) -> Optional[Pattern[str]]:
180     try:
181         return re_compile_maybe_verbose(value) if value is not None else None
182     except re.error as e:
183         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
184
185
186 @click.command(
187     context_settings={"help_option_names": ["-h", "--help"]},
188     # While Click does set this field automatically using the docstring, mypyc
189     # (annoyingly) strips 'em so we need to set it here too.
190     help="The uncompromising code formatter.",
191 )
192 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
193 @click.option(
194     "-l",
195     "--line-length",
196     type=int,
197     default=DEFAULT_LINE_LENGTH,
198     help="How many characters per line to allow.",
199     show_default=True,
200 )
201 @click.option(
202     "-t",
203     "--target-version",
204     type=click.Choice([v.name.lower() for v in TargetVersion]),
205     callback=target_version_option_callback,
206     multiple=True,
207     help=(
208         "Python versions that should be supported by Black's output. [default: per-file"
209         " auto-detection]"
210     ),
211 )
212 @click.option(
213     "--pyi",
214     is_flag=True,
215     help=(
216         "Format all input files like typing stubs regardless of file extension (useful"
217         " when piping source on standard input)."
218     ),
219 )
220 @click.option(
221     "--ipynb",
222     is_flag=True,
223     help=(
224         "Format all input files like Jupyter Notebooks regardless of file extension "
225         "(useful when piping source on standard input)."
226     ),
227 )
228 @click.option(
229     "-S",
230     "--skip-string-normalization",
231     is_flag=True,
232     help="Don't normalize string quotes or prefixes.",
233 )
234 @click.option(
235     "-C",
236     "--skip-magic-trailing-comma",
237     is_flag=True,
238     help="Don't use trailing commas as a reason to split lines.",
239 )
240 @click.option(
241     "--experimental-string-processing",
242     is_flag=True,
243     hidden=True,
244     help=(
245         "Experimental option that performs more normalization on string literals."
246         " Currently disabled because it leads to some crashes."
247     ),
248 )
249 @click.option(
250     "--preview",
251     is_flag=True,
252     help=(
253         "Enable potentially disruptive style changes that will be added to Black's main"
254         " functionality in the next major release."
255     ),
256 )
257 @click.option(
258     "--check",
259     is_flag=True,
260     help=(
261         "Don't write the files back, just return the status. Return code 0 means"
262         " nothing would change. Return code 1 means some files would be reformatted."
263         " Return code 123 means there was an internal error."
264     ),
265 )
266 @click.option(
267     "--diff",
268     is_flag=True,
269     help="Don't write the files back, just output a diff for each file on stdout.",
270 )
271 @click.option(
272     "--color/--no-color",
273     is_flag=True,
274     help="Show colored diff. Only applies when `--diff` is given.",
275 )
276 @click.option(
277     "--fast/--safe",
278     is_flag=True,
279     help="If --fast given, skip temporary sanity checks. [default: --safe]",
280 )
281 @click.option(
282     "--required-version",
283     type=str,
284     help=(
285         "Require a specific version of Black to be running (useful for unifying results"
286         " across many environments e.g. with a pyproject.toml file)."
287     ),
288 )
289 @click.option(
290     "--include",
291     type=str,
292     default=DEFAULT_INCLUDES,
293     callback=validate_regex,
294     help=(
295         "A regular expression that matches files and directories that should be"
296         " included on recursive searches. An empty value means all files are included"
297         " regardless of the name. Use forward slashes for directories on all platforms"
298         " (Windows, too). Exclusions are calculated first, inclusions later."
299     ),
300     show_default=True,
301 )
302 @click.option(
303     "--exclude",
304     type=str,
305     callback=validate_regex,
306     help=(
307         "A regular expression that matches files and directories that should be"
308         " excluded on recursive searches. An empty value means no paths are excluded."
309         " Use forward slashes for directories on all platforms (Windows, too)."
310         " Exclusions are calculated first, inclusions later. [default:"
311         f" {DEFAULT_EXCLUDES}]"
312     ),
313     show_default=False,
314 )
315 @click.option(
316     "--extend-exclude",
317     type=str,
318     callback=validate_regex,
319     help=(
320         "Like --exclude, but adds additional files and directories on top of the"
321         " excluded ones. (Useful if you simply want to add to the default)"
322     ),
323 )
324 @click.option(
325     "--force-exclude",
326     type=str,
327     callback=validate_regex,
328     help=(
329         "Like --exclude, but files and directories matching this regex will be "
330         "excluded even when they are passed explicitly as arguments."
331     ),
332 )
333 @click.option(
334     "--stdin-filename",
335     type=str,
336     help=(
337         "The name of the file when passing it through stdin. Useful to make "
338         "sure Black will respect --force-exclude option on some "
339         "editors that rely on using stdin."
340     ),
341 )
342 @click.option(
343     "-W",
344     "--workers",
345     type=click.IntRange(min=1),
346     default=DEFAULT_WORKERS,
347     show_default=True,
348     help="Number of parallel workers",
349 )
350 @click.option(
351     "-q",
352     "--quiet",
353     is_flag=True,
354     help=(
355         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
356         " those with 2>/dev/null."
357     ),
358 )
359 @click.option(
360     "-v",
361     "--verbose",
362     is_flag=True,
363     help=(
364         "Also emit messages to stderr about files that were not changed or were ignored"
365         " due to exclusion patterns."
366     ),
367 )
368 @click.version_option(
369     version=__version__,
370     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
371 )
372 @click.argument(
373     "src",
374     nargs=-1,
375     type=click.Path(
376         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
377     ),
378     is_eager=True,
379     metavar="SRC ...",
380 )
381 @click.option(
382     "--config",
383     type=click.Path(
384         exists=True,
385         file_okay=True,
386         dir_okay=False,
387         readable=True,
388         allow_dash=False,
389         path_type=str,
390     ),
391     is_eager=True,
392     callback=read_pyproject_toml,
393     help="Read configuration from FILE path.",
394 )
395 @click.pass_context
396 def main(
397     ctx: click.Context,
398     code: Optional[str],
399     line_length: int,
400     target_version: List[TargetVersion],
401     check: bool,
402     diff: bool,
403     color: bool,
404     fast: bool,
405     pyi: bool,
406     ipynb: bool,
407     skip_string_normalization: bool,
408     skip_magic_trailing_comma: bool,
409     experimental_string_processing: bool,
410     preview: bool,
411     quiet: bool,
412     verbose: bool,
413     required_version: Optional[str],
414     include: Pattern[str],
415     exclude: Optional[Pattern[str]],
416     extend_exclude: Optional[Pattern[str]],
417     force_exclude: Optional[Pattern[str]],
418     stdin_filename: Optional[str],
419     workers: int,
420     src: Tuple[str, ...],
421     config: Optional[str],
422 ) -> None:
423     """The uncompromising code formatter."""
424     ctx.ensure_object(dict)
425     root, method = find_project_root(src) if code is None else (None, None)
426     ctx.obj["root"] = root
427
428     if verbose:
429         if root:
430             out(
431                 f"Identified `{root}` as project root containing a {method}.",
432                 fg="blue",
433             )
434
435             normalized = [
436                 (normalize_path_maybe_ignore(Path(source), root), source)
437                 for source in src
438             ]
439             srcs_string = ", ".join(
440                 [
441                     f'"{_norm}"'
442                     if _norm
443                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
444                     for _norm, source in normalized
445                 ]
446             )
447             out(f"Sources to be formatted: {srcs_string}", fg="blue")
448
449         if config:
450             config_source = ctx.get_parameter_source("config")
451             if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP):
452                 out("Using configuration from project root.", fg="blue")
453             else:
454                 out(f"Using configuration in '{config}'.", fg="blue")
455
456     error_msg = "Oh no! 💥 💔 💥"
457     if required_version and required_version != __version__:
458         err(
459             f"{error_msg} The required version `{required_version}` does not match"
460             f" the running version `{__version__}`!"
461         )
462         ctx.exit(1)
463     if ipynb and pyi:
464         err("Cannot pass both `pyi` and `ipynb` flags!")
465         ctx.exit(1)
466
467     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
468     if target_version:
469         versions = set(target_version)
470     else:
471         # We'll autodetect later.
472         versions = set()
473     mode = Mode(
474         target_versions=versions,
475         line_length=line_length,
476         is_pyi=pyi,
477         is_ipynb=ipynb,
478         string_normalization=not skip_string_normalization,
479         magic_trailing_comma=not skip_magic_trailing_comma,
480         experimental_string_processing=experimental_string_processing,
481         preview=preview,
482     )
483
484     if code is not None:
485         # Run in quiet mode by default with -c; the extra output isn't useful.
486         # You can still pass -v to get verbose output.
487         quiet = True
488
489     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
490
491     if code is not None:
492         reformat_code(
493             content=code, fast=fast, write_back=write_back, mode=mode, report=report
494         )
495     else:
496         try:
497             sources = get_sources(
498                 ctx=ctx,
499                 src=src,
500                 quiet=quiet,
501                 verbose=verbose,
502                 include=include,
503                 exclude=exclude,
504                 extend_exclude=extend_exclude,
505                 force_exclude=force_exclude,
506                 report=report,
507                 stdin_filename=stdin_filename,
508             )
509         except GitWildMatchPatternError:
510             ctx.exit(1)
511
512         path_empty(
513             sources,
514             "No Python files are present to be formatted. Nothing to do 😴",
515             quiet,
516             verbose,
517             ctx,
518         )
519
520         if len(sources) == 1:
521             reformat_one(
522                 src=sources.pop(),
523                 fast=fast,
524                 write_back=write_back,
525                 mode=mode,
526                 report=report,
527             )
528         else:
529             reformat_many(
530                 sources=sources,
531                 fast=fast,
532                 write_back=write_back,
533                 mode=mode,
534                 report=report,
535                 workers=workers,
536             )
537
538     if verbose or not quiet:
539         if code is None and (verbose or report.change_count or report.failure_count):
540             out()
541         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
542         if code is None:
543             click.echo(str(report), err=True)
544     ctx.exit(report.return_code)
545
546
547 def get_sources(
548     *,
549     ctx: click.Context,
550     src: Tuple[str, ...],
551     quiet: bool,
552     verbose: bool,
553     include: Pattern[str],
554     exclude: Optional[Pattern[str]],
555     extend_exclude: Optional[Pattern[str]],
556     force_exclude: Optional[Pattern[str]],
557     report: "Report",
558     stdin_filename: Optional[str],
559 ) -> Set[Path]:
560     """Compute the set of files to be formatted."""
561     sources: Set[Path] = set()
562     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
563
564     if exclude is None:
565         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
566         gitignore = get_gitignore(ctx.obj["root"])
567     else:
568         gitignore = None
569
570     for s in src:
571         if s == "-" and stdin_filename:
572             p = Path(stdin_filename)
573             is_stdin = True
574         else:
575             p = Path(s)
576             is_stdin = False
577
578         if is_stdin or p.is_file():
579             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
580             if normalized_path is None:
581                 continue
582
583             normalized_path = "/" + normalized_path
584             # Hard-exclude any files that matches the `--force-exclude` regex.
585             if force_exclude:
586                 force_exclude_match = force_exclude.search(normalized_path)
587             else:
588                 force_exclude_match = None
589             if force_exclude_match and force_exclude_match.group(0):
590                 report.path_ignored(p, "matches the --force-exclude regular expression")
591                 continue
592
593             if is_stdin:
594                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
595
596             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
597                 verbose=verbose, quiet=quiet
598             ):
599                 continue
600
601             sources.add(p)
602         elif p.is_dir():
603             sources.update(
604                 gen_python_files(
605                     p.iterdir(),
606                     ctx.obj["root"],
607                     include,
608                     exclude,
609                     extend_exclude,
610                     force_exclude,
611                     report,
612                     gitignore,
613                     verbose=verbose,
614                     quiet=quiet,
615                 )
616             )
617         elif s == "-":
618             sources.add(p)
619         else:
620             err(f"invalid path: {s}")
621     return sources
622
623
624 def path_empty(
625     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
626 ) -> None:
627     """
628     Exit if there is no `src` provided for formatting
629     """
630     if not src:
631         if verbose or not quiet:
632             out(msg)
633         ctx.exit(0)
634
635
636 def reformat_code(
637     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
638 ) -> None:
639     """
640     Reformat and print out `content` without spawning child processes.
641     Similar to `reformat_one`, but for string content.
642
643     `fast`, `write_back`, and `mode` options are passed to
644     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
645     """
646     path = Path("<string>")
647     try:
648         changed = Changed.NO
649         if format_stdin_to_stdout(
650             content=content, fast=fast, write_back=write_back, mode=mode
651         ):
652             changed = Changed.YES
653         report.done(path, changed)
654     except Exception as exc:
655         if report.verbose:
656             traceback.print_exc()
657         report.failed(path, str(exc))
658
659
660 def reformat_one(
661     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
662 ) -> None:
663     """Reformat a single file under `src` without spawning child processes.
664
665     `fast`, `write_back`, and `mode` options are passed to
666     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
667     """
668     try:
669         changed = Changed.NO
670
671         if str(src) == "-":
672             is_stdin = True
673         elif str(src).startswith(STDIN_PLACEHOLDER):
674             is_stdin = True
675             # Use the original name again in case we want to print something
676             # to the user
677             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
678         else:
679             is_stdin = False
680
681         if is_stdin:
682             if src.suffix == ".pyi":
683                 mode = replace(mode, is_pyi=True)
684             elif src.suffix == ".ipynb":
685                 mode = replace(mode, is_ipynb=True)
686             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
687                 changed = Changed.YES
688         else:
689             cache: Cache = {}
690             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
691                 cache = read_cache(mode)
692                 res_src = src.resolve()
693                 res_src_s = str(res_src)
694                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
695                     changed = Changed.CACHED
696             if changed is not Changed.CACHED and format_file_in_place(
697                 src, fast=fast, write_back=write_back, mode=mode
698             ):
699                 changed = Changed.YES
700             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
701                 write_back is WriteBack.CHECK and changed is Changed.NO
702             ):
703                 write_cache(cache, [src], mode)
704         report.done(src, changed)
705     except Exception as exc:
706         if report.verbose:
707             traceback.print_exc()
708         report.failed(src, str(exc))
709
710
711 # diff-shades depends on being to monkeypatch this function to operate. I know it's
712 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
713 @mypyc_attr(patchable=True)
714 def reformat_many(
715     sources: Set[Path],
716     fast: bool,
717     write_back: WriteBack,
718     mode: Mode,
719     report: "Report",
720     workers: Optional[int],
721 ) -> None:
722     """Reformat multiple files using a ProcessPoolExecutor."""
723     executor: Executor
724     loop = asyncio.get_event_loop()
725     worker_count = workers if workers is not None else DEFAULT_WORKERS
726     if sys.platform == "win32":
727         # Work around https://bugs.python.org/issue26903
728         assert worker_count is not None
729         worker_count = min(worker_count, 60)
730     try:
731         executor = ProcessPoolExecutor(max_workers=worker_count)
732     except (ImportError, NotImplementedError, OSError):
733         # we arrive here if the underlying system does not support multi-processing
734         # like in AWS Lambda or Termux, in which case we gracefully fallback to
735         # a ThreadPoolExecutor with just a single worker (more workers would not do us
736         # any good due to the Global Interpreter Lock)
737         executor = ThreadPoolExecutor(max_workers=1)
738
739     try:
740         loop.run_until_complete(
741             schedule_formatting(
742                 sources=sources,
743                 fast=fast,
744                 write_back=write_back,
745                 mode=mode,
746                 report=report,
747                 loop=loop,
748                 executor=executor,
749             )
750         )
751     finally:
752         shutdown(loop)
753         if executor is not None:
754             executor.shutdown()
755
756
757 async def schedule_formatting(
758     sources: Set[Path],
759     fast: bool,
760     write_back: WriteBack,
761     mode: Mode,
762     report: "Report",
763     loop: asyncio.AbstractEventLoop,
764     executor: Executor,
765 ) -> None:
766     """Run formatting of `sources` in parallel using the provided `executor`.
767
768     (Use ProcessPoolExecutors for actual parallelism.)
769
770     `write_back`, `fast`, and `mode` options are passed to
771     :func:`format_file_in_place`.
772     """
773     cache: Cache = {}
774     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
775         cache = read_cache(mode)
776         sources, cached = filter_cached(cache, sources)
777         for src in sorted(cached):
778             report.done(src, Changed.CACHED)
779     if not sources:
780         return
781
782     cancelled = []
783     sources_to_cache = []
784     lock = None
785     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
786         # For diff output, we need locks to ensure we don't interleave output
787         # from different processes.
788         manager = Manager()
789         lock = manager.Lock()
790     tasks = {
791         asyncio.ensure_future(
792             loop.run_in_executor(
793                 executor, format_file_in_place, src, fast, mode, write_back, lock
794             )
795         ): src
796         for src in sorted(sources)
797     }
798     pending = tasks.keys()
799     try:
800         loop.add_signal_handler(signal.SIGINT, cancel, pending)
801         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
802     except NotImplementedError:
803         # There are no good alternatives for these on Windows.
804         pass
805     while pending:
806         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
807         for task in done:
808             src = tasks.pop(task)
809             if task.cancelled():
810                 cancelled.append(task)
811             elif task.exception():
812                 report.failed(src, str(task.exception()))
813             else:
814                 changed = Changed.YES if task.result() else Changed.NO
815                 # If the file was written back or was successfully checked as
816                 # well-formatted, store this information in the cache.
817                 if write_back is WriteBack.YES or (
818                     write_back is WriteBack.CHECK and changed is Changed.NO
819                 ):
820                     sources_to_cache.append(src)
821                 report.done(src, changed)
822     if cancelled:
823         if sys.version_info >= (3, 7):
824             await asyncio.gather(*cancelled, return_exceptions=True)
825         else:
826             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
827     if sources_to_cache:
828         write_cache(cache, sources_to_cache, mode)
829
830
831 def format_file_in_place(
832     src: Path,
833     fast: bool,
834     mode: Mode,
835     write_back: WriteBack = WriteBack.NO,
836     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
837 ) -> bool:
838     """Format file under `src` path. Return True if changed.
839
840     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
841     code to the file.
842     `mode` and `fast` options are passed to :func:`format_file_contents`.
843     """
844     if src.suffix == ".pyi":
845         mode = replace(mode, is_pyi=True)
846     elif src.suffix == ".ipynb":
847         mode = replace(mode, is_ipynb=True)
848
849     then = datetime.utcfromtimestamp(src.stat().st_mtime)
850     with open(src, "rb") as buf:
851         src_contents, encoding, newline = decode_bytes(buf.read())
852     try:
853         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
854     except NothingChanged:
855         return False
856     except JSONDecodeError:
857         raise ValueError(
858             f"File '{src}' cannot be parsed as valid Jupyter notebook."
859         ) from None
860
861     if write_back == WriteBack.YES:
862         with open(src, "w", encoding=encoding, newline=newline) as f:
863             f.write(dst_contents)
864     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
865         now = datetime.utcnow()
866         src_name = f"{src}\t{then} +0000"
867         dst_name = f"{src}\t{now} +0000"
868         if mode.is_ipynb:
869             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
870         else:
871             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
872
873         if write_back == WriteBack.COLOR_DIFF:
874             diff_contents = color_diff(diff_contents)
875
876         with lock or nullcontext():
877             f = io.TextIOWrapper(
878                 sys.stdout.buffer,
879                 encoding=encoding,
880                 newline=newline,
881                 write_through=True,
882             )
883             f = wrap_stream_for_windows(f)
884             f.write(diff_contents)
885             f.detach()
886
887     return True
888
889
890 def format_stdin_to_stdout(
891     fast: bool,
892     *,
893     content: Optional[str] = None,
894     write_back: WriteBack = WriteBack.NO,
895     mode: Mode,
896 ) -> bool:
897     """Format file on stdin. Return True if changed.
898
899     If content is None, it's read from sys.stdin.
900
901     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
902     write a diff to stdout. The `mode` argument is passed to
903     :func:`format_file_contents`.
904     """
905     then = datetime.utcnow()
906
907     if content is None:
908         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
909     else:
910         src, encoding, newline = content, "utf-8", ""
911
912     dst = src
913     try:
914         dst = format_file_contents(src, fast=fast, mode=mode)
915         return True
916
917     except NothingChanged:
918         return False
919
920     finally:
921         f = io.TextIOWrapper(
922             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
923         )
924         if write_back == WriteBack.YES:
925             # Make sure there's a newline after the content
926             if dst and dst[-1] != "\n":
927                 dst += "\n"
928             f.write(dst)
929         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
930             now = datetime.utcnow()
931             src_name = f"STDIN\t{then} +0000"
932             dst_name = f"STDOUT\t{now} +0000"
933             d = diff(src, dst, src_name, dst_name)
934             if write_back == WriteBack.COLOR_DIFF:
935                 d = color_diff(d)
936                 f = wrap_stream_for_windows(f)
937             f.write(d)
938         f.detach()
939
940
941 def check_stability_and_equivalence(
942     src_contents: str, dst_contents: str, *, mode: Mode
943 ) -> None:
944     """Perform stability and equivalence checks.
945
946     Raise AssertionError if source and destination contents are not
947     equivalent, or if a second pass of the formatter would format the
948     content differently.
949     """
950     assert_equivalent(src_contents, dst_contents)
951
952     # Forced second pass to work around optional trailing commas (becoming
953     # forced trailing commas on pass 2) interacting differently with optional
954     # parentheses.  Admittedly ugly.
955     dst_contents_pass2 = format_str(dst_contents, mode=mode)
956     if dst_contents != dst_contents_pass2:
957         dst_contents = dst_contents_pass2
958         assert_equivalent(src_contents, dst_contents, pass_num=2)
959         assert_stable(src_contents, dst_contents, mode=mode)
960     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
961     # the same as `dst_contents_pass2`.
962
963
964 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
965     """Reformat contents of a file and return new contents.
966
967     If `fast` is False, additionally confirm that the reformatted code is
968     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
969     `mode` is passed to :func:`format_str`.
970     """
971     if not src_contents.strip():
972         raise NothingChanged
973
974     if mode.is_ipynb:
975         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
976     else:
977         dst_contents = format_str(src_contents, mode=mode)
978     if src_contents == dst_contents:
979         raise NothingChanged
980
981     if not fast and not mode.is_ipynb:
982         # Jupyter notebooks will already have been checked above.
983         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
984     return dst_contents
985
986
987 def validate_cell(src: str) -> None:
988     """Check that cell does not already contain TransformerManager transformations,
989     or non-Python cell magics, which might cause tokenizer_rt to break because of
990     indentations.
991
992     If a cell contains ``!ls``, then it'll be transformed to
993     ``get_ipython().system('ls')``. However, if the cell originally contained
994     ``get_ipython().system('ls')``, then it would get transformed in the same way:
995
996         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
997         "get_ipython().system('ls')\n"
998         >>> TransformerManager().transform_cell("!ls")
999         "get_ipython().system('ls')\n"
1000
1001     Due to the impossibility of safely roundtripping in such situations, cells
1002     containing transformed magics will be ignored.
1003     """
1004     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
1005         raise NothingChanged
1006     if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS:
1007         raise NothingChanged
1008
1009
1010 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
1011     """Format code in given cell of Jupyter notebook.
1012
1013     General idea is:
1014
1015       - if cell has trailing semicolon, remove it;
1016       - if cell has IPython magics, mask them;
1017       - format cell;
1018       - reinstate IPython magics;
1019       - reinstate trailing semicolon (if originally present);
1020       - strip trailing newlines.
1021
1022     Cells with syntax errors will not be processed, as they
1023     could potentially be automagics or multi-line magics, which
1024     are currently not supported.
1025     """
1026     validate_cell(src)
1027     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
1028         src
1029     )
1030     try:
1031         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
1032     except SyntaxError:
1033         raise NothingChanged from None
1034     masked_dst = format_str(masked_src, mode=mode)
1035     if not fast:
1036         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
1037     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1038     dst = put_trailing_semicolon_back(
1039         dst_without_trailing_semicolon, has_trailing_semicolon
1040     )
1041     dst = dst.rstrip("\n")
1042     if dst == src:
1043         raise NothingChanged from None
1044     return dst
1045
1046
1047 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1048     """If notebook is marked as non-Python, don't format it.
1049
1050     All notebook metadata fields are optional, see
1051     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1052     if a notebook has empty metadata, we will try to parse it anyway.
1053     """
1054     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1055     if language is not None and language != "python":
1056         raise NothingChanged from None
1057
1058
1059 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1060     """Format Jupyter notebook.
1061
1062     Operate cell-by-cell, only on code cells, only for Python notebooks.
1063     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1064     """
1065     trailing_newline = src_contents[-1] == "\n"
1066     modified = False
1067     nb = json.loads(src_contents)
1068     validate_metadata(nb)
1069     for cell in nb["cells"]:
1070         if cell.get("cell_type", None) == "code":
1071             try:
1072                 src = "".join(cell["source"])
1073                 dst = format_cell(src, fast=fast, mode=mode)
1074             except NothingChanged:
1075                 pass
1076             else:
1077                 cell["source"] = dst.splitlines(keepends=True)
1078                 modified = True
1079     if modified:
1080         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1081         if trailing_newline:
1082             dst_contents = dst_contents + "\n"
1083         return dst_contents
1084     else:
1085         raise NothingChanged
1086
1087
1088 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1089     """Reformat a string and return new contents.
1090
1091     `mode` determines formatting options, such as how many characters per line are
1092     allowed.  Example:
1093
1094     >>> import black
1095     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1096     def f(arg: str = "") -> None:
1097         ...
1098
1099     A more complex example:
1100
1101     >>> print(
1102     ...   black.format_str(
1103     ...     "def f(arg:str='')->None: hey",
1104     ...     mode=black.Mode(
1105     ...       target_versions={black.TargetVersion.PY36},
1106     ...       line_length=10,
1107     ...       string_normalization=False,
1108     ...       is_pyi=False,
1109     ...     ),
1110     ...   ),
1111     ... )
1112     def f(
1113         arg: str = '',
1114     ) -> None:
1115         hey
1116
1117     """
1118     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1119     dst_contents = []
1120     future_imports = get_future_imports(src_node)
1121     if mode.target_versions:
1122         versions = mode.target_versions
1123     else:
1124         versions = detect_target_versions(src_node, future_imports=future_imports)
1125
1126     normalize_fmt_off(src_node)
1127     lines = LineGenerator(mode=mode)
1128     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1129     empty_line = Line(mode=mode)
1130     after = 0
1131     split_line_features = {
1132         feature
1133         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1134         if supports_feature(versions, feature)
1135     }
1136     for current_line in lines.visit(src_node):
1137         dst_contents.append(str(empty_line) * after)
1138         before, after = elt.maybe_empty_lines(current_line)
1139         dst_contents.append(str(empty_line) * before)
1140         for line in transform_line(
1141             current_line, mode=mode, features=split_line_features
1142         ):
1143             dst_contents.append(str(line))
1144     return "".join(dst_contents)
1145
1146
1147 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1148     """Return a tuple of (decoded_contents, encoding, newline).
1149
1150     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1151     universal newlines (i.e. only contains LF).
1152     """
1153     srcbuf = io.BytesIO(src)
1154     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1155     if not lines:
1156         return "", encoding, "\n"
1157
1158     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1159     srcbuf.seek(0)
1160     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1161         return tiow.read(), encoding, newline
1162
1163
1164 def get_features_used(  # noqa: C901
1165     node: Node, *, future_imports: Optional[Set[str]] = None
1166 ) -> Set[Feature]:
1167     """Return a set of (relatively) new Python features used in this file.
1168
1169     Currently looking for:
1170     - f-strings;
1171     - underscores in numeric literals;
1172     - trailing commas after * or ** in function signatures and calls;
1173     - positional only arguments in function signatures and lambdas;
1174     - assignment expression;
1175     - relaxed decorator syntax;
1176     - usage of __future__ flags (annotations);
1177     - print / exec statements;
1178     """
1179     features: Set[Feature] = set()
1180     if future_imports:
1181         features |= {
1182             FUTURE_FLAG_TO_FEATURE[future_import]
1183             for future_import in future_imports
1184             if future_import in FUTURE_FLAG_TO_FEATURE
1185         }
1186
1187     for n in node.pre_order():
1188         if is_string_token(n):
1189             value_head = n.value[:2]
1190             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1191                 features.add(Feature.F_STRINGS)
1192
1193         elif n.type == token.NUMBER:
1194             assert isinstance(n, Leaf)
1195             if "_" in n.value:
1196                 features.add(Feature.NUMERIC_UNDERSCORES)
1197
1198         elif n.type == token.SLASH:
1199             if n.parent and n.parent.type in {
1200                 syms.typedargslist,
1201                 syms.arglist,
1202                 syms.varargslist,
1203             }:
1204                 features.add(Feature.POS_ONLY_ARGUMENTS)
1205
1206         elif n.type == token.COLONEQUAL:
1207             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1208
1209         elif n.type == syms.decorator:
1210             if len(n.children) > 1 and not is_simple_decorator_expression(
1211                 n.children[1]
1212             ):
1213                 features.add(Feature.RELAXED_DECORATORS)
1214
1215         elif (
1216             n.type in {syms.typedargslist, syms.arglist}
1217             and n.children
1218             and n.children[-1].type == token.COMMA
1219         ):
1220             if n.type == syms.typedargslist:
1221                 feature = Feature.TRAILING_COMMA_IN_DEF
1222             else:
1223                 feature = Feature.TRAILING_COMMA_IN_CALL
1224
1225             for ch in n.children:
1226                 if ch.type in STARS:
1227                     features.add(feature)
1228
1229                 if ch.type == syms.argument:
1230                     for argch in ch.children:
1231                         if argch.type in STARS:
1232                             features.add(feature)
1233
1234         elif (
1235             n.type in {syms.return_stmt, syms.yield_expr}
1236             and len(n.children) >= 2
1237             and n.children[1].type == syms.testlist_star_expr
1238             and any(child.type == syms.star_expr for child in n.children[1].children)
1239         ):
1240             features.add(Feature.UNPACKING_ON_FLOW)
1241
1242         elif (
1243             n.type == syms.annassign
1244             and len(n.children) >= 4
1245             and n.children[3].type == syms.testlist_star_expr
1246         ):
1247             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1248
1249     return features
1250
1251
1252 def detect_target_versions(
1253     node: Node, *, future_imports: Optional[Set[str]] = None
1254 ) -> Set[TargetVersion]:
1255     """Detect the version to target based on the nodes used."""
1256     features = get_features_used(node, future_imports=future_imports)
1257     return {
1258         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1259     }
1260
1261
1262 def get_future_imports(node: Node) -> Set[str]:
1263     """Return a set of __future__ imports in the file."""
1264     imports: Set[str] = set()
1265
1266     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1267         for child in children:
1268             if isinstance(child, Leaf):
1269                 if child.type == token.NAME:
1270                     yield child.value
1271
1272             elif child.type == syms.import_as_name:
1273                 orig_name = child.children[0]
1274                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1275                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1276                 yield orig_name.value
1277
1278             elif child.type == syms.import_as_names:
1279                 yield from get_imports_from_children(child.children)
1280
1281             else:
1282                 raise AssertionError("Invalid syntax parsing imports")
1283
1284     for child in node.children:
1285         if child.type != syms.simple_stmt:
1286             break
1287
1288         first_child = child.children[0]
1289         if isinstance(first_child, Leaf):
1290             # Continue looking if we see a docstring; otherwise stop.
1291             if (
1292                 len(child.children) == 2
1293                 and first_child.type == token.STRING
1294                 and child.children[1].type == token.NEWLINE
1295             ):
1296                 continue
1297
1298             break
1299
1300         elif first_child.type == syms.import_from:
1301             module_name = first_child.children[1]
1302             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1303                 break
1304
1305             imports |= set(get_imports_from_children(first_child.children[3:]))
1306         else:
1307             break
1308
1309     return imports
1310
1311
1312 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1313     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1314     try:
1315         src_ast = parse_ast(src)
1316     except Exception as exc:
1317         raise AssertionError(
1318             f"cannot use --safe with this file; failed to parse source file: {exc}"
1319         ) from exc
1320
1321     try:
1322         dst_ast = parse_ast(dst)
1323     except Exception as exc:
1324         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1325         raise AssertionError(
1326             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1327             "Please report a bug on https://github.com/psf/black/issues.  "
1328             f"This invalid output might be helpful: {log}"
1329         ) from None
1330
1331     src_ast_str = "\n".join(stringify_ast(src_ast))
1332     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1333     if src_ast_str != dst_ast_str:
1334         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1335         raise AssertionError(
1336             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1337             f" source on pass {pass_num}.  Please report a bug on "
1338             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1339         ) from None
1340
1341
1342 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1343     """Raise AssertionError if `dst` reformats differently the second time."""
1344     newdst = format_str(dst, mode=mode)
1345     if dst != newdst:
1346         log = dump_to_file(
1347             str(mode),
1348             diff(src, dst, "source", "first pass"),
1349             diff(dst, newdst, "first pass", "second pass"),
1350         )
1351         raise AssertionError(
1352             "INTERNAL ERROR: Black produced different code on the second pass of the"
1353             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1354             f"  This diff might be helpful: {log}"
1355         ) from None
1356
1357
1358 @contextmanager
1359 def nullcontext() -> Iterator[None]:
1360     """Return an empty context manager.
1361
1362     To be used like `nullcontext` in Python 3.7.
1363     """
1364     yield
1365
1366
1367 def patch_click() -> None:
1368     """Make Click not crash on Python 3.6 with LANG=C.
1369
1370     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1371     default which restricts paths that it can access during the lifetime of the
1372     application.  Click refuses to work in this scenario by raising a RuntimeError.
1373
1374     In case of Black the likelihood that non-ASCII characters are going to be used in
1375     file paths is minimal since it's Python source code.  Moreover, this crash was
1376     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1377     """
1378     try:
1379         from click import core
1380         from click import _unicodefun
1381     except ModuleNotFoundError:
1382         return
1383
1384     for module in (core, _unicodefun):
1385         if hasattr(module, "_verify_python3_env"):
1386             module._verify_python3_env = lambda: None  # type: ignore
1387         if hasattr(module, "_verify_python_env"):
1388             module._verify_python_env = lambda: None  # type: ignore
1389
1390
1391 def patched_main() -> None:
1392     maybe_install_uvloop()
1393     freeze_support()
1394     patch_click()
1395     main()
1396
1397
1398 if __name__ == "__main__":
1399     patched_main()