]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

67c272e3cc9a5e11ef7fe5419a08d0a413b3c9cd
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 import click
34 from click.core import ParameterSource
35 from dataclasses import replace
36 from mypy_extensions import mypyc_attr
37
38 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
39 from black.const import STDIN_PLACEHOLDER
40 from black.nodes import STARS, syms, is_simple_decorator_expression
41 from black.nodes import is_string_token
42 from black.lines import Line, EmptyLineTracker
43 from black.linegen import transform_line, LineGenerator, LN
44 from black.comments import normalize_fmt_off
45 from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
46 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
47 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
48 from black.concurrency import cancel, shutdown, maybe_install_uvloop
49 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
50 from black.report import Report, Changed, NothingChanged
51 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
52 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
53 from black.files import wrap_stream_for_windows
54 from black.parsing import InvalidInput  # noqa F401
55 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
56 from black.handle_ipynb_magics import (
57     mask_cell,
58     unmask_cell,
59     remove_trailing_semicolon,
60     put_trailing_semicolon_back,
61     TRANSFORMED_MAGICS,
62     PYTHON_CELL_MAGICS,
63     jupyter_dependencies_are_installed,
64 )
65
66
67 # lib2to3 fork
68 from blib2to3.pytree import Node, Leaf
69 from blib2to3.pgen2 import token
70
71 from _black_version import version as __version__
72
73 COMPILED = Path(__file__).suffix in (".pyd", ".so")
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79
80
81 class WriteBack(Enum):
82     NO = 0
83     YES = 1
84     DIFF = 2
85     CHECK = 3
86     COLOR_DIFF = 4
87
88     @classmethod
89     def from_configuration(
90         cls, *, check: bool, diff: bool, color: bool = False
91     ) -> "WriteBack":
92         if check and not diff:
93             return cls.CHECK
94
95         if diff and color:
96             return cls.COLOR_DIFF
97
98         return cls.DIFF if diff else cls.YES
99
100
101 # Legacy name, left for integrations.
102 FileMode = Mode
103
104 DEFAULT_WORKERS = os.cpu_count()
105
106
107 def read_pyproject_toml(
108     ctx: click.Context, param: click.Parameter, value: Optional[str]
109 ) -> Optional[str]:
110     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
111
112     Returns the path to a successfully found and read configuration file, None
113     otherwise.
114     """
115     if not value:
116         value = find_pyproject_toml(ctx.params.get("src", ()))
117         if value is None:
118             return None
119
120     try:
121         config = parse_pyproject_toml(value)
122     except (OSError, ValueError) as e:
123         raise click.FileError(
124             filename=value, hint=f"Error reading configuration file: {e}"
125         ) from None
126
127     if not config:
128         return None
129     else:
130         # Sanitize the values to be Click friendly. For more information please see:
131         # https://github.com/psf/black/issues/1458
132         # https://github.com/pallets/click/issues/1567
133         config = {
134             k: str(v) if not isinstance(v, (list, dict)) else v
135             for k, v in config.items()
136         }
137
138     target_version = config.get("target_version")
139     if target_version is not None and not isinstance(target_version, list):
140         raise click.BadOptionUsage(
141             "target-version", "Config key target-version must be a list"
142         )
143
144     default_map: Dict[str, Any] = {}
145     if ctx.default_map:
146         default_map.update(ctx.default_map)
147     default_map.update(config)
148
149     ctx.default_map = default_map
150     return value
151
152
153 def target_version_option_callback(
154     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
155 ) -> List[TargetVersion]:
156     """Compute the target versions from a --target-version flag.
157
158     This is its own function because mypy couldn't infer the type correctly
159     when it was a lambda, causing mypyc trouble.
160     """
161     return [TargetVersion[val.upper()] for val in v]
162
163
164 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
165     """Compile a regular expression string in `regex`.
166
167     If it contains newlines, use verbose mode.
168     """
169     if "\n" in regex:
170         regex = "(?x)" + regex
171     compiled: Pattern[str] = re.compile(regex)
172     return compiled
173
174
175 def validate_regex(
176     ctx: click.Context,
177     param: click.Parameter,
178     value: Optional[str],
179 ) -> Optional[Pattern[str]]:
180     try:
181         return re_compile_maybe_verbose(value) if value is not None else None
182     except re.error as e:
183         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
184
185
186 @click.command(
187     context_settings={"help_option_names": ["-h", "--help"]},
188     # While Click does set this field automatically using the docstring, mypyc
189     # (annoyingly) strips 'em so we need to set it here too.
190     help="The uncompromising code formatter.",
191 )
192 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
193 @click.option(
194     "-l",
195     "--line-length",
196     type=int,
197     default=DEFAULT_LINE_LENGTH,
198     help="How many characters per line to allow.",
199     show_default=True,
200 )
201 @click.option(
202     "-t",
203     "--target-version",
204     type=click.Choice([v.name.lower() for v in TargetVersion]),
205     callback=target_version_option_callback,
206     multiple=True,
207     help=(
208         "Python versions that should be supported by Black's output. [default: per-file"
209         " auto-detection]"
210     ),
211 )
212 @click.option(
213     "--pyi",
214     is_flag=True,
215     help=(
216         "Format all input files like typing stubs regardless of file extension (useful"
217         " when piping source on standard input)."
218     ),
219 )
220 @click.option(
221     "--ipynb",
222     is_flag=True,
223     help=(
224         "Format all input files like Jupyter Notebooks regardless of file extension "
225         "(useful when piping source on standard input)."
226     ),
227 )
228 @click.option(
229     "-S",
230     "--skip-string-normalization",
231     is_flag=True,
232     help="Don't normalize string quotes or prefixes.",
233 )
234 @click.option(
235     "-C",
236     "--skip-magic-trailing-comma",
237     is_flag=True,
238     help="Don't use trailing commas as a reason to split lines.",
239 )
240 @click.option(
241     "--experimental-string-processing",
242     is_flag=True,
243     hidden=True,
244     help="(DEPRECATED and now included in --preview) Normalize string literals.",
245 )
246 @click.option(
247     "--preview",
248     is_flag=True,
249     help=(
250         "Enable potentially disruptive style changes that will be added to Black's main"
251         " functionality in the next major release."
252     ),
253 )
254 @click.option(
255     "--check",
256     is_flag=True,
257     help=(
258         "Don't write the files back, just return the status. Return code 0 means"
259         " nothing would change. Return code 1 means some files would be reformatted."
260         " Return code 123 means there was an internal error."
261     ),
262 )
263 @click.option(
264     "--diff",
265     is_flag=True,
266     help="Don't write the files back, just output a diff for each file on stdout.",
267 )
268 @click.option(
269     "--color/--no-color",
270     is_flag=True,
271     help="Show colored diff. Only applies when `--diff` is given.",
272 )
273 @click.option(
274     "--fast/--safe",
275     is_flag=True,
276     help="If --fast given, skip temporary sanity checks. [default: --safe]",
277 )
278 @click.option(
279     "--required-version",
280     type=str,
281     help=(
282         "Require a specific version of Black to be running (useful for unifying results"
283         " across many environments e.g. with a pyproject.toml file)."
284     ),
285 )
286 @click.option(
287     "--include",
288     type=str,
289     default=DEFAULT_INCLUDES,
290     callback=validate_regex,
291     help=(
292         "A regular expression that matches files and directories that should be"
293         " included on recursive searches. An empty value means all files are included"
294         " regardless of the name. Use forward slashes for directories on all platforms"
295         " (Windows, too). Exclusions are calculated first, inclusions later."
296     ),
297     show_default=True,
298 )
299 @click.option(
300     "--exclude",
301     type=str,
302     callback=validate_regex,
303     help=(
304         "A regular expression that matches files and directories that should be"
305         " excluded on recursive searches. An empty value means no paths are excluded."
306         " Use forward slashes for directories on all platforms (Windows, too)."
307         " Exclusions are calculated first, inclusions later. [default:"
308         f" {DEFAULT_EXCLUDES}]"
309     ),
310     show_default=False,
311 )
312 @click.option(
313     "--extend-exclude",
314     type=str,
315     callback=validate_regex,
316     help=(
317         "Like --exclude, but adds additional files and directories on top of the"
318         " excluded ones. (Useful if you simply want to add to the default)"
319     ),
320 )
321 @click.option(
322     "--force-exclude",
323     type=str,
324     callback=validate_regex,
325     help=(
326         "Like --exclude, but files and directories matching this regex will be "
327         "excluded even when they are passed explicitly as arguments."
328     ),
329 )
330 @click.option(
331     "--stdin-filename",
332     type=str,
333     help=(
334         "The name of the file when passing it through stdin. Useful to make "
335         "sure Black will respect --force-exclude option on some "
336         "editors that rely on using stdin."
337     ),
338 )
339 @click.option(
340     "-W",
341     "--workers",
342     type=click.IntRange(min=1),
343     default=DEFAULT_WORKERS,
344     show_default=True,
345     help="Number of parallel workers",
346 )
347 @click.option(
348     "-q",
349     "--quiet",
350     is_flag=True,
351     help=(
352         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
353         " those with 2>/dev/null."
354     ),
355 )
356 @click.option(
357     "-v",
358     "--verbose",
359     is_flag=True,
360     help=(
361         "Also emit messages to stderr about files that were not changed or were ignored"
362         " due to exclusion patterns."
363     ),
364 )
365 @click.version_option(
366     version=__version__,
367     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
368 )
369 @click.argument(
370     "src",
371     nargs=-1,
372     type=click.Path(
373         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
374     ),
375     is_eager=True,
376     metavar="SRC ...",
377 )
378 @click.option(
379     "--config",
380     type=click.Path(
381         exists=True,
382         file_okay=True,
383         dir_okay=False,
384         readable=True,
385         allow_dash=False,
386         path_type=str,
387     ),
388     is_eager=True,
389     callback=read_pyproject_toml,
390     help="Read configuration from FILE path.",
391 )
392 @click.pass_context
393 def main(
394     ctx: click.Context,
395     code: Optional[str],
396     line_length: int,
397     target_version: List[TargetVersion],
398     check: bool,
399     diff: bool,
400     color: bool,
401     fast: bool,
402     pyi: bool,
403     ipynb: bool,
404     skip_string_normalization: bool,
405     skip_magic_trailing_comma: bool,
406     experimental_string_processing: bool,
407     preview: bool,
408     quiet: bool,
409     verbose: bool,
410     required_version: Optional[str],
411     include: Pattern[str],
412     exclude: Optional[Pattern[str]],
413     extend_exclude: Optional[Pattern[str]],
414     force_exclude: Optional[Pattern[str]],
415     stdin_filename: Optional[str],
416     workers: int,
417     src: Tuple[str, ...],
418     config: Optional[str],
419 ) -> None:
420     """The uncompromising code formatter."""
421     ctx.ensure_object(dict)
422     root, method = find_project_root(src) if code is None else (None, None)
423     ctx.obj["root"] = root
424
425     if verbose:
426         if root:
427             out(
428                 f"Identified `{root}` as project root containing a {method}.",
429                 fg="blue",
430             )
431
432             normalized = [
433                 (normalize_path_maybe_ignore(Path(source), root), source)
434                 for source in src
435             ]
436             srcs_string = ", ".join(
437                 [
438                     f'"{_norm}"'
439                     if _norm
440                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
441                     for _norm, source in normalized
442                 ]
443             )
444             out(f"Sources to be formatted: {srcs_string}", fg="blue")
445
446         if config:
447             config_source = ctx.get_parameter_source("config")
448             if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP):
449                 out("Using configuration from project root.", fg="blue")
450             else:
451                 out(f"Using configuration in '{config}'.", fg="blue")
452
453     error_msg = "Oh no! 💥 💔 💥"
454     if required_version and required_version != __version__:
455         err(
456             f"{error_msg} The required version `{required_version}` does not match"
457             f" the running version `{__version__}`!"
458         )
459         ctx.exit(1)
460     if ipynb and pyi:
461         err("Cannot pass both `pyi` and `ipynb` flags!")
462         ctx.exit(1)
463
464     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
465     if target_version:
466         versions = set(target_version)
467     else:
468         # We'll autodetect later.
469         versions = set()
470     mode = Mode(
471         target_versions=versions,
472         line_length=line_length,
473         is_pyi=pyi,
474         is_ipynb=ipynb,
475         string_normalization=not skip_string_normalization,
476         magic_trailing_comma=not skip_magic_trailing_comma,
477         experimental_string_processing=experimental_string_processing,
478         preview=preview,
479     )
480
481     if code is not None:
482         # Run in quiet mode by default with -c; the extra output isn't useful.
483         # You can still pass -v to get verbose output.
484         quiet = True
485
486     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
487
488     if code is not None:
489         reformat_code(
490             content=code, fast=fast, write_back=write_back, mode=mode, report=report
491         )
492     else:
493         try:
494             sources = get_sources(
495                 ctx=ctx,
496                 src=src,
497                 quiet=quiet,
498                 verbose=verbose,
499                 include=include,
500                 exclude=exclude,
501                 extend_exclude=extend_exclude,
502                 force_exclude=force_exclude,
503                 report=report,
504                 stdin_filename=stdin_filename,
505             )
506         except GitWildMatchPatternError:
507             ctx.exit(1)
508
509         path_empty(
510             sources,
511             "No Python files are present to be formatted. Nothing to do 😴",
512             quiet,
513             verbose,
514             ctx,
515         )
516
517         if len(sources) == 1:
518             reformat_one(
519                 src=sources.pop(),
520                 fast=fast,
521                 write_back=write_back,
522                 mode=mode,
523                 report=report,
524             )
525         else:
526             reformat_many(
527                 sources=sources,
528                 fast=fast,
529                 write_back=write_back,
530                 mode=mode,
531                 report=report,
532                 workers=workers,
533             )
534
535     if verbose or not quiet:
536         if code is None and (verbose or report.change_count or report.failure_count):
537             out()
538         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
539         if code is None:
540             click.echo(str(report), err=True)
541     ctx.exit(report.return_code)
542
543
544 def get_sources(
545     *,
546     ctx: click.Context,
547     src: Tuple[str, ...],
548     quiet: bool,
549     verbose: bool,
550     include: Pattern[str],
551     exclude: Optional[Pattern[str]],
552     extend_exclude: Optional[Pattern[str]],
553     force_exclude: Optional[Pattern[str]],
554     report: "Report",
555     stdin_filename: Optional[str],
556 ) -> Set[Path]:
557     """Compute the set of files to be formatted."""
558     sources: Set[Path] = set()
559     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
560
561     if exclude is None:
562         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
563         gitignore = get_gitignore(ctx.obj["root"])
564     else:
565         gitignore = None
566
567     for s in src:
568         if s == "-" and stdin_filename:
569             p = Path(stdin_filename)
570             is_stdin = True
571         else:
572             p = Path(s)
573             is_stdin = False
574
575         if is_stdin or p.is_file():
576             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
577             if normalized_path is None:
578                 continue
579
580             normalized_path = "/" + normalized_path
581             # Hard-exclude any files that matches the `--force-exclude` regex.
582             if force_exclude:
583                 force_exclude_match = force_exclude.search(normalized_path)
584             else:
585                 force_exclude_match = None
586             if force_exclude_match and force_exclude_match.group(0):
587                 report.path_ignored(p, "matches the --force-exclude regular expression")
588                 continue
589
590             if is_stdin:
591                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
592
593             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
594                 verbose=verbose, quiet=quiet
595             ):
596                 continue
597
598             sources.add(p)
599         elif p.is_dir():
600             sources.update(
601                 gen_python_files(
602                     p.iterdir(),
603                     ctx.obj["root"],
604                     include,
605                     exclude,
606                     extend_exclude,
607                     force_exclude,
608                     report,
609                     gitignore,
610                     verbose=verbose,
611                     quiet=quiet,
612                 )
613             )
614         elif s == "-":
615             sources.add(p)
616         else:
617             err(f"invalid path: {s}")
618     return sources
619
620
621 def path_empty(
622     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
623 ) -> None:
624     """
625     Exit if there is no `src` provided for formatting
626     """
627     if not src:
628         if verbose or not quiet:
629             out(msg)
630         ctx.exit(0)
631
632
633 def reformat_code(
634     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
635 ) -> None:
636     """
637     Reformat and print out `content` without spawning child processes.
638     Similar to `reformat_one`, but for string content.
639
640     `fast`, `write_back`, and `mode` options are passed to
641     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
642     """
643     path = Path("<string>")
644     try:
645         changed = Changed.NO
646         if format_stdin_to_stdout(
647             content=content, fast=fast, write_back=write_back, mode=mode
648         ):
649             changed = Changed.YES
650         report.done(path, changed)
651     except Exception as exc:
652         if report.verbose:
653             traceback.print_exc()
654         report.failed(path, str(exc))
655
656
657 def reformat_one(
658     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
659 ) -> None:
660     """Reformat a single file under `src` without spawning child processes.
661
662     `fast`, `write_back`, and `mode` options are passed to
663     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
664     """
665     try:
666         changed = Changed.NO
667
668         if str(src) == "-":
669             is_stdin = True
670         elif str(src).startswith(STDIN_PLACEHOLDER):
671             is_stdin = True
672             # Use the original name again in case we want to print something
673             # to the user
674             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
675         else:
676             is_stdin = False
677
678         if is_stdin:
679             if src.suffix == ".pyi":
680                 mode = replace(mode, is_pyi=True)
681             elif src.suffix == ".ipynb":
682                 mode = replace(mode, is_ipynb=True)
683             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
684                 changed = Changed.YES
685         else:
686             cache: Cache = {}
687             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
688                 cache = read_cache(mode)
689                 res_src = src.resolve()
690                 res_src_s = str(res_src)
691                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
692                     changed = Changed.CACHED
693             if changed is not Changed.CACHED and format_file_in_place(
694                 src, fast=fast, write_back=write_back, mode=mode
695             ):
696                 changed = Changed.YES
697             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
698                 write_back is WriteBack.CHECK and changed is Changed.NO
699             ):
700                 write_cache(cache, [src], mode)
701         report.done(src, changed)
702     except Exception as exc:
703         if report.verbose:
704             traceback.print_exc()
705         report.failed(src, str(exc))
706
707
708 # diff-shades depends on being to monkeypatch this function to operate. I know it's
709 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
710 @mypyc_attr(patchable=True)
711 def reformat_many(
712     sources: Set[Path],
713     fast: bool,
714     write_back: WriteBack,
715     mode: Mode,
716     report: "Report",
717     workers: Optional[int],
718 ) -> None:
719     """Reformat multiple files using a ProcessPoolExecutor."""
720     executor: Executor
721     loop = asyncio.get_event_loop()
722     worker_count = workers if workers is not None else DEFAULT_WORKERS
723     if sys.platform == "win32":
724         # Work around https://bugs.python.org/issue26903
725         assert worker_count is not None
726         worker_count = min(worker_count, 60)
727     try:
728         executor = ProcessPoolExecutor(max_workers=worker_count)
729     except (ImportError, NotImplementedError, OSError):
730         # we arrive here if the underlying system does not support multi-processing
731         # like in AWS Lambda or Termux, in which case we gracefully fallback to
732         # a ThreadPoolExecutor with just a single worker (more workers would not do us
733         # any good due to the Global Interpreter Lock)
734         executor = ThreadPoolExecutor(max_workers=1)
735
736     try:
737         loop.run_until_complete(
738             schedule_formatting(
739                 sources=sources,
740                 fast=fast,
741                 write_back=write_back,
742                 mode=mode,
743                 report=report,
744                 loop=loop,
745                 executor=executor,
746             )
747         )
748     finally:
749         shutdown(loop)
750         if executor is not None:
751             executor.shutdown()
752
753
754 async def schedule_formatting(
755     sources: Set[Path],
756     fast: bool,
757     write_back: WriteBack,
758     mode: Mode,
759     report: "Report",
760     loop: asyncio.AbstractEventLoop,
761     executor: Executor,
762 ) -> None:
763     """Run formatting of `sources` in parallel using the provided `executor`.
764
765     (Use ProcessPoolExecutors for actual parallelism.)
766
767     `write_back`, `fast`, and `mode` options are passed to
768     :func:`format_file_in_place`.
769     """
770     cache: Cache = {}
771     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
772         cache = read_cache(mode)
773         sources, cached = filter_cached(cache, sources)
774         for src in sorted(cached):
775             report.done(src, Changed.CACHED)
776     if not sources:
777         return
778
779     cancelled = []
780     sources_to_cache = []
781     lock = None
782     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
783         # For diff output, we need locks to ensure we don't interleave output
784         # from different processes.
785         manager = Manager()
786         lock = manager.Lock()
787     tasks = {
788         asyncio.ensure_future(
789             loop.run_in_executor(
790                 executor, format_file_in_place, src, fast, mode, write_back, lock
791             )
792         ): src
793         for src in sorted(sources)
794     }
795     pending = tasks.keys()
796     try:
797         loop.add_signal_handler(signal.SIGINT, cancel, pending)
798         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
799     except NotImplementedError:
800         # There are no good alternatives for these on Windows.
801         pass
802     while pending:
803         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
804         for task in done:
805             src = tasks.pop(task)
806             if task.cancelled():
807                 cancelled.append(task)
808             elif task.exception():
809                 report.failed(src, str(task.exception()))
810             else:
811                 changed = Changed.YES if task.result() else Changed.NO
812                 # If the file was written back or was successfully checked as
813                 # well-formatted, store this information in the cache.
814                 if write_back is WriteBack.YES or (
815                     write_back is WriteBack.CHECK and changed is Changed.NO
816                 ):
817                     sources_to_cache.append(src)
818                 report.done(src, changed)
819     if cancelled:
820         if sys.version_info >= (3, 7):
821             await asyncio.gather(*cancelled, return_exceptions=True)
822         else:
823             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
824     if sources_to_cache:
825         write_cache(cache, sources_to_cache, mode)
826
827
828 def format_file_in_place(
829     src: Path,
830     fast: bool,
831     mode: Mode,
832     write_back: WriteBack = WriteBack.NO,
833     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
834 ) -> bool:
835     """Format file under `src` path. Return True if changed.
836
837     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
838     code to the file.
839     `mode` and `fast` options are passed to :func:`format_file_contents`.
840     """
841     if src.suffix == ".pyi":
842         mode = replace(mode, is_pyi=True)
843     elif src.suffix == ".ipynb":
844         mode = replace(mode, is_ipynb=True)
845
846     then = datetime.utcfromtimestamp(src.stat().st_mtime)
847     with open(src, "rb") as buf:
848         src_contents, encoding, newline = decode_bytes(buf.read())
849     try:
850         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
851     except NothingChanged:
852         return False
853     except JSONDecodeError:
854         raise ValueError(
855             f"File '{src}' cannot be parsed as valid Jupyter notebook."
856         ) from None
857
858     if write_back == WriteBack.YES:
859         with open(src, "w", encoding=encoding, newline=newline) as f:
860             f.write(dst_contents)
861     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
862         now = datetime.utcnow()
863         src_name = f"{src}\t{then} +0000"
864         dst_name = f"{src}\t{now} +0000"
865         if mode.is_ipynb:
866             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
867         else:
868             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
869
870         if write_back == WriteBack.COLOR_DIFF:
871             diff_contents = color_diff(diff_contents)
872
873         with lock or nullcontext():
874             f = io.TextIOWrapper(
875                 sys.stdout.buffer,
876                 encoding=encoding,
877                 newline=newline,
878                 write_through=True,
879             )
880             f = wrap_stream_for_windows(f)
881             f.write(diff_contents)
882             f.detach()
883
884     return True
885
886
887 def format_stdin_to_stdout(
888     fast: bool,
889     *,
890     content: Optional[str] = None,
891     write_back: WriteBack = WriteBack.NO,
892     mode: Mode,
893 ) -> bool:
894     """Format file on stdin. Return True if changed.
895
896     If content is None, it's read from sys.stdin.
897
898     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
899     write a diff to stdout. The `mode` argument is passed to
900     :func:`format_file_contents`.
901     """
902     then = datetime.utcnow()
903
904     if content is None:
905         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
906     else:
907         src, encoding, newline = content, "utf-8", ""
908
909     dst = src
910     try:
911         dst = format_file_contents(src, fast=fast, mode=mode)
912         return True
913
914     except NothingChanged:
915         return False
916
917     finally:
918         f = io.TextIOWrapper(
919             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
920         )
921         if write_back == WriteBack.YES:
922             # Make sure there's a newline after the content
923             if dst and dst[-1] != "\n":
924                 dst += "\n"
925             f.write(dst)
926         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
927             now = datetime.utcnow()
928             src_name = f"STDIN\t{then} +0000"
929             dst_name = f"STDOUT\t{now} +0000"
930             d = diff(src, dst, src_name, dst_name)
931             if write_back == WriteBack.COLOR_DIFF:
932                 d = color_diff(d)
933                 f = wrap_stream_for_windows(f)
934             f.write(d)
935         f.detach()
936
937
938 def check_stability_and_equivalence(
939     src_contents: str, dst_contents: str, *, mode: Mode
940 ) -> None:
941     """Perform stability and equivalence checks.
942
943     Raise AssertionError if source and destination contents are not
944     equivalent, or if a second pass of the formatter would format the
945     content differently.
946     """
947     assert_equivalent(src_contents, dst_contents)
948
949     # Forced second pass to work around optional trailing commas (becoming
950     # forced trailing commas on pass 2) interacting differently with optional
951     # parentheses.  Admittedly ugly.
952     dst_contents_pass2 = format_str(dst_contents, mode=mode)
953     if dst_contents != dst_contents_pass2:
954         dst_contents = dst_contents_pass2
955         assert_equivalent(src_contents, dst_contents, pass_num=2)
956         assert_stable(src_contents, dst_contents, mode=mode)
957     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
958     # the same as `dst_contents_pass2`.
959
960
961 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
962     """Reformat contents of a file and return new contents.
963
964     If `fast` is False, additionally confirm that the reformatted code is
965     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
966     `mode` is passed to :func:`format_str`.
967     """
968     if not src_contents.strip():
969         raise NothingChanged
970
971     if mode.is_ipynb:
972         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
973     else:
974         dst_contents = format_str(src_contents, mode=mode)
975     if src_contents == dst_contents:
976         raise NothingChanged
977
978     if not fast and not mode.is_ipynb:
979         # Jupyter notebooks will already have been checked above.
980         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
981     return dst_contents
982
983
984 def validate_cell(src: str) -> None:
985     """Check that cell does not already contain TransformerManager transformations,
986     or non-Python cell magics, which might cause tokenizer_rt to break because of
987     indentations.
988
989     If a cell contains ``!ls``, then it'll be transformed to
990     ``get_ipython().system('ls')``. However, if the cell originally contained
991     ``get_ipython().system('ls')``, then it would get transformed in the same way:
992
993         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
994         "get_ipython().system('ls')\n"
995         >>> TransformerManager().transform_cell("!ls")
996         "get_ipython().system('ls')\n"
997
998     Due to the impossibility of safely roundtripping in such situations, cells
999     containing transformed magics will be ignored.
1000     """
1001     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
1002         raise NothingChanged
1003     if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS:
1004         raise NothingChanged
1005
1006
1007 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
1008     """Format code in given cell of Jupyter notebook.
1009
1010     General idea is:
1011
1012       - if cell has trailing semicolon, remove it;
1013       - if cell has IPython magics, mask them;
1014       - format cell;
1015       - reinstate IPython magics;
1016       - reinstate trailing semicolon (if originally present);
1017       - strip trailing newlines.
1018
1019     Cells with syntax errors will not be processed, as they
1020     could potentially be automagics or multi-line magics, which
1021     are currently not supported.
1022     """
1023     validate_cell(src)
1024     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
1025         src
1026     )
1027     try:
1028         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
1029     except SyntaxError:
1030         raise NothingChanged from None
1031     masked_dst = format_str(masked_src, mode=mode)
1032     if not fast:
1033         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
1034     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1035     dst = put_trailing_semicolon_back(
1036         dst_without_trailing_semicolon, has_trailing_semicolon
1037     )
1038     dst = dst.rstrip("\n")
1039     if dst == src:
1040         raise NothingChanged from None
1041     return dst
1042
1043
1044 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1045     """If notebook is marked as non-Python, don't format it.
1046
1047     All notebook metadata fields are optional, see
1048     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1049     if a notebook has empty metadata, we will try to parse it anyway.
1050     """
1051     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1052     if language is not None and language != "python":
1053         raise NothingChanged from None
1054
1055
1056 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1057     """Format Jupyter notebook.
1058
1059     Operate cell-by-cell, only on code cells, only for Python notebooks.
1060     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1061     """
1062     trailing_newline = src_contents[-1] == "\n"
1063     modified = False
1064     nb = json.loads(src_contents)
1065     validate_metadata(nb)
1066     for cell in nb["cells"]:
1067         if cell.get("cell_type", None) == "code":
1068             try:
1069                 src = "".join(cell["source"])
1070                 dst = format_cell(src, fast=fast, mode=mode)
1071             except NothingChanged:
1072                 pass
1073             else:
1074                 cell["source"] = dst.splitlines(keepends=True)
1075                 modified = True
1076     if modified:
1077         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1078         if trailing_newline:
1079             dst_contents = dst_contents + "\n"
1080         return dst_contents
1081     else:
1082         raise NothingChanged
1083
1084
1085 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1086     """Reformat a string and return new contents.
1087
1088     `mode` determines formatting options, such as how many characters per line are
1089     allowed.  Example:
1090
1091     >>> import black
1092     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1093     def f(arg: str = "") -> None:
1094         ...
1095
1096     A more complex example:
1097
1098     >>> print(
1099     ...   black.format_str(
1100     ...     "def f(arg:str='')->None: hey",
1101     ...     mode=black.Mode(
1102     ...       target_versions={black.TargetVersion.PY36},
1103     ...       line_length=10,
1104     ...       string_normalization=False,
1105     ...       is_pyi=False,
1106     ...     ),
1107     ...   ),
1108     ... )
1109     def f(
1110         arg: str = '',
1111     ) -> None:
1112         hey
1113
1114     """
1115     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1116     dst_contents = []
1117     future_imports = get_future_imports(src_node)
1118     if mode.target_versions:
1119         versions = mode.target_versions
1120     else:
1121         versions = detect_target_versions(src_node, future_imports=future_imports)
1122
1123     normalize_fmt_off(src_node)
1124     lines = LineGenerator(mode=mode)
1125     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1126     empty_line = Line(mode=mode)
1127     after = 0
1128     split_line_features = {
1129         feature
1130         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1131         if supports_feature(versions, feature)
1132     }
1133     for current_line in lines.visit(src_node):
1134         dst_contents.append(str(empty_line) * after)
1135         before, after = elt.maybe_empty_lines(current_line)
1136         dst_contents.append(str(empty_line) * before)
1137         for line in transform_line(
1138             current_line, mode=mode, features=split_line_features
1139         ):
1140             dst_contents.append(str(line))
1141     return "".join(dst_contents)
1142
1143
1144 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1145     """Return a tuple of (decoded_contents, encoding, newline).
1146
1147     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1148     universal newlines (i.e. only contains LF).
1149     """
1150     srcbuf = io.BytesIO(src)
1151     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1152     if not lines:
1153         return "", encoding, "\n"
1154
1155     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1156     srcbuf.seek(0)
1157     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1158         return tiow.read(), encoding, newline
1159
1160
1161 def get_features_used(  # noqa: C901
1162     node: Node, *, future_imports: Optional[Set[str]] = None
1163 ) -> Set[Feature]:
1164     """Return a set of (relatively) new Python features used in this file.
1165
1166     Currently looking for:
1167     - f-strings;
1168     - underscores in numeric literals;
1169     - trailing commas after * or ** in function signatures and calls;
1170     - positional only arguments in function signatures and lambdas;
1171     - assignment expression;
1172     - relaxed decorator syntax;
1173     - usage of __future__ flags (annotations);
1174     - print / exec statements;
1175     """
1176     features: Set[Feature] = set()
1177     if future_imports:
1178         features |= {
1179             FUTURE_FLAG_TO_FEATURE[future_import]
1180             for future_import in future_imports
1181             if future_import in FUTURE_FLAG_TO_FEATURE
1182         }
1183
1184     for n in node.pre_order():
1185         if is_string_token(n):
1186             value_head = n.value[:2]
1187             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1188                 features.add(Feature.F_STRINGS)
1189
1190         elif n.type == token.NUMBER:
1191             assert isinstance(n, Leaf)
1192             if "_" in n.value:
1193                 features.add(Feature.NUMERIC_UNDERSCORES)
1194
1195         elif n.type == token.SLASH:
1196             if n.parent and n.parent.type in {
1197                 syms.typedargslist,
1198                 syms.arglist,
1199                 syms.varargslist,
1200             }:
1201                 features.add(Feature.POS_ONLY_ARGUMENTS)
1202
1203         elif n.type == token.COLONEQUAL:
1204             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1205
1206         elif n.type == syms.decorator:
1207             if len(n.children) > 1 and not is_simple_decorator_expression(
1208                 n.children[1]
1209             ):
1210                 features.add(Feature.RELAXED_DECORATORS)
1211
1212         elif (
1213             n.type in {syms.typedargslist, syms.arglist}
1214             and n.children
1215             and n.children[-1].type == token.COMMA
1216         ):
1217             if n.type == syms.typedargslist:
1218                 feature = Feature.TRAILING_COMMA_IN_DEF
1219             else:
1220                 feature = Feature.TRAILING_COMMA_IN_CALL
1221
1222             for ch in n.children:
1223                 if ch.type in STARS:
1224                     features.add(feature)
1225
1226                 if ch.type == syms.argument:
1227                     for argch in ch.children:
1228                         if argch.type in STARS:
1229                             features.add(feature)
1230
1231         elif (
1232             n.type in {syms.return_stmt, syms.yield_expr}
1233             and len(n.children) >= 2
1234             and n.children[1].type == syms.testlist_star_expr
1235             and any(child.type == syms.star_expr for child in n.children[1].children)
1236         ):
1237             features.add(Feature.UNPACKING_ON_FLOW)
1238
1239         elif (
1240             n.type == syms.annassign
1241             and len(n.children) >= 4
1242             and n.children[3].type == syms.testlist_star_expr
1243         ):
1244             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1245
1246     return features
1247
1248
1249 def detect_target_versions(
1250     node: Node, *, future_imports: Optional[Set[str]] = None
1251 ) -> Set[TargetVersion]:
1252     """Detect the version to target based on the nodes used."""
1253     features = get_features_used(node, future_imports=future_imports)
1254     return {
1255         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1256     }
1257
1258
1259 def get_future_imports(node: Node) -> Set[str]:
1260     """Return a set of __future__ imports in the file."""
1261     imports: Set[str] = set()
1262
1263     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1264         for child in children:
1265             if isinstance(child, Leaf):
1266                 if child.type == token.NAME:
1267                     yield child.value
1268
1269             elif child.type == syms.import_as_name:
1270                 orig_name = child.children[0]
1271                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1272                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1273                 yield orig_name.value
1274
1275             elif child.type == syms.import_as_names:
1276                 yield from get_imports_from_children(child.children)
1277
1278             else:
1279                 raise AssertionError("Invalid syntax parsing imports")
1280
1281     for child in node.children:
1282         if child.type != syms.simple_stmt:
1283             break
1284
1285         first_child = child.children[0]
1286         if isinstance(first_child, Leaf):
1287             # Continue looking if we see a docstring; otherwise stop.
1288             if (
1289                 len(child.children) == 2
1290                 and first_child.type == token.STRING
1291                 and child.children[1].type == token.NEWLINE
1292             ):
1293                 continue
1294
1295             break
1296
1297         elif first_child.type == syms.import_from:
1298             module_name = first_child.children[1]
1299             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1300                 break
1301
1302             imports |= set(get_imports_from_children(first_child.children[3:]))
1303         else:
1304             break
1305
1306     return imports
1307
1308
1309 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1310     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1311     try:
1312         src_ast = parse_ast(src)
1313     except Exception as exc:
1314         raise AssertionError(
1315             f"cannot use --safe with this file; failed to parse source file: {exc}"
1316         ) from exc
1317
1318     try:
1319         dst_ast = parse_ast(dst)
1320     except Exception as exc:
1321         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1322         raise AssertionError(
1323             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1324             "Please report a bug on https://github.com/psf/black/issues.  "
1325             f"This invalid output might be helpful: {log}"
1326         ) from None
1327
1328     src_ast_str = "\n".join(stringify_ast(src_ast))
1329     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1330     if src_ast_str != dst_ast_str:
1331         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1332         raise AssertionError(
1333             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1334             f" source on pass {pass_num}.  Please report a bug on "
1335             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1336         ) from None
1337
1338
1339 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1340     """Raise AssertionError if `dst` reformats differently the second time."""
1341     newdst = format_str(dst, mode=mode)
1342     if dst != newdst:
1343         log = dump_to_file(
1344             str(mode),
1345             diff(src, dst, "source", "first pass"),
1346             diff(dst, newdst, "first pass", "second pass"),
1347         )
1348         raise AssertionError(
1349             "INTERNAL ERROR: Black produced different code on the second pass of the"
1350             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1351             f"  This diff might be helpful: {log}"
1352         ) from None
1353
1354
1355 @contextmanager
1356 def nullcontext() -> Iterator[None]:
1357     """Return an empty context manager.
1358
1359     To be used like `nullcontext` in Python 3.7.
1360     """
1361     yield
1362
1363
1364 def patch_click() -> None:
1365     """Make Click not crash on Python 3.6 with LANG=C.
1366
1367     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1368     default which restricts paths that it can access during the lifetime of the
1369     application.  Click refuses to work in this scenario by raising a RuntimeError.
1370
1371     In case of Black the likelihood that non-ASCII characters are going to be used in
1372     file paths is minimal since it's Python source code.  Moreover, this crash was
1373     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1374     """
1375     try:
1376         from click import core
1377         from click import _unicodefun
1378     except ModuleNotFoundError:
1379         return
1380
1381     for module in (core, _unicodefun):
1382         if hasattr(module, "_verify_python3_env"):
1383             module._verify_python3_env = lambda: None  # type: ignore
1384         if hasattr(module, "_verify_python_env"):
1385             module._verify_python_env = lambda: None  # type: ignore
1386
1387
1388 def patched_main() -> None:
1389     maybe_install_uvloop()
1390     freeze_support()
1391     patch_click()
1392     main()
1393
1394
1395 if __name__ == "__main__":
1396     patched_main()