]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

d8b98196aa068e9be8c192d71358333dab9377bf
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 import click
34 from dataclasses import replace
35 from mypy_extensions import mypyc_attr
36
37 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
38 from black.const import STDIN_PLACEHOLDER
39 from black.nodes import STARS, syms, is_simple_decorator_expression
40 from black.lines import Line, EmptyLineTracker
41 from black.linegen import transform_line, LineGenerator, LN
42 from black.comments import normalize_fmt_off
43 from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
44 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
45 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
46 from black.concurrency import cancel, shutdown, maybe_install_uvloop
47 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
48 from black.report import Report, Changed, NothingChanged
49 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
50 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
51 from black.files import wrap_stream_for_windows
52 from black.parsing import InvalidInput  # noqa F401
53 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
54 from black.handle_ipynb_magics import (
55     mask_cell,
56     unmask_cell,
57     remove_trailing_semicolon,
58     put_trailing_semicolon_back,
59     TRANSFORMED_MAGICS,
60     PYTHON_CELL_MAGICS,
61     jupyter_dependencies_are_installed,
62 )
63
64
65 # lib2to3 fork
66 from blib2to3.pytree import Node, Leaf
67 from blib2to3.pgen2 import token
68
69 from _black_version import version as __version__
70
71 COMPILED = Path(__file__).suffix in (".pyd", ".so")
72
73 # types
74 FileContent = str
75 Encoding = str
76 NewLine = str
77
78
79 class WriteBack(Enum):
80     NO = 0
81     YES = 1
82     DIFF = 2
83     CHECK = 3
84     COLOR_DIFF = 4
85
86     @classmethod
87     def from_configuration(
88         cls, *, check: bool, diff: bool, color: bool = False
89     ) -> "WriteBack":
90         if check and not diff:
91             return cls.CHECK
92
93         if diff and color:
94             return cls.COLOR_DIFF
95
96         return cls.DIFF if diff else cls.YES
97
98
99 # Legacy name, left for integrations.
100 FileMode = Mode
101
102 DEFAULT_WORKERS = os.cpu_count()
103
104
105 def read_pyproject_toml(
106     ctx: click.Context, param: click.Parameter, value: Optional[str]
107 ) -> Optional[str]:
108     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
109
110     Returns the path to a successfully found and read configuration file, None
111     otherwise.
112     """
113     if not value:
114         value = find_pyproject_toml(ctx.params.get("src", ()))
115         if value is None:
116             return None
117
118     try:
119         config = parse_pyproject_toml(value)
120     except (OSError, ValueError) as e:
121         raise click.FileError(
122             filename=value, hint=f"Error reading configuration file: {e}"
123         ) from None
124
125     if not config:
126         return None
127     else:
128         # Sanitize the values to be Click friendly. For more information please see:
129         # https://github.com/psf/black/issues/1458
130         # https://github.com/pallets/click/issues/1567
131         config = {
132             k: str(v) if not isinstance(v, (list, dict)) else v
133             for k, v in config.items()
134         }
135
136     target_version = config.get("target_version")
137     if target_version is not None and not isinstance(target_version, list):
138         raise click.BadOptionUsage(
139             "target-version", "Config key target-version must be a list"
140         )
141
142     default_map: Dict[str, Any] = {}
143     if ctx.default_map:
144         default_map.update(ctx.default_map)
145     default_map.update(config)
146
147     ctx.default_map = default_map
148     return value
149
150
151 def target_version_option_callback(
152     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
153 ) -> List[TargetVersion]:
154     """Compute the target versions from a --target-version flag.
155
156     This is its own function because mypy couldn't infer the type correctly
157     when it was a lambda, causing mypyc trouble.
158     """
159     return [TargetVersion[val.upper()] for val in v]
160
161
162 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
163     """Compile a regular expression string in `regex`.
164
165     If it contains newlines, use verbose mode.
166     """
167     if "\n" in regex:
168         regex = "(?x)" + regex
169     compiled: Pattern[str] = re.compile(regex)
170     return compiled
171
172
173 def validate_regex(
174     ctx: click.Context,
175     param: click.Parameter,
176     value: Optional[str],
177 ) -> Optional[Pattern[str]]:
178     try:
179         return re_compile_maybe_verbose(value) if value is not None else None
180     except re.error as e:
181         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
182
183
184 @click.command(
185     context_settings={"help_option_names": ["-h", "--help"]},
186     # While Click does set this field automatically using the docstring, mypyc
187     # (annoyingly) strips 'em so we need to set it here too.
188     help="The uncompromising code formatter.",
189 )
190 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
191 @click.option(
192     "-l",
193     "--line-length",
194     type=int,
195     default=DEFAULT_LINE_LENGTH,
196     help="How many characters per line to allow.",
197     show_default=True,
198 )
199 @click.option(
200     "-t",
201     "--target-version",
202     type=click.Choice([v.name.lower() for v in TargetVersion]),
203     callback=target_version_option_callback,
204     multiple=True,
205     help=(
206         "Python versions that should be supported by Black's output. [default: per-file"
207         " auto-detection]"
208     ),
209 )
210 @click.option(
211     "--pyi",
212     is_flag=True,
213     help=(
214         "Format all input files like typing stubs regardless of file extension (useful"
215         " when piping source on standard input)."
216     ),
217 )
218 @click.option(
219     "--ipynb",
220     is_flag=True,
221     help=(
222         "Format all input files like Jupyter Notebooks regardless of file extension "
223         "(useful when piping source on standard input)."
224     ),
225 )
226 @click.option(
227     "-S",
228     "--skip-string-normalization",
229     is_flag=True,
230     help="Don't normalize string quotes or prefixes.",
231 )
232 @click.option(
233     "-C",
234     "--skip-magic-trailing-comma",
235     is_flag=True,
236     help="Don't use trailing commas as a reason to split lines.",
237 )
238 @click.option(
239     "--experimental-string-processing",
240     is_flag=True,
241     hidden=True,
242     help=(
243         "Experimental option that performs more normalization on string literals."
244         " Currently disabled because it leads to some crashes."
245     ),
246 )
247 @click.option(
248     "--check",
249     is_flag=True,
250     help=(
251         "Don't write the files back, just return the status. Return code 0 means"
252         " nothing would change. Return code 1 means some files would be reformatted."
253         " Return code 123 means there was an internal error."
254     ),
255 )
256 @click.option(
257     "--diff",
258     is_flag=True,
259     help="Don't write the files back, just output a diff for each file on stdout.",
260 )
261 @click.option(
262     "--color/--no-color",
263     is_flag=True,
264     help="Show colored diff. Only applies when `--diff` is given.",
265 )
266 @click.option(
267     "--fast/--safe",
268     is_flag=True,
269     help="If --fast given, skip temporary sanity checks. [default: --safe]",
270 )
271 @click.option(
272     "--required-version",
273     type=str,
274     help=(
275         "Require a specific version of Black to be running (useful for unifying results"
276         " across many environments e.g. with a pyproject.toml file)."
277     ),
278 )
279 @click.option(
280     "--include",
281     type=str,
282     default=DEFAULT_INCLUDES,
283     callback=validate_regex,
284     help=(
285         "A regular expression that matches files and directories that should be"
286         " included on recursive searches. An empty value means all files are included"
287         " regardless of the name. Use forward slashes for directories on all platforms"
288         " (Windows, too). Exclusions are calculated first, inclusions later."
289     ),
290     show_default=True,
291 )
292 @click.option(
293     "--exclude",
294     type=str,
295     callback=validate_regex,
296     help=(
297         "A regular expression that matches files and directories that should be"
298         " excluded on recursive searches. An empty value means no paths are excluded."
299         " Use forward slashes for directories on all platforms (Windows, too)."
300         " Exclusions are calculated first, inclusions later. [default:"
301         f" {DEFAULT_EXCLUDES}]"
302     ),
303     show_default=False,
304 )
305 @click.option(
306     "--extend-exclude",
307     type=str,
308     callback=validate_regex,
309     help=(
310         "Like --exclude, but adds additional files and directories on top of the"
311         " excluded ones. (Useful if you simply want to add to the default)"
312     ),
313 )
314 @click.option(
315     "--force-exclude",
316     type=str,
317     callback=validate_regex,
318     help=(
319         "Like --exclude, but files and directories matching this regex will be "
320         "excluded even when they are passed explicitly as arguments."
321     ),
322 )
323 @click.option(
324     "--stdin-filename",
325     type=str,
326     help=(
327         "The name of the file when passing it through stdin. Useful to make "
328         "sure Black will respect --force-exclude option on some "
329         "editors that rely on using stdin."
330     ),
331 )
332 @click.option(
333     "-W",
334     "--workers",
335     type=click.IntRange(min=1),
336     default=DEFAULT_WORKERS,
337     show_default=True,
338     help="Number of parallel workers",
339 )
340 @click.option(
341     "-q",
342     "--quiet",
343     is_flag=True,
344     help=(
345         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
346         " those with 2>/dev/null."
347     ),
348 )
349 @click.option(
350     "-v",
351     "--verbose",
352     is_flag=True,
353     help=(
354         "Also emit messages to stderr about files that were not changed or were ignored"
355         " due to exclusion patterns."
356     ),
357 )
358 @click.version_option(
359     version=__version__,
360     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
361 )
362 @click.argument(
363     "src",
364     nargs=-1,
365     type=click.Path(
366         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
367     ),
368     is_eager=True,
369     metavar="SRC ...",
370 )
371 @click.option(
372     "--config",
373     type=click.Path(
374         exists=True,
375         file_okay=True,
376         dir_okay=False,
377         readable=True,
378         allow_dash=False,
379         path_type=str,
380     ),
381     is_eager=True,
382     callback=read_pyproject_toml,
383     help="Read configuration from FILE path.",
384 )
385 @click.pass_context
386 def main(
387     ctx: click.Context,
388     code: Optional[str],
389     line_length: int,
390     target_version: List[TargetVersion],
391     check: bool,
392     diff: bool,
393     color: bool,
394     fast: bool,
395     pyi: bool,
396     ipynb: bool,
397     skip_string_normalization: bool,
398     skip_magic_trailing_comma: bool,
399     experimental_string_processing: bool,
400     quiet: bool,
401     verbose: bool,
402     required_version: Optional[str],
403     include: Pattern[str],
404     exclude: Optional[Pattern[str]],
405     extend_exclude: Optional[Pattern[str]],
406     force_exclude: Optional[Pattern[str]],
407     stdin_filename: Optional[str],
408     workers: int,
409     src: Tuple[str, ...],
410     config: Optional[str],
411 ) -> None:
412     """The uncompromising code formatter."""
413     if config and verbose:
414         out(f"Using configuration from {config}.", bold=False, fg="blue")
415
416     error_msg = "Oh no! 💥 💔 💥"
417     if required_version and required_version != __version__:
418         err(
419             f"{error_msg} The required version `{required_version}` does not match"
420             f" the running version `{__version__}`!"
421         )
422         ctx.exit(1)
423     if ipynb and pyi:
424         err("Cannot pass both `pyi` and `ipynb` flags!")
425         ctx.exit(1)
426
427     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
428     if target_version:
429         versions = set(target_version)
430     else:
431         # We'll autodetect later.
432         versions = set()
433     mode = Mode(
434         target_versions=versions,
435         line_length=line_length,
436         is_pyi=pyi,
437         is_ipynb=ipynb,
438         string_normalization=not skip_string_normalization,
439         magic_trailing_comma=not skip_magic_trailing_comma,
440         experimental_string_processing=experimental_string_processing,
441     )
442
443     if code is not None:
444         # Run in quiet mode by default with -c; the extra output isn't useful.
445         # You can still pass -v to get verbose output.
446         quiet = True
447
448     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
449
450     if code is not None:
451         reformat_code(
452             content=code, fast=fast, write_back=write_back, mode=mode, report=report
453         )
454     else:
455         try:
456             sources = get_sources(
457                 ctx=ctx,
458                 src=src,
459                 quiet=quiet,
460                 verbose=verbose,
461                 include=include,
462                 exclude=exclude,
463                 extend_exclude=extend_exclude,
464                 force_exclude=force_exclude,
465                 report=report,
466                 stdin_filename=stdin_filename,
467             )
468         except GitWildMatchPatternError:
469             ctx.exit(1)
470
471         path_empty(
472             sources,
473             "No Python files are present to be formatted. Nothing to do 😴",
474             quiet,
475             verbose,
476             ctx,
477         )
478
479         if len(sources) == 1:
480             reformat_one(
481                 src=sources.pop(),
482                 fast=fast,
483                 write_back=write_back,
484                 mode=mode,
485                 report=report,
486             )
487         else:
488             reformat_many(
489                 sources=sources,
490                 fast=fast,
491                 write_back=write_back,
492                 mode=mode,
493                 report=report,
494                 workers=workers,
495             )
496
497     if verbose or not quiet:
498         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
499         if code is None:
500             click.echo(str(report), err=True)
501     ctx.exit(report.return_code)
502
503
504 def get_sources(
505     *,
506     ctx: click.Context,
507     src: Tuple[str, ...],
508     quiet: bool,
509     verbose: bool,
510     include: Pattern[str],
511     exclude: Optional[Pattern[str]],
512     extend_exclude: Optional[Pattern[str]],
513     force_exclude: Optional[Pattern[str]],
514     report: "Report",
515     stdin_filename: Optional[str],
516 ) -> Set[Path]:
517     """Compute the set of files to be formatted."""
518
519     root = find_project_root(src)
520     sources: Set[Path] = set()
521     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
522
523     if exclude is None:
524         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
525         gitignore = get_gitignore(root)
526     else:
527         gitignore = None
528
529     for s in src:
530         if s == "-" and stdin_filename:
531             p = Path(stdin_filename)
532             is_stdin = True
533         else:
534             p = Path(s)
535             is_stdin = False
536
537         if is_stdin or p.is_file():
538             normalized_path = normalize_path_maybe_ignore(p, root, report)
539             if normalized_path is None:
540                 continue
541
542             normalized_path = "/" + normalized_path
543             # Hard-exclude any files that matches the `--force-exclude` regex.
544             if force_exclude:
545                 force_exclude_match = force_exclude.search(normalized_path)
546             else:
547                 force_exclude_match = None
548             if force_exclude_match and force_exclude_match.group(0):
549                 report.path_ignored(p, "matches the --force-exclude regular expression")
550                 continue
551
552             if is_stdin:
553                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
554
555             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
556                 verbose=verbose, quiet=quiet
557             ):
558                 continue
559
560             sources.add(p)
561         elif p.is_dir():
562             sources.update(
563                 gen_python_files(
564                     p.iterdir(),
565                     root,
566                     include,
567                     exclude,
568                     extend_exclude,
569                     force_exclude,
570                     report,
571                     gitignore,
572                     verbose=verbose,
573                     quiet=quiet,
574                 )
575             )
576         elif s == "-":
577             sources.add(p)
578         else:
579             err(f"invalid path: {s}")
580     return sources
581
582
583 def path_empty(
584     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
585 ) -> None:
586     """
587     Exit if there is no `src` provided for formatting
588     """
589     if not src:
590         if verbose or not quiet:
591             out(msg)
592         ctx.exit(0)
593
594
595 def reformat_code(
596     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
597 ) -> None:
598     """
599     Reformat and print out `content` without spawning child processes.
600     Similar to `reformat_one`, but for string content.
601
602     `fast`, `write_back`, and `mode` options are passed to
603     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
604     """
605     path = Path("<string>")
606     try:
607         changed = Changed.NO
608         if format_stdin_to_stdout(
609             content=content, fast=fast, write_back=write_back, mode=mode
610         ):
611             changed = Changed.YES
612         report.done(path, changed)
613     except Exception as exc:
614         if report.verbose:
615             traceback.print_exc()
616         report.failed(path, str(exc))
617
618
619 def reformat_one(
620     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
621 ) -> None:
622     """Reformat a single file under `src` without spawning child processes.
623
624     `fast`, `write_back`, and `mode` options are passed to
625     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
626     """
627     try:
628         changed = Changed.NO
629
630         if str(src) == "-":
631             is_stdin = True
632         elif str(src).startswith(STDIN_PLACEHOLDER):
633             is_stdin = True
634             # Use the original name again in case we want to print something
635             # to the user
636             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
637         else:
638             is_stdin = False
639
640         if is_stdin:
641             if src.suffix == ".pyi":
642                 mode = replace(mode, is_pyi=True)
643             elif src.suffix == ".ipynb":
644                 mode = replace(mode, is_ipynb=True)
645             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
646                 changed = Changed.YES
647         else:
648             cache: Cache = {}
649             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
650                 cache = read_cache(mode)
651                 res_src = src.resolve()
652                 res_src_s = str(res_src)
653                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
654                     changed = Changed.CACHED
655             if changed is not Changed.CACHED and format_file_in_place(
656                 src, fast=fast, write_back=write_back, mode=mode
657             ):
658                 changed = Changed.YES
659             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
660                 write_back is WriteBack.CHECK and changed is Changed.NO
661             ):
662                 write_cache(cache, [src], mode)
663         report.done(src, changed)
664     except Exception as exc:
665         if report.verbose:
666             traceback.print_exc()
667         report.failed(src, str(exc))
668
669
670 # diff-shades depends on being to monkeypatch this function to operate. I know it's
671 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
672 @mypyc_attr(patchable=True)
673 def reformat_many(
674     sources: Set[Path],
675     fast: bool,
676     write_back: WriteBack,
677     mode: Mode,
678     report: "Report",
679     workers: Optional[int],
680 ) -> None:
681     """Reformat multiple files using a ProcessPoolExecutor."""
682     executor: Executor
683     loop = asyncio.get_event_loop()
684     worker_count = workers if workers is not None else DEFAULT_WORKERS
685     if sys.platform == "win32":
686         # Work around https://bugs.python.org/issue26903
687         assert worker_count is not None
688         worker_count = min(worker_count, 60)
689     try:
690         executor = ProcessPoolExecutor(max_workers=worker_count)
691     except (ImportError, NotImplementedError, OSError):
692         # we arrive here if the underlying system does not support multi-processing
693         # like in AWS Lambda or Termux, in which case we gracefully fallback to
694         # a ThreadPoolExecutor with just a single worker (more workers would not do us
695         # any good due to the Global Interpreter Lock)
696         executor = ThreadPoolExecutor(max_workers=1)
697
698     try:
699         loop.run_until_complete(
700             schedule_formatting(
701                 sources=sources,
702                 fast=fast,
703                 write_back=write_back,
704                 mode=mode,
705                 report=report,
706                 loop=loop,
707                 executor=executor,
708             )
709         )
710     finally:
711         shutdown(loop)
712         if executor is not None:
713             executor.shutdown()
714
715
716 async def schedule_formatting(
717     sources: Set[Path],
718     fast: bool,
719     write_back: WriteBack,
720     mode: Mode,
721     report: "Report",
722     loop: asyncio.AbstractEventLoop,
723     executor: Executor,
724 ) -> None:
725     """Run formatting of `sources` in parallel using the provided `executor`.
726
727     (Use ProcessPoolExecutors for actual parallelism.)
728
729     `write_back`, `fast`, and `mode` options are passed to
730     :func:`format_file_in_place`.
731     """
732     cache: Cache = {}
733     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
734         cache = read_cache(mode)
735         sources, cached = filter_cached(cache, sources)
736         for src in sorted(cached):
737             report.done(src, Changed.CACHED)
738     if not sources:
739         return
740
741     cancelled = []
742     sources_to_cache = []
743     lock = None
744     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
745         # For diff output, we need locks to ensure we don't interleave output
746         # from different processes.
747         manager = Manager()
748         lock = manager.Lock()
749     tasks = {
750         asyncio.ensure_future(
751             loop.run_in_executor(
752                 executor, format_file_in_place, src, fast, mode, write_back, lock
753             )
754         ): src
755         for src in sorted(sources)
756     }
757     pending = tasks.keys()
758     try:
759         loop.add_signal_handler(signal.SIGINT, cancel, pending)
760         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
761     except NotImplementedError:
762         # There are no good alternatives for these on Windows.
763         pass
764     while pending:
765         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
766         for task in done:
767             src = tasks.pop(task)
768             if task.cancelled():
769                 cancelled.append(task)
770             elif task.exception():
771                 report.failed(src, str(task.exception()))
772             else:
773                 changed = Changed.YES if task.result() else Changed.NO
774                 # If the file was written back or was successfully checked as
775                 # well-formatted, store this information in the cache.
776                 if write_back is WriteBack.YES or (
777                     write_back is WriteBack.CHECK and changed is Changed.NO
778                 ):
779                     sources_to_cache.append(src)
780                 report.done(src, changed)
781     if cancelled:
782         if sys.version_info >= (3, 7):
783             await asyncio.gather(*cancelled, return_exceptions=True)
784         else:
785             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
786     if sources_to_cache:
787         write_cache(cache, sources_to_cache, mode)
788
789
790 def format_file_in_place(
791     src: Path,
792     fast: bool,
793     mode: Mode,
794     write_back: WriteBack = WriteBack.NO,
795     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
796 ) -> bool:
797     """Format file under `src` path. Return True if changed.
798
799     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
800     code to the file.
801     `mode` and `fast` options are passed to :func:`format_file_contents`.
802     """
803     if src.suffix == ".pyi":
804         mode = replace(mode, is_pyi=True)
805     elif src.suffix == ".ipynb":
806         mode = replace(mode, is_ipynb=True)
807
808     then = datetime.utcfromtimestamp(src.stat().st_mtime)
809     with open(src, "rb") as buf:
810         src_contents, encoding, newline = decode_bytes(buf.read())
811     try:
812         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
813     except NothingChanged:
814         return False
815     except JSONDecodeError:
816         raise ValueError(
817             f"File '{src}' cannot be parsed as valid Jupyter notebook."
818         ) from None
819
820     if write_back == WriteBack.YES:
821         with open(src, "w", encoding=encoding, newline=newline) as f:
822             f.write(dst_contents)
823     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
824         now = datetime.utcnow()
825         src_name = f"{src}\t{then} +0000"
826         dst_name = f"{src}\t{now} +0000"
827         if mode.is_ipynb:
828             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
829         else:
830             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
831
832         if write_back == WriteBack.COLOR_DIFF:
833             diff_contents = color_diff(diff_contents)
834
835         with lock or nullcontext():
836             f = io.TextIOWrapper(
837                 sys.stdout.buffer,
838                 encoding=encoding,
839                 newline=newline,
840                 write_through=True,
841             )
842             f = wrap_stream_for_windows(f)
843             f.write(diff_contents)
844             f.detach()
845
846     return True
847
848
849 def format_stdin_to_stdout(
850     fast: bool,
851     *,
852     content: Optional[str] = None,
853     write_back: WriteBack = WriteBack.NO,
854     mode: Mode,
855 ) -> bool:
856     """Format file on stdin. Return True if changed.
857
858     If content is None, it's read from sys.stdin.
859
860     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
861     write a diff to stdout. The `mode` argument is passed to
862     :func:`format_file_contents`.
863     """
864     then = datetime.utcnow()
865
866     if content is None:
867         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
868     else:
869         src, encoding, newline = content, "utf-8", ""
870
871     dst = src
872     try:
873         dst = format_file_contents(src, fast=fast, mode=mode)
874         return True
875
876     except NothingChanged:
877         return False
878
879     finally:
880         f = io.TextIOWrapper(
881             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
882         )
883         if write_back == WriteBack.YES:
884             # Make sure there's a newline after the content
885             if dst and dst[-1] != "\n":
886                 dst += "\n"
887             f.write(dst)
888         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
889             now = datetime.utcnow()
890             src_name = f"STDIN\t{then} +0000"
891             dst_name = f"STDOUT\t{now} +0000"
892             d = diff(src, dst, src_name, dst_name)
893             if write_back == WriteBack.COLOR_DIFF:
894                 d = color_diff(d)
895                 f = wrap_stream_for_windows(f)
896             f.write(d)
897         f.detach()
898
899
900 def check_stability_and_equivalence(
901     src_contents: str, dst_contents: str, *, mode: Mode
902 ) -> None:
903     """Perform stability and equivalence checks.
904
905     Raise AssertionError if source and destination contents are not
906     equivalent, or if a second pass of the formatter would format the
907     content differently.
908     """
909     assert_equivalent(src_contents, dst_contents)
910
911     # Forced second pass to work around optional trailing commas (becoming
912     # forced trailing commas on pass 2) interacting differently with optional
913     # parentheses.  Admittedly ugly.
914     dst_contents_pass2 = format_str(dst_contents, mode=mode)
915     if dst_contents != dst_contents_pass2:
916         dst_contents = dst_contents_pass2
917         assert_equivalent(src_contents, dst_contents, pass_num=2)
918         assert_stable(src_contents, dst_contents, mode=mode)
919     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
920     # the same as `dst_contents_pass2`.
921
922
923 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
924     """Reformat contents of a file and return new contents.
925
926     If `fast` is False, additionally confirm that the reformatted code is
927     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
928     `mode` is passed to :func:`format_str`.
929     """
930     if not src_contents.strip():
931         raise NothingChanged
932
933     if mode.is_ipynb:
934         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
935     else:
936         dst_contents = format_str(src_contents, mode=mode)
937     if src_contents == dst_contents:
938         raise NothingChanged
939
940     if not fast and not mode.is_ipynb:
941         # Jupyter notebooks will already have been checked above.
942         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
943     return dst_contents
944
945
946 def validate_cell(src: str) -> None:
947     """Check that cell does not already contain TransformerManager transformations,
948     or non-Python cell magics, which might cause tokenizer_rt to break because of
949     indentations.
950
951     If a cell contains ``!ls``, then it'll be transformed to
952     ``get_ipython().system('ls')``. However, if the cell originally contained
953     ``get_ipython().system('ls')``, then it would get transformed in the same way:
954
955         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
956         "get_ipython().system('ls')\n"
957         >>> TransformerManager().transform_cell("!ls")
958         "get_ipython().system('ls')\n"
959
960     Due to the impossibility of safely roundtripping in such situations, cells
961     containing transformed magics will be ignored.
962     """
963     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
964         raise NothingChanged
965     if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS:
966         raise NothingChanged
967
968
969 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
970     """Format code in given cell of Jupyter notebook.
971
972     General idea is:
973
974       - if cell has trailing semicolon, remove it;
975       - if cell has IPython magics, mask them;
976       - format cell;
977       - reinstate IPython magics;
978       - reinstate trailing semicolon (if originally present);
979       - strip trailing newlines.
980
981     Cells with syntax errors will not be processed, as they
982     could potentially be automagics or multi-line magics, which
983     are currently not supported.
984     """
985     validate_cell(src)
986     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
987         src
988     )
989     try:
990         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
991     except SyntaxError:
992         raise NothingChanged from None
993     masked_dst = format_str(masked_src, mode=mode)
994     if not fast:
995         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
996     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
997     dst = put_trailing_semicolon_back(
998         dst_without_trailing_semicolon, has_trailing_semicolon
999     )
1000     dst = dst.rstrip("\n")
1001     if dst == src:
1002         raise NothingChanged from None
1003     return dst
1004
1005
1006 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1007     """If notebook is marked as non-Python, don't format it.
1008
1009     All notebook metadata fields are optional, see
1010     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1011     if a notebook has empty metadata, we will try to parse it anyway.
1012     """
1013     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1014     if language is not None and language != "python":
1015         raise NothingChanged from None
1016
1017
1018 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1019     """Format Jupyter notebook.
1020
1021     Operate cell-by-cell, only on code cells, only for Python notebooks.
1022     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1023     """
1024     trailing_newline = src_contents[-1] == "\n"
1025     modified = False
1026     nb = json.loads(src_contents)
1027     validate_metadata(nb)
1028     for cell in nb["cells"]:
1029         if cell.get("cell_type", None) == "code":
1030             try:
1031                 src = "".join(cell["source"])
1032                 dst = format_cell(src, fast=fast, mode=mode)
1033             except NothingChanged:
1034                 pass
1035             else:
1036                 cell["source"] = dst.splitlines(keepends=True)
1037                 modified = True
1038     if modified:
1039         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1040         if trailing_newline:
1041             dst_contents = dst_contents + "\n"
1042         return dst_contents
1043     else:
1044         raise NothingChanged
1045
1046
1047 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1048     """Reformat a string and return new contents.
1049
1050     `mode` determines formatting options, such as how many characters per line are
1051     allowed.  Example:
1052
1053     >>> import black
1054     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1055     def f(arg: str = "") -> None:
1056         ...
1057
1058     A more complex example:
1059
1060     >>> print(
1061     ...   black.format_str(
1062     ...     "def f(arg:str='')->None: hey",
1063     ...     mode=black.Mode(
1064     ...       target_versions={black.TargetVersion.PY36},
1065     ...       line_length=10,
1066     ...       string_normalization=False,
1067     ...       is_pyi=False,
1068     ...     ),
1069     ...   ),
1070     ... )
1071     def f(
1072         arg: str = '',
1073     ) -> None:
1074         hey
1075
1076     """
1077     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1078     dst_contents = []
1079     future_imports = get_future_imports(src_node)
1080     if mode.target_versions:
1081         versions = mode.target_versions
1082     else:
1083         versions = detect_target_versions(src_node, future_imports=future_imports)
1084
1085     # TODO: fully drop support and this code hopefully in January 2022 :D
1086     if TargetVersion.PY27 in mode.target_versions or versions == {TargetVersion.PY27}:
1087         msg = (
1088             "DEPRECATION: Python 2 support will be removed in the first stable release "
1089             "expected in January 2022."
1090         )
1091         err(msg, fg="yellow", bold=True)
1092
1093     normalize_fmt_off(src_node)
1094     lines = LineGenerator(
1095         mode=mode,
1096         remove_u_prefix="unicode_literals" in future_imports
1097         or supports_feature(versions, Feature.UNICODE_LITERALS),
1098     )
1099     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1100     empty_line = Line(mode=mode)
1101     after = 0
1102     split_line_features = {
1103         feature
1104         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1105         if supports_feature(versions, feature)
1106     }
1107     for current_line in lines.visit(src_node):
1108         dst_contents.append(str(empty_line) * after)
1109         before, after = elt.maybe_empty_lines(current_line)
1110         dst_contents.append(str(empty_line) * before)
1111         for line in transform_line(
1112             current_line, mode=mode, features=split_line_features
1113         ):
1114             dst_contents.append(str(line))
1115     return "".join(dst_contents)
1116
1117
1118 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1119     """Return a tuple of (decoded_contents, encoding, newline).
1120
1121     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1122     universal newlines (i.e. only contains LF).
1123     """
1124     srcbuf = io.BytesIO(src)
1125     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1126     if not lines:
1127         return "", encoding, "\n"
1128
1129     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1130     srcbuf.seek(0)
1131     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1132         return tiow.read(), encoding, newline
1133
1134
1135 def get_features_used(  # noqa: C901
1136     node: Node, *, future_imports: Optional[Set[str]] = None
1137 ) -> Set[Feature]:
1138     """Return a set of (relatively) new Python features used in this file.
1139
1140     Currently looking for:
1141     - f-strings;
1142     - underscores in numeric literals;
1143     - trailing commas after * or ** in function signatures and calls;
1144     - positional only arguments in function signatures and lambdas;
1145     - assignment expression;
1146     - relaxed decorator syntax;
1147     - usage of __future__ flags (annotations);
1148     - print / exec statements;
1149     """
1150     features: Set[Feature] = set()
1151     if future_imports:
1152         features |= {
1153             FUTURE_FLAG_TO_FEATURE[future_import]
1154             for future_import in future_imports
1155             if future_import in FUTURE_FLAG_TO_FEATURE
1156         }
1157
1158     for n in node.pre_order():
1159         if n.type == token.STRING:
1160             value_head = n.value[:2]  # type: ignore
1161             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1162                 features.add(Feature.F_STRINGS)
1163
1164         elif n.type == token.NUMBER:
1165             assert isinstance(n, Leaf)
1166             if "_" in n.value:
1167                 features.add(Feature.NUMERIC_UNDERSCORES)
1168             elif n.value.endswith(("L", "l")):
1169                 # Python 2: 10L
1170                 features.add(Feature.LONG_INT_LITERAL)
1171             elif len(n.value) >= 2 and n.value[0] == "0" and n.value[1].isdigit():
1172                 # Python 2: 0123; 00123; ...
1173                 if not all(char == "0" for char in n.value):
1174                     # although we don't want to match 0000 or similar
1175                     features.add(Feature.OCTAL_INT_LITERAL)
1176
1177         elif n.type == token.SLASH:
1178             if n.parent and n.parent.type in {
1179                 syms.typedargslist,
1180                 syms.arglist,
1181                 syms.varargslist,
1182             }:
1183                 features.add(Feature.POS_ONLY_ARGUMENTS)
1184
1185         elif n.type == token.COLONEQUAL:
1186             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1187
1188         elif n.type == syms.decorator:
1189             if len(n.children) > 1 and not is_simple_decorator_expression(
1190                 n.children[1]
1191             ):
1192                 features.add(Feature.RELAXED_DECORATORS)
1193
1194         elif (
1195             n.type in {syms.typedargslist, syms.arglist}
1196             and n.children
1197             and n.children[-1].type == token.COMMA
1198         ):
1199             if n.type == syms.typedargslist:
1200                 feature = Feature.TRAILING_COMMA_IN_DEF
1201             else:
1202                 feature = Feature.TRAILING_COMMA_IN_CALL
1203
1204             for ch in n.children:
1205                 if ch.type in STARS:
1206                     features.add(feature)
1207
1208                 if ch.type == syms.argument:
1209                     for argch in ch.children:
1210                         if argch.type in STARS:
1211                             features.add(feature)
1212
1213         elif (
1214             n.type in {syms.return_stmt, syms.yield_expr}
1215             and len(n.children) >= 2
1216             and n.children[1].type == syms.testlist_star_expr
1217             and any(child.type == syms.star_expr for child in n.children[1].children)
1218         ):
1219             features.add(Feature.UNPACKING_ON_FLOW)
1220
1221         elif (
1222             n.type == syms.annassign
1223             and len(n.children) >= 4
1224             and n.children[3].type == syms.testlist_star_expr
1225         ):
1226             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1227
1228         # Python 2 only features (for its deprecation) except for integers, see above
1229         elif n.type == syms.print_stmt:
1230             features.add(Feature.PRINT_STMT)
1231         elif n.type == syms.exec_stmt:
1232             features.add(Feature.EXEC_STMT)
1233         elif n.type == syms.tfpdef:
1234             # def set_position((x, y), value):
1235             #     ...
1236             features.add(Feature.AUTOMATIC_PARAMETER_UNPACKING)
1237         elif n.type == syms.except_clause:
1238             # try:
1239             #     ...
1240             # except Exception, err:
1241             #     ...
1242             if len(n.children) >= 4:
1243                 if n.children[-2].type == token.COMMA:
1244                     features.add(Feature.COMMA_STYLE_EXCEPT)
1245         elif n.type == syms.raise_stmt:
1246             # raise Exception, "msg"
1247             if len(n.children) >= 4:
1248                 if n.children[-2].type == token.COMMA:
1249                     features.add(Feature.COMMA_STYLE_RAISE)
1250         elif n.type == token.BACKQUOTE:
1251             # `i'm surprised this ever existed`
1252             features.add(Feature.BACKQUOTE_REPR)
1253
1254     return features
1255
1256
1257 def detect_target_versions(
1258     node: Node, *, future_imports: Optional[Set[str]] = None
1259 ) -> Set[TargetVersion]:
1260     """Detect the version to target based on the nodes used."""
1261     features = get_features_used(node, future_imports=future_imports)
1262     return {
1263         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1264     }
1265
1266
1267 def get_future_imports(node: Node) -> Set[str]:
1268     """Return a set of __future__ imports in the file."""
1269     imports: Set[str] = set()
1270
1271     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1272         for child in children:
1273             if isinstance(child, Leaf):
1274                 if child.type == token.NAME:
1275                     yield child.value
1276
1277             elif child.type == syms.import_as_name:
1278                 orig_name = child.children[0]
1279                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1280                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1281                 yield orig_name.value
1282
1283             elif child.type == syms.import_as_names:
1284                 yield from get_imports_from_children(child.children)
1285
1286             else:
1287                 raise AssertionError("Invalid syntax parsing imports")
1288
1289     for child in node.children:
1290         if child.type != syms.simple_stmt:
1291             break
1292
1293         first_child = child.children[0]
1294         if isinstance(first_child, Leaf):
1295             # Continue looking if we see a docstring; otherwise stop.
1296             if (
1297                 len(child.children) == 2
1298                 and first_child.type == token.STRING
1299                 and child.children[1].type == token.NEWLINE
1300             ):
1301                 continue
1302
1303             break
1304
1305         elif first_child.type == syms.import_from:
1306             module_name = first_child.children[1]
1307             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1308                 break
1309
1310             imports |= set(get_imports_from_children(first_child.children[3:]))
1311         else:
1312             break
1313
1314     return imports
1315
1316
1317 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1318     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1319     try:
1320         src_ast = parse_ast(src)
1321     except Exception as exc:
1322         raise AssertionError(
1323             f"cannot use --safe with this file; failed to parse source file: {exc}"
1324         ) from exc
1325
1326     try:
1327         dst_ast = parse_ast(dst)
1328     except Exception as exc:
1329         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1330         raise AssertionError(
1331             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1332             "Please report a bug on https://github.com/psf/black/issues.  "
1333             f"This invalid output might be helpful: {log}"
1334         ) from None
1335
1336     src_ast_str = "\n".join(stringify_ast(src_ast))
1337     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1338     if src_ast_str != dst_ast_str:
1339         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1340         raise AssertionError(
1341             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1342             f" source on pass {pass_num}.  Please report a bug on "
1343             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1344         ) from None
1345
1346
1347 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1348     """Raise AssertionError if `dst` reformats differently the second time."""
1349     newdst = format_str(dst, mode=mode)
1350     if dst != newdst:
1351         log = dump_to_file(
1352             str(mode),
1353             diff(src, dst, "source", "first pass"),
1354             diff(dst, newdst, "first pass", "second pass"),
1355         )
1356         raise AssertionError(
1357             "INTERNAL ERROR: Black produced different code on the second pass of the"
1358             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1359             f"  This diff might be helpful: {log}"
1360         ) from None
1361
1362
1363 @contextmanager
1364 def nullcontext() -> Iterator[None]:
1365     """Return an empty context manager.
1366
1367     To be used like `nullcontext` in Python 3.7.
1368     """
1369     yield
1370
1371
1372 def patch_click() -> None:
1373     """Make Click not crash on Python 3.6 with LANG=C.
1374
1375     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1376     default which restricts paths that it can access during the lifetime of the
1377     application.  Click refuses to work in this scenario by raising a RuntimeError.
1378
1379     In case of Black the likelihood that non-ASCII characters are going to be used in
1380     file paths is minimal since it's Python source code.  Moreover, this crash was
1381     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1382     """
1383     try:
1384         from click import core
1385         from click import _unicodefun
1386     except ModuleNotFoundError:
1387         return
1388
1389     for module in (core, _unicodefun):
1390         if hasattr(module, "_verify_python3_env"):
1391             module._verify_python3_env = lambda: None  # type: ignore
1392         if hasattr(module, "_verify_python_env"):
1393             module._verify_python_env = lambda: None  # type: ignore
1394
1395
1396 def patched_main() -> None:
1397     maybe_install_uvloop()
1398     freeze_support()
1399     patch_click()
1400     main()
1401
1402
1403 if __name__ == "__main__":
1404     patched_main()