]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add a test case to torture.py (#2822)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Sized,
30     Tuple,
31     Union,
32 )
33
34 import click
35 from click.core import ParameterSource
36 from dataclasses import replace
37 from mypy_extensions import mypyc_attr
38
39 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
40 from black.const import STDIN_PLACEHOLDER
41 from black.nodes import STARS, syms, is_simple_decorator_expression
42 from black.nodes import is_string_token
43 from black.lines import Line, EmptyLineTracker
44 from black.linegen import transform_line, LineGenerator, LN
45 from black.comments import normalize_fmt_off
46 from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
47 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
48 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
49 from black.concurrency import cancel, shutdown, maybe_install_uvloop
50 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
51 from black.report import Report, Changed, NothingChanged
52 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
53 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
54 from black.files import wrap_stream_for_windows
55 from black.parsing import InvalidInput  # noqa F401
56 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
57 from black.handle_ipynb_magics import (
58     mask_cell,
59     unmask_cell,
60     remove_trailing_semicolon,
61     put_trailing_semicolon_back,
62     TRANSFORMED_MAGICS,
63     PYTHON_CELL_MAGICS,
64     jupyter_dependencies_are_installed,
65 )
66
67
68 # lib2to3 fork
69 from blib2to3.pytree import Node, Leaf
70 from blib2to3.pgen2 import token
71
72 from _black_version import version as __version__
73
74 COMPILED = Path(__file__).suffix in (".pyd", ".so")
75
76 # types
77 FileContent = str
78 Encoding = str
79 NewLine = str
80
81
82 class WriteBack(Enum):
83     NO = 0
84     YES = 1
85     DIFF = 2
86     CHECK = 3
87     COLOR_DIFF = 4
88
89     @classmethod
90     def from_configuration(
91         cls, *, check: bool, diff: bool, color: bool = False
92     ) -> "WriteBack":
93         if check and not diff:
94             return cls.CHECK
95
96         if diff and color:
97             return cls.COLOR_DIFF
98
99         return cls.DIFF if diff else cls.YES
100
101
102 # Legacy name, left for integrations.
103 FileMode = Mode
104
105 DEFAULT_WORKERS = os.cpu_count()
106
107
108 def read_pyproject_toml(
109     ctx: click.Context, param: click.Parameter, value: Optional[str]
110 ) -> Optional[str]:
111     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
112
113     Returns the path to a successfully found and read configuration file, None
114     otherwise.
115     """
116     if not value:
117         value = find_pyproject_toml(ctx.params.get("src", ()))
118         if value is None:
119             return None
120
121     try:
122         config = parse_pyproject_toml(value)
123     except (OSError, ValueError) as e:
124         raise click.FileError(
125             filename=value, hint=f"Error reading configuration file: {e}"
126         ) from None
127
128     if not config:
129         return None
130     else:
131         # Sanitize the values to be Click friendly. For more information please see:
132         # https://github.com/psf/black/issues/1458
133         # https://github.com/pallets/click/issues/1567
134         config = {
135             k: str(v) if not isinstance(v, (list, dict)) else v
136             for k, v in config.items()
137         }
138
139     target_version = config.get("target_version")
140     if target_version is not None and not isinstance(target_version, list):
141         raise click.BadOptionUsage(
142             "target-version", "Config key target-version must be a list"
143         )
144
145     default_map: Dict[str, Any] = {}
146     if ctx.default_map:
147         default_map.update(ctx.default_map)
148     default_map.update(config)
149
150     ctx.default_map = default_map
151     return value
152
153
154 def target_version_option_callback(
155     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
156 ) -> List[TargetVersion]:
157     """Compute the target versions from a --target-version flag.
158
159     This is its own function because mypy couldn't infer the type correctly
160     when it was a lambda, causing mypyc trouble.
161     """
162     return [TargetVersion[val.upper()] for val in v]
163
164
165 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
166     """Compile a regular expression string in `regex`.
167
168     If it contains newlines, use verbose mode.
169     """
170     if "\n" in regex:
171         regex = "(?x)" + regex
172     compiled: Pattern[str] = re.compile(regex)
173     return compiled
174
175
176 def validate_regex(
177     ctx: click.Context,
178     param: click.Parameter,
179     value: Optional[str],
180 ) -> Optional[Pattern[str]]:
181     try:
182         return re_compile_maybe_verbose(value) if value is not None else None
183     except re.error as e:
184         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
185
186
187 @click.command(
188     context_settings={"help_option_names": ["-h", "--help"]},
189     # While Click does set this field automatically using the docstring, mypyc
190     # (annoyingly) strips 'em so we need to set it here too.
191     help="The uncompromising code formatter.",
192 )
193 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
194 @click.option(
195     "-l",
196     "--line-length",
197     type=int,
198     default=DEFAULT_LINE_LENGTH,
199     help="How many characters per line to allow.",
200     show_default=True,
201 )
202 @click.option(
203     "-t",
204     "--target-version",
205     type=click.Choice([v.name.lower() for v in TargetVersion]),
206     callback=target_version_option_callback,
207     multiple=True,
208     help=(
209         "Python versions that should be supported by Black's output. [default: per-file"
210         " auto-detection]"
211     ),
212 )
213 @click.option(
214     "--pyi",
215     is_flag=True,
216     help=(
217         "Format all input files like typing stubs regardless of file extension (useful"
218         " when piping source on standard input)."
219     ),
220 )
221 @click.option(
222     "--ipynb",
223     is_flag=True,
224     help=(
225         "Format all input files like Jupyter Notebooks regardless of file extension "
226         "(useful when piping source on standard input)."
227     ),
228 )
229 @click.option(
230     "--python-cell-magics",
231     multiple=True,
232     help=(
233         "When processing Jupyter Notebooks, add the given magic to the list"
234         f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
235         " Useful for formatting cells with custom python magics."
236     ),
237     default=[],
238 )
239 @click.option(
240     "-S",
241     "--skip-string-normalization",
242     is_flag=True,
243     help="Don't normalize string quotes or prefixes.",
244 )
245 @click.option(
246     "-C",
247     "--skip-magic-trailing-comma",
248     is_flag=True,
249     help="Don't use trailing commas as a reason to split lines.",
250 )
251 @click.option(
252     "--experimental-string-processing",
253     is_flag=True,
254     hidden=True,
255     help="(DEPRECATED and now included in --preview) Normalize string literals.",
256 )
257 @click.option(
258     "--preview",
259     is_flag=True,
260     help=(
261         "Enable potentially disruptive style changes that will be added to Black's main"
262         " functionality in the next major release."
263     ),
264 )
265 @click.option(
266     "--check",
267     is_flag=True,
268     help=(
269         "Don't write the files back, just return the status. Return code 0 means"
270         " nothing would change. Return code 1 means some files would be reformatted."
271         " Return code 123 means there was an internal error."
272     ),
273 )
274 @click.option(
275     "--diff",
276     is_flag=True,
277     help="Don't write the files back, just output a diff for each file on stdout.",
278 )
279 @click.option(
280     "--color/--no-color",
281     is_flag=True,
282     help="Show colored diff. Only applies when `--diff` is given.",
283 )
284 @click.option(
285     "--fast/--safe",
286     is_flag=True,
287     help="If --fast given, skip temporary sanity checks. [default: --safe]",
288 )
289 @click.option(
290     "--required-version",
291     type=str,
292     help=(
293         "Require a specific version of Black to be running (useful for unifying results"
294         " across many environments e.g. with a pyproject.toml file)."
295     ),
296 )
297 @click.option(
298     "--include",
299     type=str,
300     default=DEFAULT_INCLUDES,
301     callback=validate_regex,
302     help=(
303         "A regular expression that matches files and directories that should be"
304         " included on recursive searches. An empty value means all files are included"
305         " regardless of the name. Use forward slashes for directories on all platforms"
306         " (Windows, too). Exclusions are calculated first, inclusions later."
307     ),
308     show_default=True,
309 )
310 @click.option(
311     "--exclude",
312     type=str,
313     callback=validate_regex,
314     help=(
315         "A regular expression that matches files and directories that should be"
316         " excluded on recursive searches. An empty value means no paths are excluded."
317         " Use forward slashes for directories on all platforms (Windows, too)."
318         " Exclusions are calculated first, inclusions later. [default:"
319         f" {DEFAULT_EXCLUDES}]"
320     ),
321     show_default=False,
322 )
323 @click.option(
324     "--extend-exclude",
325     type=str,
326     callback=validate_regex,
327     help=(
328         "Like --exclude, but adds additional files and directories on top of the"
329         " excluded ones. (Useful if you simply want to add to the default)"
330     ),
331 )
332 @click.option(
333     "--force-exclude",
334     type=str,
335     callback=validate_regex,
336     help=(
337         "Like --exclude, but files and directories matching this regex will be "
338         "excluded even when they are passed explicitly as arguments."
339     ),
340 )
341 @click.option(
342     "--stdin-filename",
343     type=str,
344     help=(
345         "The name of the file when passing it through stdin. Useful to make "
346         "sure Black will respect --force-exclude option on some "
347         "editors that rely on using stdin."
348     ),
349 )
350 @click.option(
351     "-W",
352     "--workers",
353     type=click.IntRange(min=1),
354     default=DEFAULT_WORKERS,
355     show_default=True,
356     help="Number of parallel workers",
357 )
358 @click.option(
359     "-q",
360     "--quiet",
361     is_flag=True,
362     help=(
363         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
364         " those with 2>/dev/null."
365     ),
366 )
367 @click.option(
368     "-v",
369     "--verbose",
370     is_flag=True,
371     help=(
372         "Also emit messages to stderr about files that were not changed or were ignored"
373         " due to exclusion patterns."
374     ),
375 )
376 @click.version_option(
377     version=__version__,
378     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
379 )
380 @click.argument(
381     "src",
382     nargs=-1,
383     type=click.Path(
384         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
385     ),
386     is_eager=True,
387     metavar="SRC ...",
388 )
389 @click.option(
390     "--config",
391     type=click.Path(
392         exists=True,
393         file_okay=True,
394         dir_okay=False,
395         readable=True,
396         allow_dash=False,
397         path_type=str,
398     ),
399     is_eager=True,
400     callback=read_pyproject_toml,
401     help="Read configuration from FILE path.",
402 )
403 @click.pass_context
404 def main(
405     ctx: click.Context,
406     code: Optional[str],
407     line_length: int,
408     target_version: List[TargetVersion],
409     check: bool,
410     diff: bool,
411     color: bool,
412     fast: bool,
413     pyi: bool,
414     ipynb: bool,
415     python_cell_magics: Sequence[str],
416     skip_string_normalization: bool,
417     skip_magic_trailing_comma: bool,
418     experimental_string_processing: bool,
419     preview: bool,
420     quiet: bool,
421     verbose: bool,
422     required_version: Optional[str],
423     include: Pattern[str],
424     exclude: Optional[Pattern[str]],
425     extend_exclude: Optional[Pattern[str]],
426     force_exclude: Optional[Pattern[str]],
427     stdin_filename: Optional[str],
428     workers: int,
429     src: Tuple[str, ...],
430     config: Optional[str],
431 ) -> None:
432     """The uncompromising code formatter."""
433     ctx.ensure_object(dict)
434
435     if src and code is not None:
436         out(
437             main.get_usage(ctx)
438             + "\n\n'SRC' and 'code' cannot be passed simultaneously."
439         )
440         ctx.exit(1)
441     if not src and code is None:
442         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
443         ctx.exit(1)
444
445     root, method = find_project_root(src) if code is None else (None, None)
446     ctx.obj["root"] = root
447
448     if verbose:
449         if root:
450             out(
451                 f"Identified `{root}` as project root containing a {method}.",
452                 fg="blue",
453             )
454
455             normalized = [
456                 (normalize_path_maybe_ignore(Path(source), root), source)
457                 for source in src
458             ]
459             srcs_string = ", ".join(
460                 [
461                     f'"{_norm}"'
462                     if _norm
463                     else f'\033[31m"{source} (skipping - invalid)"\033[34m'
464                     for _norm, source in normalized
465                 ]
466             )
467             out(f"Sources to be formatted: {srcs_string}", fg="blue")
468
469         if config:
470             config_source = ctx.get_parameter_source("config")
471             if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP):
472                 out("Using configuration from project root.", fg="blue")
473             else:
474                 out(f"Using configuration in '{config}'.", fg="blue")
475
476     error_msg = "Oh no! 💥 💔 💥"
477     if required_version and required_version != __version__:
478         err(
479             f"{error_msg} The required version `{required_version}` does not match"
480             f" the running version `{__version__}`!"
481         )
482         ctx.exit(1)
483     if ipynb and pyi:
484         err("Cannot pass both `pyi` and `ipynb` flags!")
485         ctx.exit(1)
486
487     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
488     if target_version:
489         versions = set(target_version)
490     else:
491         # We'll autodetect later.
492         versions = set()
493     mode = Mode(
494         target_versions=versions,
495         line_length=line_length,
496         is_pyi=pyi,
497         is_ipynb=ipynb,
498         string_normalization=not skip_string_normalization,
499         magic_trailing_comma=not skip_magic_trailing_comma,
500         experimental_string_processing=experimental_string_processing,
501         preview=preview,
502         python_cell_magics=set(python_cell_magics),
503     )
504
505     if code is not None:
506         # Run in quiet mode by default with -c; the extra output isn't useful.
507         # You can still pass -v to get verbose output.
508         quiet = True
509
510     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
511
512     if code is not None:
513         reformat_code(
514             content=code, fast=fast, write_back=write_back, mode=mode, report=report
515         )
516     else:
517         try:
518             sources = get_sources(
519                 ctx=ctx,
520                 src=src,
521                 quiet=quiet,
522                 verbose=verbose,
523                 include=include,
524                 exclude=exclude,
525                 extend_exclude=extend_exclude,
526                 force_exclude=force_exclude,
527                 report=report,
528                 stdin_filename=stdin_filename,
529             )
530         except GitWildMatchPatternError:
531             ctx.exit(1)
532
533         path_empty(
534             sources,
535             "No Python files are present to be formatted. Nothing to do 😴",
536             quiet,
537             verbose,
538             ctx,
539         )
540
541         if len(sources) == 1:
542             reformat_one(
543                 src=sources.pop(),
544                 fast=fast,
545                 write_back=write_back,
546                 mode=mode,
547                 report=report,
548             )
549         else:
550             reformat_many(
551                 sources=sources,
552                 fast=fast,
553                 write_back=write_back,
554                 mode=mode,
555                 report=report,
556                 workers=workers,
557             )
558
559     if verbose or not quiet:
560         if code is None and (verbose or report.change_count or report.failure_count):
561             out()
562         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
563         if code is None:
564             click.echo(str(report), err=True)
565     ctx.exit(report.return_code)
566
567
568 def get_sources(
569     *,
570     ctx: click.Context,
571     src: Tuple[str, ...],
572     quiet: bool,
573     verbose: bool,
574     include: Pattern[str],
575     exclude: Optional[Pattern[str]],
576     extend_exclude: Optional[Pattern[str]],
577     force_exclude: Optional[Pattern[str]],
578     report: "Report",
579     stdin_filename: Optional[str],
580 ) -> Set[Path]:
581     """Compute the set of files to be formatted."""
582     sources: Set[Path] = set()
583
584     if exclude is None:
585         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
586         gitignore = get_gitignore(ctx.obj["root"])
587     else:
588         gitignore = None
589
590     for s in src:
591         if s == "-" and stdin_filename:
592             p = Path(stdin_filename)
593             is_stdin = True
594         else:
595             p = Path(s)
596             is_stdin = False
597
598         if is_stdin or p.is_file():
599             normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
600             if normalized_path is None:
601                 continue
602
603             normalized_path = "/" + normalized_path
604             # Hard-exclude any files that matches the `--force-exclude` regex.
605             if force_exclude:
606                 force_exclude_match = force_exclude.search(normalized_path)
607             else:
608                 force_exclude_match = None
609             if force_exclude_match and force_exclude_match.group(0):
610                 report.path_ignored(p, "matches the --force-exclude regular expression")
611                 continue
612
613             if is_stdin:
614                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
615
616             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
617                 verbose=verbose, quiet=quiet
618             ):
619                 continue
620
621             sources.add(p)
622         elif p.is_dir():
623             sources.update(
624                 gen_python_files(
625                     p.iterdir(),
626                     ctx.obj["root"],
627                     include,
628                     exclude,
629                     extend_exclude,
630                     force_exclude,
631                     report,
632                     gitignore,
633                     verbose=verbose,
634                     quiet=quiet,
635                 )
636             )
637         elif s == "-":
638             sources.add(p)
639         else:
640             err(f"invalid path: {s}")
641     return sources
642
643
644 def path_empty(
645     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
646 ) -> None:
647     """
648     Exit if there is no `src` provided for formatting
649     """
650     if not src:
651         if verbose or not quiet:
652             out(msg)
653         ctx.exit(0)
654
655
656 def reformat_code(
657     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
658 ) -> None:
659     """
660     Reformat and print out `content` without spawning child processes.
661     Similar to `reformat_one`, but for string content.
662
663     `fast`, `write_back`, and `mode` options are passed to
664     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
665     """
666     path = Path("<string>")
667     try:
668         changed = Changed.NO
669         if format_stdin_to_stdout(
670             content=content, fast=fast, write_back=write_back, mode=mode
671         ):
672             changed = Changed.YES
673         report.done(path, changed)
674     except Exception as exc:
675         if report.verbose:
676             traceback.print_exc()
677         report.failed(path, str(exc))
678
679
680 def reformat_one(
681     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
682 ) -> None:
683     """Reformat a single file under `src` without spawning child processes.
684
685     `fast`, `write_back`, and `mode` options are passed to
686     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
687     """
688     try:
689         changed = Changed.NO
690
691         if str(src) == "-":
692             is_stdin = True
693         elif str(src).startswith(STDIN_PLACEHOLDER):
694             is_stdin = True
695             # Use the original name again in case we want to print something
696             # to the user
697             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
698         else:
699             is_stdin = False
700
701         if is_stdin:
702             if src.suffix == ".pyi":
703                 mode = replace(mode, is_pyi=True)
704             elif src.suffix == ".ipynb":
705                 mode = replace(mode, is_ipynb=True)
706             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
707                 changed = Changed.YES
708         else:
709             cache: Cache = {}
710             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
711                 cache = read_cache(mode)
712                 res_src = src.resolve()
713                 res_src_s = str(res_src)
714                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
715                     changed = Changed.CACHED
716             if changed is not Changed.CACHED and format_file_in_place(
717                 src, fast=fast, write_back=write_back, mode=mode
718             ):
719                 changed = Changed.YES
720             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
721                 write_back is WriteBack.CHECK and changed is Changed.NO
722             ):
723                 write_cache(cache, [src], mode)
724         report.done(src, changed)
725     except Exception as exc:
726         if report.verbose:
727             traceback.print_exc()
728         report.failed(src, str(exc))
729
730
731 # diff-shades depends on being to monkeypatch this function to operate. I know it's
732 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
733 @mypyc_attr(patchable=True)
734 def reformat_many(
735     sources: Set[Path],
736     fast: bool,
737     write_back: WriteBack,
738     mode: Mode,
739     report: "Report",
740     workers: Optional[int],
741 ) -> None:
742     """Reformat multiple files using a ProcessPoolExecutor."""
743     executor: Executor
744     loop = asyncio.get_event_loop()
745     worker_count = workers if workers is not None else DEFAULT_WORKERS
746     if sys.platform == "win32":
747         # Work around https://bugs.python.org/issue26903
748         assert worker_count is not None
749         worker_count = min(worker_count, 60)
750     try:
751         executor = ProcessPoolExecutor(max_workers=worker_count)
752     except (ImportError, NotImplementedError, OSError):
753         # we arrive here if the underlying system does not support multi-processing
754         # like in AWS Lambda or Termux, in which case we gracefully fallback to
755         # a ThreadPoolExecutor with just a single worker (more workers would not do us
756         # any good due to the Global Interpreter Lock)
757         executor = ThreadPoolExecutor(max_workers=1)
758
759     try:
760         loop.run_until_complete(
761             schedule_formatting(
762                 sources=sources,
763                 fast=fast,
764                 write_back=write_back,
765                 mode=mode,
766                 report=report,
767                 loop=loop,
768                 executor=executor,
769             )
770         )
771     finally:
772         shutdown(loop)
773         if executor is not None:
774             executor.shutdown()
775
776
777 async def schedule_formatting(
778     sources: Set[Path],
779     fast: bool,
780     write_back: WriteBack,
781     mode: Mode,
782     report: "Report",
783     loop: asyncio.AbstractEventLoop,
784     executor: Executor,
785 ) -> None:
786     """Run formatting of `sources` in parallel using the provided `executor`.
787
788     (Use ProcessPoolExecutors for actual parallelism.)
789
790     `write_back`, `fast`, and `mode` options are passed to
791     :func:`format_file_in_place`.
792     """
793     cache: Cache = {}
794     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
795         cache = read_cache(mode)
796         sources, cached = filter_cached(cache, sources)
797         for src in sorted(cached):
798             report.done(src, Changed.CACHED)
799     if not sources:
800         return
801
802     cancelled = []
803     sources_to_cache = []
804     lock = None
805     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
806         # For diff output, we need locks to ensure we don't interleave output
807         # from different processes.
808         manager = Manager()
809         lock = manager.Lock()
810     tasks = {
811         asyncio.ensure_future(
812             loop.run_in_executor(
813                 executor, format_file_in_place, src, fast, mode, write_back, lock
814             )
815         ): src
816         for src in sorted(sources)
817     }
818     pending = tasks.keys()
819     try:
820         loop.add_signal_handler(signal.SIGINT, cancel, pending)
821         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
822     except NotImplementedError:
823         # There are no good alternatives for these on Windows.
824         pass
825     while pending:
826         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
827         for task in done:
828             src = tasks.pop(task)
829             if task.cancelled():
830                 cancelled.append(task)
831             elif task.exception():
832                 report.failed(src, str(task.exception()))
833             else:
834                 changed = Changed.YES if task.result() else Changed.NO
835                 # If the file was written back or was successfully checked as
836                 # well-formatted, store this information in the cache.
837                 if write_back is WriteBack.YES or (
838                     write_back is WriteBack.CHECK and changed is Changed.NO
839                 ):
840                     sources_to_cache.append(src)
841                 report.done(src, changed)
842     if cancelled:
843         if sys.version_info >= (3, 7):
844             await asyncio.gather(*cancelled, return_exceptions=True)
845         else:
846             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
847     if sources_to_cache:
848         write_cache(cache, sources_to_cache, mode)
849
850
851 def format_file_in_place(
852     src: Path,
853     fast: bool,
854     mode: Mode,
855     write_back: WriteBack = WriteBack.NO,
856     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
857 ) -> bool:
858     """Format file under `src` path. Return True if changed.
859
860     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
861     code to the file.
862     `mode` and `fast` options are passed to :func:`format_file_contents`.
863     """
864     if src.suffix == ".pyi":
865         mode = replace(mode, is_pyi=True)
866     elif src.suffix == ".ipynb":
867         mode = replace(mode, is_ipynb=True)
868
869     then = datetime.utcfromtimestamp(src.stat().st_mtime)
870     with open(src, "rb") as buf:
871         src_contents, encoding, newline = decode_bytes(buf.read())
872     try:
873         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
874     except NothingChanged:
875         return False
876     except JSONDecodeError:
877         raise ValueError(
878             f"File '{src}' cannot be parsed as valid Jupyter notebook."
879         ) from None
880
881     if write_back == WriteBack.YES:
882         with open(src, "w", encoding=encoding, newline=newline) as f:
883             f.write(dst_contents)
884     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
885         now = datetime.utcnow()
886         src_name = f"{src}\t{then} +0000"
887         dst_name = f"{src}\t{now} +0000"
888         if mode.is_ipynb:
889             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
890         else:
891             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
892
893         if write_back == WriteBack.COLOR_DIFF:
894             diff_contents = color_diff(diff_contents)
895
896         with lock or nullcontext():
897             f = io.TextIOWrapper(
898                 sys.stdout.buffer,
899                 encoding=encoding,
900                 newline=newline,
901                 write_through=True,
902             )
903             f = wrap_stream_for_windows(f)
904             f.write(diff_contents)
905             f.detach()
906
907     return True
908
909
910 def format_stdin_to_stdout(
911     fast: bool,
912     *,
913     content: Optional[str] = None,
914     write_back: WriteBack = WriteBack.NO,
915     mode: Mode,
916 ) -> bool:
917     """Format file on stdin. Return True if changed.
918
919     If content is None, it's read from sys.stdin.
920
921     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
922     write a diff to stdout. The `mode` argument is passed to
923     :func:`format_file_contents`.
924     """
925     then = datetime.utcnow()
926
927     if content is None:
928         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
929     else:
930         src, encoding, newline = content, "utf-8", ""
931
932     dst = src
933     try:
934         dst = format_file_contents(src, fast=fast, mode=mode)
935         return True
936
937     except NothingChanged:
938         return False
939
940     finally:
941         f = io.TextIOWrapper(
942             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
943         )
944         if write_back == WriteBack.YES:
945             # Make sure there's a newline after the content
946             if dst and dst[-1] != "\n":
947                 dst += "\n"
948             f.write(dst)
949         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
950             now = datetime.utcnow()
951             src_name = f"STDIN\t{then} +0000"
952             dst_name = f"STDOUT\t{now} +0000"
953             d = diff(src, dst, src_name, dst_name)
954             if write_back == WriteBack.COLOR_DIFF:
955                 d = color_diff(d)
956                 f = wrap_stream_for_windows(f)
957             f.write(d)
958         f.detach()
959
960
961 def check_stability_and_equivalence(
962     src_contents: str, dst_contents: str, *, mode: Mode
963 ) -> None:
964     """Perform stability and equivalence checks.
965
966     Raise AssertionError if source and destination contents are not
967     equivalent, or if a second pass of the formatter would format the
968     content differently.
969     """
970     assert_equivalent(src_contents, dst_contents)
971     assert_stable(src_contents, dst_contents, mode=mode)
972
973
974 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
975     """Reformat contents of a file and return new contents.
976
977     If `fast` is False, additionally confirm that the reformatted code is
978     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
979     `mode` is passed to :func:`format_str`.
980     """
981     if not src_contents.strip():
982         raise NothingChanged
983
984     if mode.is_ipynb:
985         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
986     else:
987         dst_contents = format_str(src_contents, mode=mode)
988     if src_contents == dst_contents:
989         raise NothingChanged
990
991     if not fast and not mode.is_ipynb:
992         # Jupyter notebooks will already have been checked above.
993         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
994     return dst_contents
995
996
997 def validate_cell(src: str, mode: Mode) -> None:
998     """Check that cell does not already contain TransformerManager transformations,
999     or non-Python cell magics, which might cause tokenizer_rt to break because of
1000     indentations.
1001
1002     If a cell contains ``!ls``, then it'll be transformed to
1003     ``get_ipython().system('ls')``. However, if the cell originally contained
1004     ``get_ipython().system('ls')``, then it would get transformed in the same way:
1005
1006         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
1007         "get_ipython().system('ls')\n"
1008         >>> TransformerManager().transform_cell("!ls")
1009         "get_ipython().system('ls')\n"
1010
1011     Due to the impossibility of safely roundtripping in such situations, cells
1012     containing transformed magics will be ignored.
1013     """
1014     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
1015         raise NothingChanged
1016     if (
1017         src[:2] == "%%"
1018         and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
1019     ):
1020         raise NothingChanged
1021
1022
1023 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
1024     """Format code in given cell of Jupyter notebook.
1025
1026     General idea is:
1027
1028       - if cell has trailing semicolon, remove it;
1029       - if cell has IPython magics, mask them;
1030       - format cell;
1031       - reinstate IPython magics;
1032       - reinstate trailing semicolon (if originally present);
1033       - strip trailing newlines.
1034
1035     Cells with syntax errors will not be processed, as they
1036     could potentially be automagics or multi-line magics, which
1037     are currently not supported.
1038     """
1039     validate_cell(src, mode)
1040     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
1041         src
1042     )
1043     try:
1044         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
1045     except SyntaxError:
1046         raise NothingChanged from None
1047     masked_dst = format_str(masked_src, mode=mode)
1048     if not fast:
1049         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
1050     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
1051     dst = put_trailing_semicolon_back(
1052         dst_without_trailing_semicolon, has_trailing_semicolon
1053     )
1054     dst = dst.rstrip("\n")
1055     if dst == src:
1056         raise NothingChanged from None
1057     return dst
1058
1059
1060 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1061     """If notebook is marked as non-Python, don't format it.
1062
1063     All notebook metadata fields are optional, see
1064     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1065     if a notebook has empty metadata, we will try to parse it anyway.
1066     """
1067     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1068     if language is not None and language != "python":
1069         raise NothingChanged from None
1070
1071
1072 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1073     """Format Jupyter notebook.
1074
1075     Operate cell-by-cell, only on code cells, only for Python notebooks.
1076     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1077     """
1078     trailing_newline = src_contents[-1] == "\n"
1079     modified = False
1080     nb = json.loads(src_contents)
1081     validate_metadata(nb)
1082     for cell in nb["cells"]:
1083         if cell.get("cell_type", None) == "code":
1084             try:
1085                 src = "".join(cell["source"])
1086                 dst = format_cell(src, fast=fast, mode=mode)
1087             except NothingChanged:
1088                 pass
1089             else:
1090                 cell["source"] = dst.splitlines(keepends=True)
1091                 modified = True
1092     if modified:
1093         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1094         if trailing_newline:
1095             dst_contents = dst_contents + "\n"
1096         return dst_contents
1097     else:
1098         raise NothingChanged
1099
1100
1101 def format_str(src_contents: str, *, mode: Mode) -> str:
1102     """Reformat a string and return new contents.
1103
1104     `mode` determines formatting options, such as how many characters per line are
1105     allowed.  Example:
1106
1107     >>> import black
1108     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1109     def f(arg: str = "") -> None:
1110         ...
1111
1112     A more complex example:
1113
1114     >>> print(
1115     ...   black.format_str(
1116     ...     "def f(arg:str='')->None: hey",
1117     ...     mode=black.Mode(
1118     ...       target_versions={black.TargetVersion.PY36},
1119     ...       line_length=10,
1120     ...       string_normalization=False,
1121     ...       is_pyi=False,
1122     ...     ),
1123     ...   ),
1124     ... )
1125     def f(
1126         arg: str = '',
1127     ) -> None:
1128         hey
1129
1130     """
1131     dst_contents = _format_str_once(src_contents, mode=mode)
1132     # Forced second pass to work around optional trailing commas (becoming
1133     # forced trailing commas on pass 2) interacting differently with optional
1134     # parentheses.  Admittedly ugly.
1135     if src_contents != dst_contents:
1136         return _format_str_once(dst_contents, mode=mode)
1137     return dst_contents
1138
1139
1140 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
1141     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1142     dst_contents = []
1143     future_imports = get_future_imports(src_node)
1144     if mode.target_versions:
1145         versions = mode.target_versions
1146     else:
1147         versions = detect_target_versions(src_node, future_imports=future_imports)
1148
1149     normalize_fmt_off(src_node)
1150     lines = LineGenerator(mode=mode)
1151     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1152     empty_line = Line(mode=mode)
1153     after = 0
1154     split_line_features = {
1155         feature
1156         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1157         if supports_feature(versions, feature)
1158     }
1159     for current_line in lines.visit(src_node):
1160         dst_contents.append(str(empty_line) * after)
1161         before, after = elt.maybe_empty_lines(current_line)
1162         dst_contents.append(str(empty_line) * before)
1163         for line in transform_line(
1164             current_line, mode=mode, features=split_line_features
1165         ):
1166             dst_contents.append(str(line))
1167     return "".join(dst_contents)
1168
1169
1170 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1171     """Return a tuple of (decoded_contents, encoding, newline).
1172
1173     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1174     universal newlines (i.e. only contains LF).
1175     """
1176     srcbuf = io.BytesIO(src)
1177     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1178     if not lines:
1179         return "", encoding, "\n"
1180
1181     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1182     srcbuf.seek(0)
1183     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1184         return tiow.read(), encoding, newline
1185
1186
1187 def get_features_used(  # noqa: C901
1188     node: Node, *, future_imports: Optional[Set[str]] = None
1189 ) -> Set[Feature]:
1190     """Return a set of (relatively) new Python features used in this file.
1191
1192     Currently looking for:
1193     - f-strings;
1194     - underscores in numeric literals;
1195     - trailing commas after * or ** in function signatures and calls;
1196     - positional only arguments in function signatures and lambdas;
1197     - assignment expression;
1198     - relaxed decorator syntax;
1199     - usage of __future__ flags (annotations);
1200     - print / exec statements;
1201     """
1202     features: Set[Feature] = set()
1203     if future_imports:
1204         features |= {
1205             FUTURE_FLAG_TO_FEATURE[future_import]
1206             for future_import in future_imports
1207             if future_import in FUTURE_FLAG_TO_FEATURE
1208         }
1209
1210     for n in node.pre_order():
1211         if is_string_token(n):
1212             value_head = n.value[:2]
1213             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1214                 features.add(Feature.F_STRINGS)
1215
1216         elif n.type == token.NUMBER:
1217             assert isinstance(n, Leaf)
1218             if "_" in n.value:
1219                 features.add(Feature.NUMERIC_UNDERSCORES)
1220
1221         elif n.type == token.SLASH:
1222             if n.parent and n.parent.type in {
1223                 syms.typedargslist,
1224                 syms.arglist,
1225                 syms.varargslist,
1226             }:
1227                 features.add(Feature.POS_ONLY_ARGUMENTS)
1228
1229         elif n.type == token.COLONEQUAL:
1230             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1231
1232         elif n.type == syms.decorator:
1233             if len(n.children) > 1 and not is_simple_decorator_expression(
1234                 n.children[1]
1235             ):
1236                 features.add(Feature.RELAXED_DECORATORS)
1237
1238         elif (
1239             n.type in {syms.typedargslist, syms.arglist}
1240             and n.children
1241             and n.children[-1].type == token.COMMA
1242         ):
1243             if n.type == syms.typedargslist:
1244                 feature = Feature.TRAILING_COMMA_IN_DEF
1245             else:
1246                 feature = Feature.TRAILING_COMMA_IN_CALL
1247
1248             for ch in n.children:
1249                 if ch.type in STARS:
1250                     features.add(feature)
1251
1252                 if ch.type == syms.argument:
1253                     for argch in ch.children:
1254                         if argch.type in STARS:
1255                             features.add(feature)
1256
1257         elif (
1258             n.type in {syms.return_stmt, syms.yield_expr}
1259             and len(n.children) >= 2
1260             and n.children[1].type == syms.testlist_star_expr
1261             and any(child.type == syms.star_expr for child in n.children[1].children)
1262         ):
1263             features.add(Feature.UNPACKING_ON_FLOW)
1264
1265         elif (
1266             n.type == syms.annassign
1267             and len(n.children) >= 4
1268             and n.children[3].type == syms.testlist_star_expr
1269         ):
1270             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1271
1272     return features
1273
1274
1275 def detect_target_versions(
1276     node: Node, *, future_imports: Optional[Set[str]] = None
1277 ) -> Set[TargetVersion]:
1278     """Detect the version to target based on the nodes used."""
1279     features = get_features_used(node, future_imports=future_imports)
1280     return {
1281         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1282     }
1283
1284
1285 def get_future_imports(node: Node) -> Set[str]:
1286     """Return a set of __future__ imports in the file."""
1287     imports: Set[str] = set()
1288
1289     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1290         for child in children:
1291             if isinstance(child, Leaf):
1292                 if child.type == token.NAME:
1293                     yield child.value
1294
1295             elif child.type == syms.import_as_name:
1296                 orig_name = child.children[0]
1297                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1298                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1299                 yield orig_name.value
1300
1301             elif child.type == syms.import_as_names:
1302                 yield from get_imports_from_children(child.children)
1303
1304             else:
1305                 raise AssertionError("Invalid syntax parsing imports")
1306
1307     for child in node.children:
1308         if child.type != syms.simple_stmt:
1309             break
1310
1311         first_child = child.children[0]
1312         if isinstance(first_child, Leaf):
1313             # Continue looking if we see a docstring; otherwise stop.
1314             if (
1315                 len(child.children) == 2
1316                 and first_child.type == token.STRING
1317                 and child.children[1].type == token.NEWLINE
1318             ):
1319                 continue
1320
1321             break
1322
1323         elif first_child.type == syms.import_from:
1324             module_name = first_child.children[1]
1325             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1326                 break
1327
1328             imports |= set(get_imports_from_children(first_child.children[3:]))
1329         else:
1330             break
1331
1332     return imports
1333
1334
1335 def assert_equivalent(src: str, dst: str) -> None:
1336     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1337     try:
1338         src_ast = parse_ast(src)
1339     except Exception as exc:
1340         raise AssertionError(
1341             f"cannot use --safe with this file; failed to parse source file AST: "
1342             f"{exc}\n"
1343             f"This could be caused by running Black with an older Python version "
1344             f"that does not support new syntax used in your source file."
1345         ) from exc
1346
1347     try:
1348         dst_ast = parse_ast(dst)
1349     except Exception as exc:
1350         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1351         raise AssertionError(
1352             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1353             "Please report a bug on https://github.com/psf/black/issues.  "
1354             f"This invalid output might be helpful: {log}"
1355         ) from None
1356
1357     src_ast_str = "\n".join(stringify_ast(src_ast))
1358     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1359     if src_ast_str != dst_ast_str:
1360         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1361         raise AssertionError(
1362             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1363             f" source.  Please report a bug on "
1364             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1365         ) from None
1366
1367
1368 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1369     """Raise AssertionError if `dst` reformats differently the second time."""
1370     # We shouldn't call format_str() here, because that formats the string
1371     # twice and may hide a bug where we bounce back and forth between two
1372     # versions.
1373     newdst = _format_str_once(dst, mode=mode)
1374     if dst != newdst:
1375         log = dump_to_file(
1376             str(mode),
1377             diff(src, dst, "source", "first pass"),
1378             diff(dst, newdst, "first pass", "second pass"),
1379         )
1380         raise AssertionError(
1381             "INTERNAL ERROR: Black produced different code on the second pass of the"
1382             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1383             f"  This diff might be helpful: {log}"
1384         ) from None
1385
1386
1387 @contextmanager
1388 def nullcontext() -> Iterator[None]:
1389     """Return an empty context manager.
1390
1391     To be used like `nullcontext` in Python 3.7.
1392     """
1393     yield
1394
1395
1396 def patch_click() -> None:
1397     """Make Click not crash on Python 3.6 with LANG=C.
1398
1399     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1400     default which restricts paths that it can access during the lifetime of the
1401     application.  Click refuses to work in this scenario by raising a RuntimeError.
1402
1403     In case of Black the likelihood that non-ASCII characters are going to be used in
1404     file paths is minimal since it's Python source code.  Moreover, this crash was
1405     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1406     """
1407     try:
1408         from click import core
1409         from click import _unicodefun
1410     except ModuleNotFoundError:
1411         return
1412
1413     for module in (core, _unicodefun):
1414         if hasattr(module, "_verify_python3_env"):
1415             module._verify_python3_env = lambda: None  # type: ignore
1416         if hasattr(module, "_verify_python_env"):
1417             module._verify_python_env = lambda: None  # type: ignore
1418
1419
1420 def patched_main() -> None:
1421     maybe_install_uvloop()
1422     freeze_support()
1423     patch_click()
1424     main()
1425
1426
1427 if __name__ == "__main__":
1428     patched_main()