]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

[trivial] Use proper test cases on `unittest` (#2775)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 import click
34 from click.core import ParameterSource
35 from dataclasses import replace
36 from mypy_extensions import mypyc_attr
37
38 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
39 from black.const import STDIN_PLACEHOLDER
40 from black.nodes import STARS, syms, is_simple_decorator_expression
41 from black.nodes import is_string_token
42 from black.lines import Line, EmptyLineTracker
43 from black.linegen import transform_line, LineGenerator, LN
44 from black.comments import normalize_fmt_off
45 from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
46 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
47 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
48 from black.concurrency import cancel, shutdown, maybe_install_uvloop
49 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
50 from black.report import Report, Changed, NothingChanged
51 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
52 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
53 from black.files import wrap_stream_for_windows
54 from black.parsing import InvalidInput  # noqa F401
55 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
56 from black.handle_ipynb_magics import (
57     mask_cell,
58     unmask_cell,
59     remove_trailing_semicolon,
60     put_trailing_semicolon_back,
61     TRANSFORMED_MAGICS,
62     PYTHON_CELL_MAGICS,
63     jupyter_dependencies_are_installed,
64 )
65
66
67 # lib2to3 fork
68 from blib2to3.pytree import Node, Leaf
69 from blib2to3.pgen2 import token
70
71 from _black_version import version as __version__
72
73 COMPILED = Path(__file__).suffix in (".pyd", ".so")
74
75 # types
76 FileContent = str
77 Encoding = str
78 NewLine = str
79
80
81 class WriteBack(Enum):
82     NO = 0
83     YES = 1
84     DIFF = 2
85     CHECK = 3
86     COLOR_DIFF = 4
87
88     @classmethod
89     def from_configuration(
90         cls, *, check: bool, diff: bool, color: bool = False
91     ) -> "WriteBack":
92         if check and not diff:
93             return cls.CHECK
94
95         if diff and color:
96             return cls.COLOR_DIFF
97
98         return cls.DIFF if diff else cls.YES
99
100
101 # Legacy name, left for integrations.
102 FileMode = Mode
103
104 DEFAULT_WORKERS = os.cpu_count()
105
106
107 def read_pyproject_toml(
108     ctx: click.Context, param: click.Parameter, value: Optional[str]
109 ) -> Optional[str]:
110     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
111
112     Returns the path to a successfully found and read configuration file, None
113     otherwise.
114     """
115     if not value:
116         value = find_pyproject_toml(ctx.params.get("src", ()))
117         if value is None:
118             return None
119
120     try:
121         config = parse_pyproject_toml(value)
122     except (OSError, ValueError) as e:
123         raise click.FileError(
124             filename=value, hint=f"Error reading configuration file: {e}"
125         ) from None
126
127     if not config:
128         return None
129     else:
130         # Sanitize the values to be Click friendly. For more information please see:
131         # https://github.com/psf/black/issues/1458
132         # https://github.com/pallets/click/issues/1567
133         config = {
134             k: str(v) if not isinstance(v, (list, dict)) else v
135             for k, v in config.items()
136         }
137
138     target_version = config.get("target_version")
139     if target_version is not None and not isinstance(target_version, list):
140         raise click.BadOptionUsage(
141             "target-version", "Config key target-version must be a list"
142         )
143
144     default_map: Dict[str, Any] = {}
145     if ctx.default_map:
146         default_map.update(ctx.default_map)
147     default_map.update(config)
148
149     ctx.default_map = default_map
150     return value
151
152
153 def target_version_option_callback(
154     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
155 ) -> List[TargetVersion]:
156     """Compute the target versions from a --target-version flag.
157
158     This is its own function because mypy couldn't infer the type correctly
159     when it was a lambda, causing mypyc trouble.
160     """
161     return [TargetVersion[val.upper()] for val in v]
162
163
164 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
165     """Compile a regular expression string in `regex`.
166
167     If it contains newlines, use verbose mode.
168     """
169     if "\n" in regex:
170         regex = "(?x)" + regex
171     compiled: Pattern[str] = re.compile(regex)
172     return compiled
173
174
175 def validate_regex(
176     ctx: click.Context,
177     param: click.Parameter,
178     value: Optional[str],
179 ) -> Optional[Pattern[str]]:
180     try:
181         return re_compile_maybe_verbose(value) if value is not None else None
182     except re.error as e:
183         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
184
185
186 @click.command(
187     context_settings={"help_option_names": ["-h", "--help"]},
188     # While Click does set this field automatically using the docstring, mypyc
189     # (annoyingly) strips 'em so we need to set it here too.
190     help="The uncompromising code formatter.",
191 )
192 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
193 @click.option(
194     "-l",
195     "--line-length",
196     type=int,
197     default=DEFAULT_LINE_LENGTH,
198     help="How many characters per line to allow.",
199     show_default=True,
200 )
201 @click.option(
202     "-t",
203     "--target-version",
204     type=click.Choice([v.name.lower() for v in TargetVersion]),
205     callback=target_version_option_callback,
206     multiple=True,
207     help=(
208         "Python versions that should be supported by Black's output. [default: per-file"
209         " auto-detection]"
210     ),
211 )
212 @click.option(
213     "--pyi",
214     is_flag=True,
215     help=(
216         "Format all input files like typing stubs regardless of file extension (useful"
217         " when piping source on standard input)."
218     ),
219 )
220 @click.option(
221     "--ipynb",
222     is_flag=True,
223     help=(
224         "Format all input files like Jupyter Notebooks regardless of file extension "
225         "(useful when piping source on standard input)."
226     ),
227 )
228 @click.option(
229     "-S",
230     "--skip-string-normalization",
231     is_flag=True,
232     help="Don't normalize string quotes or prefixes.",
233 )
234 @click.option(
235     "-C",
236     "--skip-magic-trailing-comma",
237     is_flag=True,
238     help="Don't use trailing commas as a reason to split lines.",
239 )
240 @click.option(
241     "--experimental-string-processing",
242     is_flag=True,
243     hidden=True,
244     help=(
245         "Experimental option that performs more normalization on string literals."
246         " Currently disabled because it leads to some crashes."
247     ),
248 )
249 @click.option(
250     "--check",
251     is_flag=True,
252     help=(
253         "Don't write the files back, just return the status. Return code 0 means"
254         " nothing would change. Return code 1 means some files would be reformatted."
255         " Return code 123 means there was an internal error."
256     ),
257 )
258 @click.option(
259     "--diff",
260     is_flag=True,
261     help="Don't write the files back, just output a diff for each file on stdout.",
262 )
263 @click.option(
264     "--color/--no-color",
265     is_flag=True,
266     help="Show colored diff. Only applies when `--diff` is given.",
267 )
268 @click.option(
269     "--fast/--safe",
270     is_flag=True,
271     help="If --fast given, skip temporary sanity checks. [default: --safe]",
272 )
273 @click.option(
274     "--required-version",
275     type=str,
276     help=(
277         "Require a specific version of Black to be running (useful for unifying results"
278         " across many environments e.g. with a pyproject.toml file)."
279     ),
280 )
281 @click.option(
282     "--include",
283     type=str,
284     default=DEFAULT_INCLUDES,
285     callback=validate_regex,
286     help=(
287         "A regular expression that matches files and directories that should be"
288         " included on recursive searches. An empty value means all files are included"
289         " regardless of the name. Use forward slashes for directories on all platforms"
290         " (Windows, too). Exclusions are calculated first, inclusions later."
291     ),
292     show_default=True,
293 )
294 @click.option(
295     "--exclude",
296     type=str,
297     callback=validate_regex,
298     help=(
299         "A regular expression that matches files and directories that should be"
300         " excluded on recursive searches. An empty value means no paths are excluded."
301         " Use forward slashes for directories on all platforms (Windows, too)."
302         " Exclusions are calculated first, inclusions later. [default:"
303         f" {DEFAULT_EXCLUDES}]"
304     ),
305     show_default=False,
306 )
307 @click.option(
308     "--extend-exclude",
309     type=str,
310     callback=validate_regex,
311     help=(
312         "Like --exclude, but adds additional files and directories on top of the"
313         " excluded ones. (Useful if you simply want to add to the default)"
314     ),
315 )
316 @click.option(
317     "--force-exclude",
318     type=str,
319     callback=validate_regex,
320     help=(
321         "Like --exclude, but files and directories matching this regex will be "
322         "excluded even when they are passed explicitly as arguments."
323     ),
324 )
325 @click.option(
326     "--stdin-filename",
327     type=str,
328     help=(
329         "The name of the file when passing it through stdin. Useful to make "
330         "sure Black will respect --force-exclude option on some "
331         "editors that rely on using stdin."
332     ),
333 )
334 @click.option(
335     "-W",
336     "--workers",
337     type=click.IntRange(min=1),
338     default=DEFAULT_WORKERS,
339     show_default=True,
340     help="Number of parallel workers",
341 )
342 @click.option(
343     "-q",
344     "--quiet",
345     is_flag=True,
346     help=(
347         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
348         " those with 2>/dev/null."
349     ),
350 )
351 @click.option(
352     "-v",
353     "--verbose",
354     is_flag=True,
355     help=(
356         "Also emit messages to stderr about files that were not changed or were ignored"
357         " due to exclusion patterns."
358     ),
359 )
360 @click.version_option(
361     version=__version__,
362     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
363 )
364 @click.argument(
365     "src",
366     nargs=-1,
367     type=click.Path(
368         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
369     ),
370     is_eager=True,
371     metavar="SRC ...",
372 )
373 @click.option(
374     "--config",
375     type=click.Path(
376         exists=True,
377         file_okay=True,
378         dir_okay=False,
379         readable=True,
380         allow_dash=False,
381         path_type=str,
382     ),
383     is_eager=True,
384     callback=read_pyproject_toml,
385     help="Read configuration from FILE path.",
386 )
387 @click.pass_context
388 def main(
389     ctx: click.Context,
390     code: Optional[str],
391     line_length: int,
392     target_version: List[TargetVersion],
393     check: bool,
394     diff: bool,
395     color: bool,
396     fast: bool,
397     pyi: bool,
398     ipynb: bool,
399     skip_string_normalization: bool,
400     skip_magic_trailing_comma: bool,
401     experimental_string_processing: bool,
402     quiet: bool,
403     verbose: bool,
404     required_version: Optional[str],
405     include: Pattern[str],
406     exclude: Optional[Pattern[str]],
407     extend_exclude: Optional[Pattern[str]],
408     force_exclude: Optional[Pattern[str]],
409     stdin_filename: Optional[str],
410     workers: int,
411     src: Tuple[str, ...],
412     config: Optional[str],
413 ) -> None:
414     """The uncompromising code formatter."""
415     ctx.ensure_object(dict)
416     root, method = find_project_root(src) if code is None else (None, None)
417     ctx.obj["root"] = root
418
419     if verbose:
420         if root:
421             out(
422                 f"Identified `{root}` as project root containing a {method}.",
423                 fg="blue",
424             )
425
426             normalized = [
427                 (normalize_path_maybe_ignore(Path(source), root), source)
428                 for source in src
429             ]
430             srcs_string = ", ".join(
431                 [
432                     f'"{_norm}"'
433                     if _norm
434                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
435                     for _norm, source in normalized
436                 ]
437             )
438             out(f"Sources to be formatted: {srcs_string}", fg="blue")
439
440         if config:
441             config_source = ctx.get_parameter_source("config")
442             if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP):
443                 out("Using configuration from project root.", fg="blue")
444             else:
445                 out(f"Using configuration in '{config}'.", fg="blue")
446
447     error_msg = "Oh no! 💥 💔 💥"
448     if required_version and required_version != __version__:
449         err(
450             f"{error_msg} The required version `{required_version}` does not match"
451             f" the running version `{__version__}`!"
452         )
453         ctx.exit(1)
454     if ipynb and pyi:
455         err("Cannot pass both `pyi` and `ipynb` flags!")
456         ctx.exit(1)
457
458     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
459     if target_version:
460         versions = set(target_version)
461     else:
462         # We'll autodetect later.
463         versions = set()
464     mode = Mode(
465         target_versions=versions,
466         line_length=line_length,
467         is_pyi=pyi,
468         is_ipynb=ipynb,
469         string_normalization=not skip_string_normalization,
470         magic_trailing_comma=not skip_magic_trailing_comma,
471         experimental_string_processing=experimental_string_processing,
472     )
473
474     if code is not None:
475         # Run in quiet mode by default with -c; the extra output isn't useful.
476         # You can still pass -v to get verbose output.
477         quiet = True
478
479     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
480
481     if code is not None:
482         reformat_code(
483             content=code, fast=fast, write_back=write_back, mode=mode, report=report
484         )
485     else:
486         try:
487             sources = get_sources(
488                 ctx=ctx,
489                 src=src,
490                 quiet=quiet,
491                 verbose=verbose,
492                 include=include,
493                 exclude=exclude,
494                 extend_exclude=extend_exclude,
495                 force_exclude=force_exclude,
496                 report=report,
497                 stdin_filename=stdin_filename,
498             )
499         except GitWildMatchPatternError:
500             ctx.exit(1)
501
502         path_empty(
503             sources,
504             "No Python files are present to be formatted. Nothing to do 😴",
505             quiet,
506             verbose,
507             ctx,
508         )
509
510         if len(sources) == 1:
511             reformat_one(
512                 src=sources.pop(),
513                 fast=fast,
514                 write_back=write_back,
515                 mode=mode,
516                 report=report,
517             )
518         else:
519             reformat_many(
520                 sources=sources,
521                 fast=fast,
522                 write_back=write_back,
523                 mode=mode,
524                 report=report,
525                 workers=workers,
526             )
527
528     if verbose or not quiet:
529         if code is None and (verbose or report.change_count or report.failure_count):
530             out()
531         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
532         if code is None:
533             click.echo(str(report), err=True)
534     ctx.exit(report.return_code)
535
536
537 def get_sources(
538     *,
539     ctx: click.Context,
540     src: Tuple[str, ...],
541     quiet: bool,
542     verbose: bool,
543     include: Pattern[str],
544     exclude: Optional[Pattern[str]],
545     extend_exclude: Optional[Pattern[str]],
546     force_exclude: Optional[Pattern[str]],
547     report: "Report",
548     stdin_filename: Optional[str],
549 ) -> Set[Path]:
550     """Compute the set of files to be formatted."""
551     sources: Set[Path] = set()
552     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
553
554     if exclude is None:
555         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
556         gitignore = get_gitignore(ctx.obj["root"])
557     else:
558         gitignore = None
559
560     for s in src:
561         if s == "-" and stdin_filename:
562             p = Path(stdin_filename)
563             is_stdin = True
564         else:
565             p = Path(s)
566             is_stdin = False
567
568         if is_stdin or p.is_file():
569             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
570             if normalized_path is None:
571                 continue
572
573             normalized_path = "/" + normalized_path
574             # Hard-exclude any files that matches the `--force-exclude` regex.
575             if force_exclude:
576                 force_exclude_match = force_exclude.search(normalized_path)
577             else:
578                 force_exclude_match = None
579             if force_exclude_match and force_exclude_match.group(0):
580                 report.path_ignored(p, "matches the --force-exclude regular expression")
581                 continue
582
583             if is_stdin:
584                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
585
586             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
587                 verbose=verbose, quiet=quiet
588             ):
589                 continue
590
591             sources.add(p)
592         elif p.is_dir():
593             sources.update(
594                 gen_python_files(
595                     p.iterdir(),
596                     ctx.obj["root"],
597                     include,
598                     exclude,
599                     extend_exclude,
600                     force_exclude,
601                     report,
602                     gitignore,
603                     verbose=verbose,
604                     quiet=quiet,
605                 )
606             )
607         elif s == "-":
608             sources.add(p)
609         else:
610             err(f"invalid path: {s}")
611     return sources
612
613
614 def path_empty(
615     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
616 ) -> None:
617     """
618     Exit if there is no `src` provided for formatting
619     """
620     if not src:
621         if verbose or not quiet:
622             out(msg)
623         ctx.exit(0)
624
625
626 def reformat_code(
627     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
628 ) -> None:
629     """
630     Reformat and print out `content` without spawning child processes.
631     Similar to `reformat_one`, but for string content.
632
633     `fast`, `write_back`, and `mode` options are passed to
634     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
635     """
636     path = Path("<string>")
637     try:
638         changed = Changed.NO
639         if format_stdin_to_stdout(
640             content=content, fast=fast, write_back=write_back, mode=mode
641         ):
642             changed = Changed.YES
643         report.done(path, changed)
644     except Exception as exc:
645         if report.verbose:
646             traceback.print_exc()
647         report.failed(path, str(exc))
648
649
650 def reformat_one(
651     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
652 ) -> None:
653     """Reformat a single file under `src` without spawning child processes.
654
655     `fast`, `write_back`, and `mode` options are passed to
656     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
657     """
658     try:
659         changed = Changed.NO
660
661         if str(src) == "-":
662             is_stdin = True
663         elif str(src).startswith(STDIN_PLACEHOLDER):
664             is_stdin = True
665             # Use the original name again in case we want to print something
666             # to the user
667             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
668         else:
669             is_stdin = False
670
671         if is_stdin:
672             if src.suffix == ".pyi":
673                 mode = replace(mode, is_pyi=True)
674             elif src.suffix == ".ipynb":
675                 mode = replace(mode, is_ipynb=True)
676             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
677                 changed = Changed.YES
678         else:
679             cache: Cache = {}
680             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
681                 cache = read_cache(mode)
682                 res_src = src.resolve()
683                 res_src_s = str(res_src)
684                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
685                     changed = Changed.CACHED
686             if changed is not Changed.CACHED and format_file_in_place(
687                 src, fast=fast, write_back=write_back, mode=mode
688             ):
689                 changed = Changed.YES
690             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
691                 write_back is WriteBack.CHECK and changed is Changed.NO
692             ):
693                 write_cache(cache, [src], mode)
694         report.done(src, changed)
695     except Exception as exc:
696         if report.verbose:
697             traceback.print_exc()
698         report.failed(src, str(exc))
699
700
701 # diff-shades depends on being to monkeypatch this function to operate. I know it's
702 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
703 @mypyc_attr(patchable=True)
704 def reformat_many(
705     sources: Set[Path],
706     fast: bool,
707     write_back: WriteBack,
708     mode: Mode,
709     report: "Report",
710     workers: Optional[int],
711 ) -> None:
712     """Reformat multiple files using a ProcessPoolExecutor."""
713     executor: Executor
714     loop = asyncio.get_event_loop()
715     worker_count = workers if workers is not None else DEFAULT_WORKERS
716     if sys.platform == "win32":
717         # Work around https://bugs.python.org/issue26903
718         assert worker_count is not None
719         worker_count = min(worker_count, 60)
720     try:
721         executor = ProcessPoolExecutor(max_workers=worker_count)
722     except (ImportError, NotImplementedError, OSError):
723         # we arrive here if the underlying system does not support multi-processing
724         # like in AWS Lambda or Termux, in which case we gracefully fallback to
725         # a ThreadPoolExecutor with just a single worker (more workers would not do us
726         # any good due to the Global Interpreter Lock)
727         executor = ThreadPoolExecutor(max_workers=1)
728
729     try:
730         loop.run_until_complete(
731             schedule_formatting(
732                 sources=sources,
733                 fast=fast,
734                 write_back=write_back,
735                 mode=mode,
736                 report=report,
737                 loop=loop,
738                 executor=executor,
739             )
740         )
741     finally:
742         shutdown(loop)
743         if executor is not None:
744             executor.shutdown()
745
746
747 async def schedule_formatting(
748     sources: Set[Path],
749     fast: bool,
750     write_back: WriteBack,
751     mode: Mode,
752     report: "Report",
753     loop: asyncio.AbstractEventLoop,
754     executor: Executor,
755 ) -> None:
756     """Run formatting of `sources` in parallel using the provided `executor`.
757
758     (Use ProcessPoolExecutors for actual parallelism.)
759
760     `write_back`, `fast`, and `mode` options are passed to
761     :func:`format_file_in_place`.
762     """
763     cache: Cache = {}
764     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
765         cache = read_cache(mode)
766         sources, cached = filter_cached(cache, sources)
767         for src in sorted(cached):
768             report.done(src, Changed.CACHED)
769     if not sources:
770         return
771
772     cancelled = []
773     sources_to_cache = []
774     lock = None
775     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
776         # For diff output, we need locks to ensure we don't interleave output
777         # from different processes.
778         manager = Manager()
779         lock = manager.Lock()
780     tasks = {
781         asyncio.ensure_future(
782             loop.run_in_executor(
783                 executor, format_file_in_place, src, fast, mode, write_back, lock
784             )
785         ): src
786         for src in sorted(sources)
787     }
788     pending = tasks.keys()
789     try:
790         loop.add_signal_handler(signal.SIGINT, cancel, pending)
791         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
792     except NotImplementedError:
793         # There are no good alternatives for these on Windows.
794         pass
795     while pending:
796         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
797         for task in done:
798             src = tasks.pop(task)
799             if task.cancelled():
800                 cancelled.append(task)
801             elif task.exception():
802                 report.failed(src, str(task.exception()))
803             else:
804                 changed = Changed.YES if task.result() else Changed.NO
805                 # If the file was written back or was successfully checked as
806                 # well-formatted, store this information in the cache.
807                 if write_back is WriteBack.YES or (
808                     write_back is WriteBack.CHECK and changed is Changed.NO
809                 ):
810                     sources_to_cache.append(src)
811                 report.done(src, changed)
812     if cancelled:
813         if sys.version_info >= (3, 7):
814             await asyncio.gather(*cancelled, return_exceptions=True)
815         else:
816             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
817     if sources_to_cache:
818         write_cache(cache, sources_to_cache, mode)
819
820
821 def format_file_in_place(
822     src: Path,
823     fast: bool,
824     mode: Mode,
825     write_back: WriteBack = WriteBack.NO,
826     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
827 ) -> bool:
828     """Format file under `src` path. Return True if changed.
829
830     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
831     code to the file.
832     `mode` and `fast` options are passed to :func:`format_file_contents`.
833     """
834     if src.suffix == ".pyi":
835         mode = replace(mode, is_pyi=True)
836     elif src.suffix == ".ipynb":
837         mode = replace(mode, is_ipynb=True)
838
839     then = datetime.utcfromtimestamp(src.stat().st_mtime)
840     with open(src, "rb") as buf:
841         src_contents, encoding, newline = decode_bytes(buf.read())
842     try:
843         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
844     except NothingChanged:
845         return False
846     except JSONDecodeError:
847         raise ValueError(
848             f"File '{src}' cannot be parsed as valid Jupyter notebook."
849         ) from None
850
851     if write_back == WriteBack.YES:
852         with open(src, "w", encoding=encoding, newline=newline) as f:
853             f.write(dst_contents)
854     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
855         now = datetime.utcnow()
856         src_name = f"{src}\t{then} +0000"
857         dst_name = f"{src}\t{now} +0000"
858         if mode.is_ipynb:
859             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
860         else:
861             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
862
863         if write_back == WriteBack.COLOR_DIFF:
864             diff_contents = color_diff(diff_contents)
865
866         with lock or nullcontext():
867             f = io.TextIOWrapper(
868                 sys.stdout.buffer,
869                 encoding=encoding,
870                 newline=newline,
871                 write_through=True,
872             )
873             f = wrap_stream_for_windows(f)
874             f.write(diff_contents)
875             f.detach()
876
877     return True
878
879
880 def format_stdin_to_stdout(
881     fast: bool,
882     *,
883     content: Optional[str] = None,
884     write_back: WriteBack = WriteBack.NO,
885     mode: Mode,
886 ) -> bool:
887     """Format file on stdin. Return True if changed.
888
889     If content is None, it's read from sys.stdin.
890
891     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
892     write a diff to stdout. The `mode` argument is passed to
893     :func:`format_file_contents`.
894     """
895     then = datetime.utcnow()
896
897     if content is None:
898         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
899     else:
900         src, encoding, newline = content, "utf-8", ""
901
902     dst = src
903     try:
904         dst = format_file_contents(src, fast=fast, mode=mode)
905         return True
906
907     except NothingChanged:
908         return False
909
910     finally:
911         f = io.TextIOWrapper(
912             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
913         )
914         if write_back == WriteBack.YES:
915             # Make sure there's a newline after the content
916             if dst and dst[-1] != "\n":
917                 dst += "\n"
918             f.write(dst)
919         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
920             now = datetime.utcnow()
921             src_name = f"STDIN\t{then} +0000"
922             dst_name = f"STDOUT\t{now} +0000"
923             d = diff(src, dst, src_name, dst_name)
924             if write_back == WriteBack.COLOR_DIFF:
925                 d = color_diff(d)
926                 f = wrap_stream_for_windows(f)
927             f.write(d)
928         f.detach()
929
930
931 def check_stability_and_equivalence(
932     src_contents: str, dst_contents: str, *, mode: Mode
933 ) -> None:
934     """Perform stability and equivalence checks.
935
936     Raise AssertionError if source and destination contents are not
937     equivalent, or if a second pass of the formatter would format the
938     content differently.
939     """
940     assert_equivalent(src_contents, dst_contents)
941
942     # Forced second pass to work around optional trailing commas (becoming
943     # forced trailing commas on pass 2) interacting differently with optional
944     # parentheses.  Admittedly ugly.
945     dst_contents_pass2 = format_str(dst_contents, mode=mode)
946     if dst_contents != dst_contents_pass2:
947         dst_contents = dst_contents_pass2
948         assert_equivalent(src_contents, dst_contents, pass_num=2)
949         assert_stable(src_contents, dst_contents, mode=mode)
950     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
951     # the same as `dst_contents_pass2`.
952
953
954 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
955     """Reformat contents of a file and return new contents.
956
957     If `fast` is False, additionally confirm that the reformatted code is
958     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
959     `mode` is passed to :func:`format_str`.
960     """
961     if not src_contents.strip():
962         raise NothingChanged
963
964     if mode.is_ipynb:
965         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
966     else:
967         dst_contents = format_str(src_contents, mode=mode)
968     if src_contents == dst_contents:
969         raise NothingChanged
970
971     if not fast and not mode.is_ipynb:
972         # Jupyter notebooks will already have been checked above.
973         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
974     return dst_contents
975
976
977 def validate_cell(src: str) -> None:
978     """Check that cell does not already contain TransformerManager transformations,
979     or non-Python cell magics, which might cause tokenizer_rt to break because of
980     indentations.
981
982     If a cell contains ``!ls``, then it'll be transformed to
983     ``get_ipython().system('ls')``. However, if the cell originally contained
984     ``get_ipython().system('ls')``, then it would get transformed in the same way:
985
986         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
987         "get_ipython().system('ls')\n"
988         >>> TransformerManager().transform_cell("!ls")
989         "get_ipython().system('ls')\n"
990
991     Due to the impossibility of safely roundtripping in such situations, cells
992     containing transformed magics will be ignored.
993     """
994     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
995         raise NothingChanged
996     if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS:
997         raise NothingChanged
998
999
1000 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
1001     """Format code in given cell of Jupyter notebook.
1002
1003     General idea is:
1004
1005       - if cell has trailing semicolon, remove it;
1006       - if cell has IPython magics, mask them;
1007       - format cell;
1008       - reinstate IPython magics;
1009       - reinstate trailing semicolon (if originally present);
1010       - strip trailing newlines.
1011
1012     Cells with syntax errors will not be processed, as they
1013     could potentially be automagics or multi-line magics, which
1014     are currently not supported.
1015     """
1016     validate_cell(src)
1017     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
1018         src
1019     )
1020     try:
1021         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
1022     except SyntaxError:
1023         raise NothingChanged from None
1024     masked_dst = format_str(masked_src, mode=mode)
1025     if not fast:
1026         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
1027     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1028     dst = put_trailing_semicolon_back(
1029         dst_without_trailing_semicolon, has_trailing_semicolon
1030     )
1031     dst = dst.rstrip("\n")
1032     if dst == src:
1033         raise NothingChanged from None
1034     return dst
1035
1036
1037 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1038     """If notebook is marked as non-Python, don't format it.
1039
1040     All notebook metadata fields are optional, see
1041     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1042     if a notebook has empty metadata, we will try to parse it anyway.
1043     """
1044     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1045     if language is not None and language != "python":
1046         raise NothingChanged from None
1047
1048
1049 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1050     """Format Jupyter notebook.
1051
1052     Operate cell-by-cell, only on code cells, only for Python notebooks.
1053     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1054     """
1055     trailing_newline = src_contents[-1] == "\n"
1056     modified = False
1057     nb = json.loads(src_contents)
1058     validate_metadata(nb)
1059     for cell in nb["cells"]:
1060         if cell.get("cell_type", None) == "code":
1061             try:
1062                 src = "".join(cell["source"])
1063                 dst = format_cell(src, fast=fast, mode=mode)
1064             except NothingChanged:
1065                 pass
1066             else:
1067                 cell["source"] = dst.splitlines(keepends=True)
1068                 modified = True
1069     if modified:
1070         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1071         if trailing_newline:
1072             dst_contents = dst_contents + "\n"
1073         return dst_contents
1074     else:
1075         raise NothingChanged
1076
1077
1078 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1079     """Reformat a string and return new contents.
1080
1081     `mode` determines formatting options, such as how many characters per line are
1082     allowed.  Example:
1083
1084     >>> import black
1085     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1086     def f(arg: str = "") -> None:
1087         ...
1088
1089     A more complex example:
1090
1091     >>> print(
1092     ...   black.format_str(
1093     ...     "def f(arg:str='')->None: hey",
1094     ...     mode=black.Mode(
1095     ...       target_versions={black.TargetVersion.PY36},
1096     ...       line_length=10,
1097     ...       string_normalization=False,
1098     ...       is_pyi=False,
1099     ...     ),
1100     ...   ),
1101     ... )
1102     def f(
1103         arg: str = '',
1104     ) -> None:
1105         hey
1106
1107     """
1108     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1109     dst_contents = []
1110     future_imports = get_future_imports(src_node)
1111     if mode.target_versions:
1112         versions = mode.target_versions
1113     else:
1114         versions = detect_target_versions(src_node, future_imports=future_imports)
1115
1116     normalize_fmt_off(src_node)
1117     lines = LineGenerator(mode=mode)
1118     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1119     empty_line = Line(mode=mode)
1120     after = 0
1121     split_line_features = {
1122         feature
1123         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1124         if supports_feature(versions, feature)
1125     }
1126     for current_line in lines.visit(src_node):
1127         dst_contents.append(str(empty_line) * after)
1128         before, after = elt.maybe_empty_lines(current_line)
1129         dst_contents.append(str(empty_line) * before)
1130         for line in transform_line(
1131             current_line, mode=mode, features=split_line_features
1132         ):
1133             dst_contents.append(str(line))
1134     return "".join(dst_contents)
1135
1136
1137 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1138     """Return a tuple of (decoded_contents, encoding, newline).
1139
1140     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1141     universal newlines (i.e. only contains LF).
1142     """
1143     srcbuf = io.BytesIO(src)
1144     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1145     if not lines:
1146         return "", encoding, "\n"
1147
1148     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1149     srcbuf.seek(0)
1150     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1151         return tiow.read(), encoding, newline
1152
1153
1154 def get_features_used(  # noqa: C901
1155     node: Node, *, future_imports: Optional[Set[str]] = None
1156 ) -> Set[Feature]:
1157     """Return a set of (relatively) new Python features used in this file.
1158
1159     Currently looking for:
1160     - f-strings;
1161     - underscores in numeric literals;
1162     - trailing commas after * or ** in function signatures and calls;
1163     - positional only arguments in function signatures and lambdas;
1164     - assignment expression;
1165     - relaxed decorator syntax;
1166     - usage of __future__ flags (annotations);
1167     - print / exec statements;
1168     """
1169     features: Set[Feature] = set()
1170     if future_imports:
1171         features |= {
1172             FUTURE_FLAG_TO_FEATURE[future_import]
1173             for future_import in future_imports
1174             if future_import in FUTURE_FLAG_TO_FEATURE
1175         }
1176
1177     for n in node.pre_order():
1178         if is_string_token(n):
1179             value_head = n.value[:2]
1180             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1181                 features.add(Feature.F_STRINGS)
1182
1183         elif n.type == token.NUMBER:
1184             assert isinstance(n, Leaf)
1185             if "_" in n.value:
1186                 features.add(Feature.NUMERIC_UNDERSCORES)
1187
1188         elif n.type == token.SLASH:
1189             if n.parent and n.parent.type in {
1190                 syms.typedargslist,
1191                 syms.arglist,
1192                 syms.varargslist,
1193             }:
1194                 features.add(Feature.POS_ONLY_ARGUMENTS)
1195
1196         elif n.type == token.COLONEQUAL:
1197             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1198
1199         elif n.type == syms.decorator:
1200             if len(n.children) > 1 and not is_simple_decorator_expression(
1201                 n.children[1]
1202             ):
1203                 features.add(Feature.RELAXED_DECORATORS)
1204
1205         elif (
1206             n.type in {syms.typedargslist, syms.arglist}
1207             and n.children
1208             and n.children[-1].type == token.COMMA
1209         ):
1210             if n.type == syms.typedargslist:
1211                 feature = Feature.TRAILING_COMMA_IN_DEF
1212             else:
1213                 feature = Feature.TRAILING_COMMA_IN_CALL
1214
1215             for ch in n.children:
1216                 if ch.type in STARS:
1217                     features.add(feature)
1218
1219                 if ch.type == syms.argument:
1220                     for argch in ch.children:
1221                         if argch.type in STARS:
1222                             features.add(feature)
1223
1224         elif (
1225             n.type in {syms.return_stmt, syms.yield_expr}
1226             and len(n.children) >= 2
1227             and n.children[1].type == syms.testlist_star_expr
1228             and any(child.type == syms.star_expr for child in n.children[1].children)
1229         ):
1230             features.add(Feature.UNPACKING_ON_FLOW)
1231
1232         elif (
1233             n.type == syms.annassign
1234             and len(n.children) >= 4
1235             and n.children[3].type == syms.testlist_star_expr
1236         ):
1237             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1238
1239     return features
1240
1241
1242 def detect_target_versions(
1243     node: Node, *, future_imports: Optional[Set[str]] = None
1244 ) -> Set[TargetVersion]:
1245     """Detect the version to target based on the nodes used."""
1246     features = get_features_used(node, future_imports=future_imports)
1247     return {
1248         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1249     }
1250
1251
1252 def get_future_imports(node: Node) -> Set[str]:
1253     """Return a set of __future__ imports in the file."""
1254     imports: Set[str] = set()
1255
1256     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1257         for child in children:
1258             if isinstance(child, Leaf):
1259                 if child.type == token.NAME:
1260                     yield child.value
1261
1262             elif child.type == syms.import_as_name:
1263                 orig_name = child.children[0]
1264                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1265                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1266                 yield orig_name.value
1267
1268             elif child.type == syms.import_as_names:
1269                 yield from get_imports_from_children(child.children)
1270
1271             else:
1272                 raise AssertionError("Invalid syntax parsing imports")
1273
1274     for child in node.children:
1275         if child.type != syms.simple_stmt:
1276             break
1277
1278         first_child = child.children[0]
1279         if isinstance(first_child, Leaf):
1280             # Continue looking if we see a docstring; otherwise stop.
1281             if (
1282                 len(child.children) == 2
1283                 and first_child.type == token.STRING
1284                 and child.children[1].type == token.NEWLINE
1285             ):
1286                 continue
1287
1288             break
1289
1290         elif first_child.type == syms.import_from:
1291             module_name = first_child.children[1]
1292             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1293                 break
1294
1295             imports |= set(get_imports_from_children(first_child.children[3:]))
1296         else:
1297             break
1298
1299     return imports
1300
1301
1302 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1303     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1304     try:
1305         src_ast = parse_ast(src)
1306     except Exception as exc:
1307         raise AssertionError(
1308             f"cannot use --safe with this file; failed to parse source file: {exc}"
1309         ) from exc
1310
1311     try:
1312         dst_ast = parse_ast(dst)
1313     except Exception as exc:
1314         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1315         raise AssertionError(
1316             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1317             "Please report a bug on https://github.com/psf/black/issues.  "
1318             f"This invalid output might be helpful: {log}"
1319         ) from None
1320
1321     src_ast_str = "\n".join(stringify_ast(src_ast))
1322     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1323     if src_ast_str != dst_ast_str:
1324         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1325         raise AssertionError(
1326             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1327             f" source on pass {pass_num}.  Please report a bug on "
1328             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1329         ) from None
1330
1331
1332 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1333     """Raise AssertionError if `dst` reformats differently the second time."""
1334     newdst = format_str(dst, mode=mode)
1335     if dst != newdst:
1336         log = dump_to_file(
1337             str(mode),
1338             diff(src, dst, "source", "first pass"),
1339             diff(dst, newdst, "first pass", "second pass"),
1340         )
1341         raise AssertionError(
1342             "INTERNAL ERROR: Black produced different code on the second pass of the"
1343             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1344             f"  This diff might be helpful: {log}"
1345         ) from None
1346
1347
1348 @contextmanager
1349 def nullcontext() -> Iterator[None]:
1350     """Return an empty context manager.
1351
1352     To be used like `nullcontext` in Python 3.7.
1353     """
1354     yield
1355
1356
1357 def patch_click() -> None:
1358     """Make Click not crash on Python 3.6 with LANG=C.
1359
1360     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1361     default which restricts paths that it can access during the lifetime of the
1362     application.  Click refuses to work in this scenario by raising a RuntimeError.
1363
1364     In case of Black the likelihood that non-ASCII characters are going to be used in
1365     file paths is minimal since it's Python source code.  Moreover, this crash was
1366     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1367     """
1368     try:
1369         from click import core
1370         from click import _unicodefun
1371     except ModuleNotFoundError:
1372         return
1373
1374     for module in (core, _unicodefun):
1375         if hasattr(module, "_verify_python3_env"):
1376             module._verify_python3_env = lambda: None  # type: ignore
1377         if hasattr(module, "_verify_python_env"):
1378             module._verify_python_env = lambda: None  # type: ignore
1379
1380
1381 def patched_main() -> None:
1382     maybe_install_uvloop()
1383     freeze_support()
1384     patch_click()
1385     main()
1386
1387
1388 if __name__ == "__main__":
1389     patched_main()