]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Implementing mypyc support pt. 2 (#2431)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import regex as re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 import click
34 from dataclasses import replace
35 from mypy_extensions import mypyc_attr
36
37 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
38 from black.const import STDIN_PLACEHOLDER
39 from black.nodes import STARS, syms, is_simple_decorator_expression
40 from black.lines import Line, EmptyLineTracker
41 from black.linegen import transform_line, LineGenerator, LN
42 from black.comments import normalize_fmt_off
43 from black.mode import Mode, TargetVersion
44 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
45 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
46 from black.concurrency import cancel, shutdown, maybe_install_uvloop
47 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
48 from black.report import Report, Changed, NothingChanged
49 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
50 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
51 from black.files import wrap_stream_for_windows
52 from black.parsing import InvalidInput  # noqa F401
53 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
54 from black.handle_ipynb_magics import (
55     mask_cell,
56     unmask_cell,
57     remove_trailing_semicolon,
58     put_trailing_semicolon_back,
59     TRANSFORMED_MAGICS,
60     jupyter_dependencies_are_installed,
61 )
62
63
64 # lib2to3 fork
65 from blib2to3.pytree import Node, Leaf
66 from blib2to3.pgen2 import token
67
68 from _black_version import version as __version__
69
70 COMPILED = Path(__file__).suffix in (".pyd", ".so")
71
72 # types
73 FileContent = str
74 Encoding = str
75 NewLine = str
76
77
78 class WriteBack(Enum):
79     NO = 0
80     YES = 1
81     DIFF = 2
82     CHECK = 3
83     COLOR_DIFF = 4
84
85     @classmethod
86     def from_configuration(
87         cls, *, check: bool, diff: bool, color: bool = False
88     ) -> "WriteBack":
89         if check and not diff:
90             return cls.CHECK
91
92         if diff and color:
93             return cls.COLOR_DIFF
94
95         return cls.DIFF if diff else cls.YES
96
97
98 # Legacy name, left for integrations.
99 FileMode = Mode
100
101 DEFAULT_WORKERS = os.cpu_count()
102
103
104 def read_pyproject_toml(
105     ctx: click.Context, param: click.Parameter, value: Optional[str]
106 ) -> Optional[str]:
107     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
108
109     Returns the path to a successfully found and read configuration file, None
110     otherwise.
111     """
112     if not value:
113         value = find_pyproject_toml(ctx.params.get("src", ()))
114         if value is None:
115             return None
116
117     try:
118         config = parse_pyproject_toml(value)
119     except (OSError, ValueError) as e:
120         raise click.FileError(
121             filename=value, hint=f"Error reading configuration file: {e}"
122         ) from None
123
124     if not config:
125         return None
126     else:
127         # Sanitize the values to be Click friendly. For more information please see:
128         # https://github.com/psf/black/issues/1458
129         # https://github.com/pallets/click/issues/1567
130         config = {
131             k: str(v) if not isinstance(v, (list, dict)) else v
132             for k, v in config.items()
133         }
134
135     target_version = config.get("target_version")
136     if target_version is not None and not isinstance(target_version, list):
137         raise click.BadOptionUsage(
138             "target-version", "Config key target-version must be a list"
139         )
140
141     default_map: Dict[str, Any] = {}
142     if ctx.default_map:
143         default_map.update(ctx.default_map)
144     default_map.update(config)
145
146     ctx.default_map = default_map
147     return value
148
149
150 def target_version_option_callback(
151     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
152 ) -> List[TargetVersion]:
153     """Compute the target versions from a --target-version flag.
154
155     This is its own function because mypy couldn't infer the type correctly
156     when it was a lambda, causing mypyc trouble.
157     """
158     return [TargetVersion[val.upper()] for val in v]
159
160
161 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
162     """Compile a regular expression string in `regex`.
163
164     If it contains newlines, use verbose mode.
165     """
166     if "\n" in regex:
167         regex = "(?x)" + regex
168     compiled: Pattern[str] = re.compile(regex)
169     return compiled
170
171
172 def validate_regex(
173     ctx: click.Context,
174     param: click.Parameter,
175     value: Optional[str],
176 ) -> Optional[Pattern[str]]:
177     try:
178         return re_compile_maybe_verbose(value) if value is not None else None
179     except re.error:
180         raise click.BadParameter("Not a valid regular expression") from None
181
182
183 @click.command(
184     context_settings=dict(help_option_names=["-h", "--help"]),
185     # While Click does set this field automatically using the docstring, mypyc
186     # (annoyingly) strips 'em so we need to set it here too.
187     help="The uncompromising code formatter.",
188 )
189 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
190 @click.option(
191     "-l",
192     "--line-length",
193     type=int,
194     default=DEFAULT_LINE_LENGTH,
195     help="How many characters per line to allow.",
196     show_default=True,
197 )
198 @click.option(
199     "-t",
200     "--target-version",
201     type=click.Choice([v.name.lower() for v in TargetVersion]),
202     callback=target_version_option_callback,
203     multiple=True,
204     help=(
205         "Python versions that should be supported by Black's output. [default: per-file"
206         " auto-detection]"
207     ),
208 )
209 @click.option(
210     "--pyi",
211     is_flag=True,
212     help=(
213         "Format all input files like typing stubs regardless of file extension (useful"
214         " when piping source on standard input)."
215     ),
216 )
217 @click.option(
218     "--ipynb",
219     is_flag=True,
220     help=(
221         "Format all input files like Jupyter Notebooks regardless of file extension "
222         "(useful when piping source on standard input)."
223     ),
224 )
225 @click.option(
226     "-S",
227     "--skip-string-normalization",
228     is_flag=True,
229     help="Don't normalize string quotes or prefixes.",
230 )
231 @click.option(
232     "-C",
233     "--skip-magic-trailing-comma",
234     is_flag=True,
235     help="Don't use trailing commas as a reason to split lines.",
236 )
237 @click.option(
238     "--experimental-string-processing",
239     is_flag=True,
240     hidden=True,
241     help=(
242         "Experimental option that performs more normalization on string literals."
243         " Currently disabled because it leads to some crashes."
244     ),
245 )
246 @click.option(
247     "--check",
248     is_flag=True,
249     help=(
250         "Don't write the files back, just return the status. Return code 0 means"
251         " nothing would change. Return code 1 means some files would be reformatted."
252         " Return code 123 means there was an internal error."
253     ),
254 )
255 @click.option(
256     "--diff",
257     is_flag=True,
258     help="Don't write the files back, just output a diff for each file on stdout.",
259 )
260 @click.option(
261     "--color/--no-color",
262     is_flag=True,
263     help="Show colored diff. Only applies when `--diff` is given.",
264 )
265 @click.option(
266     "--fast/--safe",
267     is_flag=True,
268     help="If --fast given, skip temporary sanity checks. [default: --safe]",
269 )
270 @click.option(
271     "--required-version",
272     type=str,
273     help=(
274         "Require a specific version of Black to be running (useful for unifying results"
275         " across many environments e.g. with a pyproject.toml file)."
276     ),
277 )
278 @click.option(
279     "--include",
280     type=str,
281     default=DEFAULT_INCLUDES,
282     callback=validate_regex,
283     help=(
284         "A regular expression that matches files and directories that should be"
285         " included on recursive searches. An empty value means all files are included"
286         " regardless of the name. Use forward slashes for directories on all platforms"
287         " (Windows, too). Exclusions are calculated first, inclusions later."
288     ),
289     show_default=True,
290 )
291 @click.option(
292     "--exclude",
293     type=str,
294     callback=validate_regex,
295     help=(
296         "A regular expression that matches files and directories that should be"
297         " excluded on recursive searches. An empty value means no paths are excluded."
298         " Use forward slashes for directories on all platforms (Windows, too)."
299         " Exclusions are calculated first, inclusions later. [default:"
300         f" {DEFAULT_EXCLUDES}]"
301     ),
302     show_default=False,
303 )
304 @click.option(
305     "--extend-exclude",
306     type=str,
307     callback=validate_regex,
308     help=(
309         "Like --exclude, but adds additional files and directories on top of the"
310         " excluded ones. (Useful if you simply want to add to the default)"
311     ),
312 )
313 @click.option(
314     "--force-exclude",
315     type=str,
316     callback=validate_regex,
317     help=(
318         "Like --exclude, but files and directories matching this regex will be "
319         "excluded even when they are passed explicitly as arguments."
320     ),
321 )
322 @click.option(
323     "--stdin-filename",
324     type=str,
325     help=(
326         "The name of the file when passing it through stdin. Useful to make "
327         "sure Black will respect --force-exclude option on some "
328         "editors that rely on using stdin."
329     ),
330 )
331 @click.option(
332     "-W",
333     "--workers",
334     type=click.IntRange(min=1),
335     default=DEFAULT_WORKERS,
336     show_default=True,
337     help="Number of parallel workers",
338 )
339 @click.option(
340     "-q",
341     "--quiet",
342     is_flag=True,
343     help=(
344         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
345         " those with 2>/dev/null."
346     ),
347 )
348 @click.option(
349     "-v",
350     "--verbose",
351     is_flag=True,
352     help=(
353         "Also emit messages to stderr about files that were not changed or were ignored"
354         " due to exclusion patterns."
355     ),
356 )
357 @click.version_option(
358     version=__version__,
359     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
360 )
361 @click.argument(
362     "src",
363     nargs=-1,
364     type=click.Path(
365         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
366     ),
367     is_eager=True,
368     metavar="SRC ...",
369 )
370 @click.option(
371     "--config",
372     type=click.Path(
373         exists=True,
374         file_okay=True,
375         dir_okay=False,
376         readable=True,
377         allow_dash=False,
378         path_type=str,
379     ),
380     is_eager=True,
381     callback=read_pyproject_toml,
382     help="Read configuration from FILE path.",
383 )
384 @click.pass_context
385 def main(
386     ctx: click.Context,
387     code: Optional[str],
388     line_length: int,
389     target_version: List[TargetVersion],
390     check: bool,
391     diff: bool,
392     color: bool,
393     fast: bool,
394     pyi: bool,
395     ipynb: bool,
396     skip_string_normalization: bool,
397     skip_magic_trailing_comma: bool,
398     experimental_string_processing: bool,
399     quiet: bool,
400     verbose: bool,
401     required_version: Optional[str],
402     include: Pattern[str],
403     exclude: Optional[Pattern[str]],
404     extend_exclude: Optional[Pattern[str]],
405     force_exclude: Optional[Pattern[str]],
406     stdin_filename: Optional[str],
407     workers: int,
408     src: Tuple[str, ...],
409     config: Optional[str],
410 ) -> None:
411     """The uncompromising code formatter."""
412     if config and verbose:
413         out(f"Using configuration from {config}.", bold=False, fg="blue")
414
415     error_msg = "Oh no! 💥 💔 💥"
416     if required_version and required_version != __version__:
417         err(
418             f"{error_msg} The required version `{required_version}` does not match"
419             f" the running version `{__version__}`!"
420         )
421         ctx.exit(1)
422     if ipynb and pyi:
423         err("Cannot pass both `pyi` and `ipynb` flags!")
424         ctx.exit(1)
425
426     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
427     if target_version:
428         versions = set(target_version)
429     else:
430         # We'll autodetect later.
431         versions = set()
432     mode = Mode(
433         target_versions=versions,
434         line_length=line_length,
435         is_pyi=pyi,
436         is_ipynb=ipynb,
437         string_normalization=not skip_string_normalization,
438         magic_trailing_comma=not skip_magic_trailing_comma,
439         experimental_string_processing=experimental_string_processing,
440     )
441
442     if code is not None:
443         # Run in quiet mode by default with -c; the extra output isn't useful.
444         # You can still pass -v to get verbose output.
445         quiet = True
446
447     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
448
449     if code is not None:
450         reformat_code(
451             content=code, fast=fast, write_back=write_back, mode=mode, report=report
452         )
453     else:
454         try:
455             sources = get_sources(
456                 ctx=ctx,
457                 src=src,
458                 quiet=quiet,
459                 verbose=verbose,
460                 include=include,
461                 exclude=exclude,
462                 extend_exclude=extend_exclude,
463                 force_exclude=force_exclude,
464                 report=report,
465                 stdin_filename=stdin_filename,
466             )
467         except GitWildMatchPatternError:
468             ctx.exit(1)
469
470         path_empty(
471             sources,
472             "No Python files are present to be formatted. Nothing to do 😴",
473             quiet,
474             verbose,
475             ctx,
476         )
477
478         if len(sources) == 1:
479             reformat_one(
480                 src=sources.pop(),
481                 fast=fast,
482                 write_back=write_back,
483                 mode=mode,
484                 report=report,
485             )
486         else:
487             reformat_many(
488                 sources=sources,
489                 fast=fast,
490                 write_back=write_back,
491                 mode=mode,
492                 report=report,
493                 workers=workers,
494             )
495
496     if verbose or not quiet:
497         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
498         if code is None:
499             click.echo(str(report), err=True)
500     ctx.exit(report.return_code)
501
502
503 def get_sources(
504     *,
505     ctx: click.Context,
506     src: Tuple[str, ...],
507     quiet: bool,
508     verbose: bool,
509     include: Pattern[str],
510     exclude: Optional[Pattern[str]],
511     extend_exclude: Optional[Pattern[str]],
512     force_exclude: Optional[Pattern[str]],
513     report: "Report",
514     stdin_filename: Optional[str],
515 ) -> Set[Path]:
516     """Compute the set of files to be formatted."""
517
518     root = find_project_root(src)
519     sources: Set[Path] = set()
520     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
521
522     if exclude is None:
523         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
524         gitignore = get_gitignore(root)
525     else:
526         gitignore = None
527
528     for s in src:
529         if s == "-" and stdin_filename:
530             p = Path(stdin_filename)
531             is_stdin = True
532         else:
533             p = Path(s)
534             is_stdin = False
535
536         if is_stdin or p.is_file():
537             normalized_path = normalize_path_maybe_ignore(p, root, report)
538             if normalized_path is None:
539                 continue
540
541             normalized_path = "/" + normalized_path
542             # Hard-exclude any files that matches the `--force-exclude` regex.
543             if force_exclude:
544                 force_exclude_match = force_exclude.search(normalized_path)
545             else:
546                 force_exclude_match = None
547             if force_exclude_match and force_exclude_match.group(0):
548                 report.path_ignored(p, "matches the --force-exclude regular expression")
549                 continue
550
551             if is_stdin:
552                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
553
554             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
555                 verbose=verbose, quiet=quiet
556             ):
557                 continue
558
559             sources.add(p)
560         elif p.is_dir():
561             sources.update(
562                 gen_python_files(
563                     p.iterdir(),
564                     root,
565                     include,
566                     exclude,
567                     extend_exclude,
568                     force_exclude,
569                     report,
570                     gitignore,
571                     verbose=verbose,
572                     quiet=quiet,
573                 )
574             )
575         elif s == "-":
576             sources.add(p)
577         else:
578             err(f"invalid path: {s}")
579     return sources
580
581
582 def path_empty(
583     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
584 ) -> None:
585     """
586     Exit if there is no `src` provided for formatting
587     """
588     if not src:
589         if verbose or not quiet:
590             out(msg)
591         ctx.exit(0)
592
593
594 def reformat_code(
595     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
596 ) -> None:
597     """
598     Reformat and print out `content` without spawning child processes.
599     Similar to `reformat_one`, but for string content.
600
601     `fast`, `write_back`, and `mode` options are passed to
602     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
603     """
604     path = Path("<string>")
605     try:
606         changed = Changed.NO
607         if format_stdin_to_stdout(
608             content=content, fast=fast, write_back=write_back, mode=mode
609         ):
610             changed = Changed.YES
611         report.done(path, changed)
612     except Exception as exc:
613         if report.verbose:
614             traceback.print_exc()
615         report.failed(path, str(exc))
616
617
618 def reformat_one(
619     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
620 ) -> None:
621     """Reformat a single file under `src` without spawning child processes.
622
623     `fast`, `write_back`, and `mode` options are passed to
624     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
625     """
626     try:
627         changed = Changed.NO
628
629         if str(src) == "-":
630             is_stdin = True
631         elif str(src).startswith(STDIN_PLACEHOLDER):
632             is_stdin = True
633             # Use the original name again in case we want to print something
634             # to the user
635             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
636         else:
637             is_stdin = False
638
639         if is_stdin:
640             if src.suffix == ".pyi":
641                 mode = replace(mode, is_pyi=True)
642             elif src.suffix == ".ipynb":
643                 mode = replace(mode, is_ipynb=True)
644             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
645                 changed = Changed.YES
646         else:
647             cache: Cache = {}
648             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
649                 cache = read_cache(mode)
650                 res_src = src.resolve()
651                 res_src_s = str(res_src)
652                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
653                     changed = Changed.CACHED
654             if changed is not Changed.CACHED and format_file_in_place(
655                 src, fast=fast, write_back=write_back, mode=mode
656             ):
657                 changed = Changed.YES
658             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
659                 write_back is WriteBack.CHECK and changed is Changed.NO
660             ):
661                 write_cache(cache, [src], mode)
662         report.done(src, changed)
663     except Exception as exc:
664         if report.verbose:
665             traceback.print_exc()
666         report.failed(src, str(exc))
667
668
669 # diff-shades depends on being to monkeypatch this function to operate. I know it's
670 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
671 @mypyc_attr(patchable=True)
672 def reformat_many(
673     sources: Set[Path],
674     fast: bool,
675     write_back: WriteBack,
676     mode: Mode,
677     report: "Report",
678     workers: Optional[int],
679 ) -> None:
680     """Reformat multiple files using a ProcessPoolExecutor."""
681     executor: Executor
682     loop = asyncio.get_event_loop()
683     worker_count = workers if workers is not None else DEFAULT_WORKERS
684     if sys.platform == "win32":
685         # Work around https://bugs.python.org/issue26903
686         assert worker_count is not None
687         worker_count = min(worker_count, 60)
688     try:
689         executor = ProcessPoolExecutor(max_workers=worker_count)
690     except (ImportError, OSError):
691         # we arrive here if the underlying system does not support multi-processing
692         # like in AWS Lambda or Termux, in which case we gracefully fallback to
693         # a ThreadPoolExecutor with just a single worker (more workers would not do us
694         # any good due to the Global Interpreter Lock)
695         executor = ThreadPoolExecutor(max_workers=1)
696
697     try:
698         loop.run_until_complete(
699             schedule_formatting(
700                 sources=sources,
701                 fast=fast,
702                 write_back=write_back,
703                 mode=mode,
704                 report=report,
705                 loop=loop,
706                 executor=executor,
707             )
708         )
709     finally:
710         shutdown(loop)
711         if executor is not None:
712             executor.shutdown()
713
714
715 async def schedule_formatting(
716     sources: Set[Path],
717     fast: bool,
718     write_back: WriteBack,
719     mode: Mode,
720     report: "Report",
721     loop: asyncio.AbstractEventLoop,
722     executor: Executor,
723 ) -> None:
724     """Run formatting of `sources` in parallel using the provided `executor`.
725
726     (Use ProcessPoolExecutors for actual parallelism.)
727
728     `write_back`, `fast`, and `mode` options are passed to
729     :func:`format_file_in_place`.
730     """
731     cache: Cache = {}
732     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
733         cache = read_cache(mode)
734         sources, cached = filter_cached(cache, sources)
735         for src in sorted(cached):
736             report.done(src, Changed.CACHED)
737     if not sources:
738         return
739
740     cancelled = []
741     sources_to_cache = []
742     lock = None
743     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
744         # For diff output, we need locks to ensure we don't interleave output
745         # from different processes.
746         manager = Manager()
747         lock = manager.Lock()
748     tasks = {
749         asyncio.ensure_future(
750             loop.run_in_executor(
751                 executor, format_file_in_place, src, fast, mode, write_back, lock
752             )
753         ): src
754         for src in sorted(sources)
755     }
756     pending = tasks.keys()
757     try:
758         loop.add_signal_handler(signal.SIGINT, cancel, pending)
759         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
760     except NotImplementedError:
761         # There are no good alternatives for these on Windows.
762         pass
763     while pending:
764         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
765         for task in done:
766             src = tasks.pop(task)
767             if task.cancelled():
768                 cancelled.append(task)
769             elif task.exception():
770                 report.failed(src, str(task.exception()))
771             else:
772                 changed = Changed.YES if task.result() else Changed.NO
773                 # If the file was written back or was successfully checked as
774                 # well-formatted, store this information in the cache.
775                 if write_back is WriteBack.YES or (
776                     write_back is WriteBack.CHECK and changed is Changed.NO
777                 ):
778                     sources_to_cache.append(src)
779                 report.done(src, changed)
780     if cancelled:
781         if sys.version_info >= (3, 7):
782             await asyncio.gather(*cancelled, return_exceptions=True)
783         else:
784             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
785     if sources_to_cache:
786         write_cache(cache, sources_to_cache, mode)
787
788
789 def format_file_in_place(
790     src: Path,
791     fast: bool,
792     mode: Mode,
793     write_back: WriteBack = WriteBack.NO,
794     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
795 ) -> bool:
796     """Format file under `src` path. Return True if changed.
797
798     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
799     code to the file.
800     `mode` and `fast` options are passed to :func:`format_file_contents`.
801     """
802     if src.suffix == ".pyi":
803         mode = replace(mode, is_pyi=True)
804     elif src.suffix == ".ipynb":
805         mode = replace(mode, is_ipynb=True)
806
807     then = datetime.utcfromtimestamp(src.stat().st_mtime)
808     with open(src, "rb") as buf:
809         src_contents, encoding, newline = decode_bytes(buf.read())
810     try:
811         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
812     except NothingChanged:
813         return False
814     except JSONDecodeError:
815         raise ValueError(
816             f"File '{src}' cannot be parsed as valid Jupyter notebook."
817         ) from None
818
819     if write_back == WriteBack.YES:
820         with open(src, "w", encoding=encoding, newline=newline) as f:
821             f.write(dst_contents)
822     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
823         now = datetime.utcnow()
824         src_name = f"{src}\t{then} +0000"
825         dst_name = f"{src}\t{now} +0000"
826         if mode.is_ipynb:
827             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
828         else:
829             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
830
831         if write_back == WriteBack.COLOR_DIFF:
832             diff_contents = color_diff(diff_contents)
833
834         with lock or nullcontext():
835             f = io.TextIOWrapper(
836                 sys.stdout.buffer,
837                 encoding=encoding,
838                 newline=newline,
839                 write_through=True,
840             )
841             f = wrap_stream_for_windows(f)
842             f.write(diff_contents)
843             f.detach()
844
845     return True
846
847
848 def format_stdin_to_stdout(
849     fast: bool,
850     *,
851     content: Optional[str] = None,
852     write_back: WriteBack = WriteBack.NO,
853     mode: Mode,
854 ) -> bool:
855     """Format file on stdin. Return True if changed.
856
857     If content is None, it's read from sys.stdin.
858
859     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
860     write a diff to stdout. The `mode` argument is passed to
861     :func:`format_file_contents`.
862     """
863     then = datetime.utcnow()
864
865     if content is None:
866         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
867     else:
868         src, encoding, newline = content, "utf-8", ""
869
870     dst = src
871     try:
872         dst = format_file_contents(src, fast=fast, mode=mode)
873         return True
874
875     except NothingChanged:
876         return False
877
878     finally:
879         f = io.TextIOWrapper(
880             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
881         )
882         if write_back == WriteBack.YES:
883             # Make sure there's a newline after the content
884             if dst and dst[-1] != "\n":
885                 dst += "\n"
886             f.write(dst)
887         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
888             now = datetime.utcnow()
889             src_name = f"STDIN\t{then} +0000"
890             dst_name = f"STDOUT\t{now} +0000"
891             d = diff(src, dst, src_name, dst_name)
892             if write_back == WriteBack.COLOR_DIFF:
893                 d = color_diff(d)
894                 f = wrap_stream_for_windows(f)
895             f.write(d)
896         f.detach()
897
898
899 def check_stability_and_equivalence(
900     src_contents: str, dst_contents: str, *, mode: Mode
901 ) -> None:
902     """Perform stability and equivalence checks.
903
904     Raise AssertionError if source and destination contents are not
905     equivalent, or if a second pass of the formatter would format the
906     content differently.
907     """
908     assert_equivalent(src_contents, dst_contents)
909
910     # Forced second pass to work around optional trailing commas (becoming
911     # forced trailing commas on pass 2) interacting differently with optional
912     # parentheses.  Admittedly ugly.
913     dst_contents_pass2 = format_str(dst_contents, mode=mode)
914     if dst_contents != dst_contents_pass2:
915         dst_contents = dst_contents_pass2
916         assert_equivalent(src_contents, dst_contents, pass_num=2)
917         assert_stable(src_contents, dst_contents, mode=mode)
918     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
919     # the same as `dst_contents_pass2`.
920
921
922 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
923     """Reformat contents of a file and return new contents.
924
925     If `fast` is False, additionally confirm that the reformatted code is
926     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
927     `mode` is passed to :func:`format_str`.
928     """
929     if not src_contents.strip():
930         raise NothingChanged
931
932     if mode.is_ipynb:
933         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
934     else:
935         dst_contents = format_str(src_contents, mode=mode)
936     if src_contents == dst_contents:
937         raise NothingChanged
938
939     if not fast and not mode.is_ipynb:
940         # Jupyter notebooks will already have been checked above.
941         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
942     return dst_contents
943
944
945 def validate_cell(src: str) -> None:
946     """Check that cell does not already contain TransformerManager transformations.
947
948     If a cell contains ``!ls``, then it'll be transformed to
949     ``get_ipython().system('ls')``. However, if the cell originally contained
950     ``get_ipython().system('ls')``, then it would get transformed in the same way:
951
952         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
953         "get_ipython().system('ls')\n"
954         >>> TransformerManager().transform_cell("!ls")
955         "get_ipython().system('ls')\n"
956
957     Due to the impossibility of safely roundtripping in such situations, cells
958     containing transformed magics will be ignored.
959     """
960     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
961         raise NothingChanged
962
963
964 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
965     """Format code in given cell of Jupyter notebook.
966
967     General idea is:
968
969       - if cell has trailing semicolon, remove it;
970       - if cell has IPython magics, mask them;
971       - format cell;
972       - reinstate IPython magics;
973       - reinstate trailing semicolon (if originally present);
974       - strip trailing newlines.
975
976     Cells with syntax errors will not be processed, as they
977     could potentially be automagics or multi-line magics, which
978     are currently not supported.
979     """
980     validate_cell(src)
981     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
982         src
983     )
984     try:
985         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
986     except SyntaxError:
987         raise NothingChanged from None
988     masked_dst = format_str(masked_src, mode=mode)
989     if not fast:
990         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
991     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
992     dst = put_trailing_semicolon_back(
993         dst_without_trailing_semicolon, has_trailing_semicolon
994     )
995     dst = dst.rstrip("\n")
996     if dst == src:
997         raise NothingChanged from None
998     return dst
999
1000
1001 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1002     """If notebook is marked as non-Python, don't format it.
1003
1004     All notebook metadata fields are optional, see
1005     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1006     if a notebook has empty metadata, we will try to parse it anyway.
1007     """
1008     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1009     if language is not None and language != "python":
1010         raise NothingChanged from None
1011
1012
1013 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1014     """Format Jupyter notebook.
1015
1016     Operate cell-by-cell, only on code cells, only for Python notebooks.
1017     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1018     """
1019     trailing_newline = src_contents[-1] == "\n"
1020     modified = False
1021     nb = json.loads(src_contents)
1022     validate_metadata(nb)
1023     for cell in nb["cells"]:
1024         if cell.get("cell_type", None) == "code":
1025             try:
1026                 src = "".join(cell["source"])
1027                 dst = format_cell(src, fast=fast, mode=mode)
1028             except NothingChanged:
1029                 pass
1030             else:
1031                 cell["source"] = dst.splitlines(keepends=True)
1032                 modified = True
1033     if modified:
1034         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1035         if trailing_newline:
1036             dst_contents = dst_contents + "\n"
1037         return dst_contents
1038     else:
1039         raise NothingChanged
1040
1041
1042 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1043     """Reformat a string and return new contents.
1044
1045     `mode` determines formatting options, such as how many characters per line are
1046     allowed.  Example:
1047
1048     >>> import black
1049     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1050     def f(arg: str = "") -> None:
1051         ...
1052
1053     A more complex example:
1054
1055     >>> print(
1056     ...   black.format_str(
1057     ...     "def f(arg:str='')->None: hey",
1058     ...     mode=black.Mode(
1059     ...       target_versions={black.TargetVersion.PY36},
1060     ...       line_length=10,
1061     ...       string_normalization=False,
1062     ...       is_pyi=False,
1063     ...     ),
1064     ...   ),
1065     ... )
1066     def f(
1067         arg: str = '',
1068     ) -> None:
1069         hey
1070
1071     """
1072     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1073     dst_contents = []
1074     future_imports = get_future_imports(src_node)
1075     if mode.target_versions:
1076         versions = mode.target_versions
1077     else:
1078         versions = detect_target_versions(src_node)
1079
1080     # TODO: fully drop support and this code hopefully in January 2022 :D
1081     if TargetVersion.PY27 in mode.target_versions or versions == {TargetVersion.PY27}:
1082         msg = (
1083             "DEPRECATION: Python 2 support will be removed in the first stable release "
1084             "expected in January 2022."
1085         )
1086         err(msg, fg="yellow", bold=True)
1087
1088     normalize_fmt_off(src_node)
1089     lines = LineGenerator(
1090         mode=mode,
1091         remove_u_prefix="unicode_literals" in future_imports
1092         or supports_feature(versions, Feature.UNICODE_LITERALS),
1093     )
1094     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1095     empty_line = Line(mode=mode)
1096     after = 0
1097     split_line_features = {
1098         feature
1099         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1100         if supports_feature(versions, feature)
1101     }
1102     for current_line in lines.visit(src_node):
1103         dst_contents.append(str(empty_line) * after)
1104         before, after = elt.maybe_empty_lines(current_line)
1105         dst_contents.append(str(empty_line) * before)
1106         for line in transform_line(
1107             current_line, mode=mode, features=split_line_features
1108         ):
1109             dst_contents.append(str(line))
1110     return "".join(dst_contents)
1111
1112
1113 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1114     """Return a tuple of (decoded_contents, encoding, newline).
1115
1116     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1117     universal newlines (i.e. only contains LF).
1118     """
1119     srcbuf = io.BytesIO(src)
1120     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1121     if not lines:
1122         return "", encoding, "\n"
1123
1124     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1125     srcbuf.seek(0)
1126     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1127         return tiow.read(), encoding, newline
1128
1129
1130 def get_features_used(node: Node) -> Set[Feature]:  # noqa: C901
1131     """Return a set of (relatively) new Python features used in this file.
1132
1133     Currently looking for:
1134     - f-strings;
1135     - underscores in numeric literals;
1136     - trailing commas after * or ** in function signatures and calls;
1137     - positional only arguments in function signatures and lambdas;
1138     - assignment expression;
1139     - relaxed decorator syntax;
1140     - print / exec statements;
1141     """
1142     features: Set[Feature] = set()
1143     for n in node.pre_order():
1144         if n.type == token.STRING:
1145             value_head = n.value[:2]  # type: ignore
1146             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1147                 features.add(Feature.F_STRINGS)
1148
1149         elif n.type == token.NUMBER:
1150             assert isinstance(n, Leaf)
1151             if "_" in n.value:
1152                 features.add(Feature.NUMERIC_UNDERSCORES)
1153             elif n.value.endswith(("L", "l")):
1154                 # Python 2: 10L
1155                 features.add(Feature.LONG_INT_LITERAL)
1156             elif len(n.value) >= 2 and n.value[0] == "0" and n.value[1].isdigit():
1157                 # Python 2: 0123; 00123; ...
1158                 if not all(char == "0" for char in n.value):
1159                     # although we don't want to match 0000 or similar
1160                     features.add(Feature.OCTAL_INT_LITERAL)
1161
1162         elif n.type == token.SLASH:
1163             if n.parent and n.parent.type in {
1164                 syms.typedargslist,
1165                 syms.arglist,
1166                 syms.varargslist,
1167             }:
1168                 features.add(Feature.POS_ONLY_ARGUMENTS)
1169
1170         elif n.type == token.COLONEQUAL:
1171             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1172
1173         elif n.type == syms.decorator:
1174             if len(n.children) > 1 and not is_simple_decorator_expression(
1175                 n.children[1]
1176             ):
1177                 features.add(Feature.RELAXED_DECORATORS)
1178
1179         elif (
1180             n.type in {syms.typedargslist, syms.arglist}
1181             and n.children
1182             and n.children[-1].type == token.COMMA
1183         ):
1184             if n.type == syms.typedargslist:
1185                 feature = Feature.TRAILING_COMMA_IN_DEF
1186             else:
1187                 feature = Feature.TRAILING_COMMA_IN_CALL
1188
1189             for ch in n.children:
1190                 if ch.type in STARS:
1191                     features.add(feature)
1192
1193                 if ch.type == syms.argument:
1194                     for argch in ch.children:
1195                         if argch.type in STARS:
1196                             features.add(feature)
1197
1198         # Python 2 only features (for its deprecation) except for integers, see above
1199         elif n.type == syms.print_stmt:
1200             features.add(Feature.PRINT_STMT)
1201         elif n.type == syms.exec_stmt:
1202             features.add(Feature.EXEC_STMT)
1203         elif n.type == syms.tfpdef:
1204             # def set_position((x, y), value):
1205             #     ...
1206             features.add(Feature.AUTOMATIC_PARAMETER_UNPACKING)
1207         elif n.type == syms.except_clause:
1208             # try:
1209             #     ...
1210             # except Exception, err:
1211             #     ...
1212             if len(n.children) >= 4:
1213                 if n.children[-2].type == token.COMMA:
1214                     features.add(Feature.COMMA_STYLE_EXCEPT)
1215         elif n.type == syms.raise_stmt:
1216             # raise Exception, "msg"
1217             if len(n.children) >= 4:
1218                 if n.children[-2].type == token.COMMA:
1219                     features.add(Feature.COMMA_STYLE_RAISE)
1220         elif n.type == token.BACKQUOTE:
1221             # `i'm surprised this ever existed`
1222             features.add(Feature.BACKQUOTE_REPR)
1223
1224     return features
1225
1226
1227 def detect_target_versions(node: Node) -> Set[TargetVersion]:
1228     """Detect the version to target based on the nodes used."""
1229     features = get_features_used(node)
1230     return {
1231         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1232     }
1233
1234
1235 def get_future_imports(node: Node) -> Set[str]:
1236     """Return a set of __future__ imports in the file."""
1237     imports: Set[str] = set()
1238
1239     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1240         for child in children:
1241             if isinstance(child, Leaf):
1242                 if child.type == token.NAME:
1243                     yield child.value
1244
1245             elif child.type == syms.import_as_name:
1246                 orig_name = child.children[0]
1247                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1248                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1249                 yield orig_name.value
1250
1251             elif child.type == syms.import_as_names:
1252                 yield from get_imports_from_children(child.children)
1253
1254             else:
1255                 raise AssertionError("Invalid syntax parsing imports")
1256
1257     for child in node.children:
1258         if child.type != syms.simple_stmt:
1259             break
1260
1261         first_child = child.children[0]
1262         if isinstance(first_child, Leaf):
1263             # Continue looking if we see a docstring; otherwise stop.
1264             if (
1265                 len(child.children) == 2
1266                 and first_child.type == token.STRING
1267                 and child.children[1].type == token.NEWLINE
1268             ):
1269                 continue
1270
1271             break
1272
1273         elif first_child.type == syms.import_from:
1274             module_name = first_child.children[1]
1275             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1276                 break
1277
1278             imports |= set(get_imports_from_children(first_child.children[3:]))
1279         else:
1280             break
1281
1282     return imports
1283
1284
1285 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1286     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1287     try:
1288         src_ast = parse_ast(src)
1289     except Exception as exc:
1290         raise AssertionError(
1291             "cannot use --safe with this file; failed to parse source file."
1292         ) from exc
1293
1294     try:
1295         dst_ast = parse_ast(dst)
1296     except Exception as exc:
1297         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1298         raise AssertionError(
1299             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1300             "Please report a bug on https://github.com/psf/black/issues.  "
1301             f"This invalid output might be helpful: {log}"
1302         ) from None
1303
1304     src_ast_str = "\n".join(stringify_ast(src_ast))
1305     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1306     if src_ast_str != dst_ast_str:
1307         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1308         raise AssertionError(
1309             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1310             f" source on pass {pass_num}.  Please report a bug on "
1311             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1312         ) from None
1313
1314
1315 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1316     """Raise AssertionError if `dst` reformats differently the second time."""
1317     newdst = format_str(dst, mode=mode)
1318     if dst != newdst:
1319         log = dump_to_file(
1320             str(mode),
1321             diff(src, dst, "source", "first pass"),
1322             diff(dst, newdst, "first pass", "second pass"),
1323         )
1324         raise AssertionError(
1325             "INTERNAL ERROR: Black produced different code on the second pass of the"
1326             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1327             f"  This diff might be helpful: {log}"
1328         ) from None
1329
1330
1331 @contextmanager
1332 def nullcontext() -> Iterator[None]:
1333     """Return an empty context manager.
1334
1335     To be used like `nullcontext` in Python 3.7.
1336     """
1337     yield
1338
1339
1340 def patch_click() -> None:
1341     """Make Click not crash on Python 3.6 with LANG=C.
1342
1343     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1344     default which restricts paths that it can access during the lifetime of the
1345     application.  Click refuses to work in this scenario by raising a RuntimeError.
1346
1347     In case of Black the likelihood that non-ASCII characters are going to be used in
1348     file paths is minimal since it's Python source code.  Moreover, this crash was
1349     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1350     """
1351     try:
1352         from click import core
1353         from click import _unicodefun
1354     except ModuleNotFoundError:
1355         return
1356
1357     for module in (core, _unicodefun):
1358         if hasattr(module, "_verify_python3_env"):
1359             module._verify_python3_env = lambda: None  # type: ignore
1360         if hasattr(module, "_verify_python_env"):
1361             module._verify_python_env = lambda: None  # type: ignore
1362
1363
1364 def patched_main() -> None:
1365     maybe_install_uvloop()
1366     freeze_support()
1367     patch_click()
1368     main()
1369
1370
1371 if __name__ == "__main__":
1372     patched_main()