]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Support multiple top-level as-expressions on case statements (#2716)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 import click
34 from dataclasses import replace
35 from mypy_extensions import mypyc_attr
36
37 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
38 from black.const import STDIN_PLACEHOLDER
39 from black.nodes import STARS, syms, is_simple_decorator_expression
40 from black.nodes import is_string_token
41 from black.lines import Line, EmptyLineTracker
42 from black.linegen import transform_line, LineGenerator, LN
43 from black.comments import normalize_fmt_off
44 from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
45 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
46 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
47 from black.concurrency import cancel, shutdown, maybe_install_uvloop
48 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
49 from black.report import Report, Changed, NothingChanged
50 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
51 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
52 from black.files import wrap_stream_for_windows
53 from black.parsing import InvalidInput  # noqa F401
54 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
55 from black.handle_ipynb_magics import (
56     mask_cell,
57     unmask_cell,
58     remove_trailing_semicolon,
59     put_trailing_semicolon_back,
60     TRANSFORMED_MAGICS,
61     PYTHON_CELL_MAGICS,
62     jupyter_dependencies_are_installed,
63 )
64
65
66 # lib2to3 fork
67 from blib2to3.pytree import Node, Leaf
68 from blib2to3.pgen2 import token
69
70 from _black_version import version as __version__
71
72 COMPILED = Path(__file__).suffix in (".pyd", ".so")
73
74 # types
75 FileContent = str
76 Encoding = str
77 NewLine = str
78
79
80 class WriteBack(Enum):
81     NO = 0
82     YES = 1
83     DIFF = 2
84     CHECK = 3
85     COLOR_DIFF = 4
86
87     @classmethod
88     def from_configuration(
89         cls, *, check: bool, diff: bool, color: bool = False
90     ) -> "WriteBack":
91         if check and not diff:
92             return cls.CHECK
93
94         if diff and color:
95             return cls.COLOR_DIFF
96
97         return cls.DIFF if diff else cls.YES
98
99
100 # Legacy name, left for integrations.
101 FileMode = Mode
102
103 DEFAULT_WORKERS = os.cpu_count()
104
105
106 def read_pyproject_toml(
107     ctx: click.Context, param: click.Parameter, value: Optional[str]
108 ) -> Optional[str]:
109     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
110
111     Returns the path to a successfully found and read configuration file, None
112     otherwise.
113     """
114     if not value:
115         value = find_pyproject_toml(ctx.params.get("src", ()))
116         if value is None:
117             return None
118
119     try:
120         config = parse_pyproject_toml(value)
121     except (OSError, ValueError) as e:
122         raise click.FileError(
123             filename=value, hint=f"Error reading configuration file: {e}"
124         ) from None
125
126     if not config:
127         return None
128     else:
129         # Sanitize the values to be Click friendly. For more information please see:
130         # https://github.com/psf/black/issues/1458
131         # https://github.com/pallets/click/issues/1567
132         config = {
133             k: str(v) if not isinstance(v, (list, dict)) else v
134             for k, v in config.items()
135         }
136
137     target_version = config.get("target_version")
138     if target_version is not None and not isinstance(target_version, list):
139         raise click.BadOptionUsage(
140             "target-version", "Config key target-version must be a list"
141         )
142
143     default_map: Dict[str, Any] = {}
144     if ctx.default_map:
145         default_map.update(ctx.default_map)
146     default_map.update(config)
147
148     ctx.default_map = default_map
149     return value
150
151
152 def target_version_option_callback(
153     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
154 ) -> List[TargetVersion]:
155     """Compute the target versions from a --target-version flag.
156
157     This is its own function because mypy couldn't infer the type correctly
158     when it was a lambda, causing mypyc trouble.
159     """
160     return [TargetVersion[val.upper()] for val in v]
161
162
163 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
164     """Compile a regular expression string in `regex`.
165
166     If it contains newlines, use verbose mode.
167     """
168     if "\n" in regex:
169         regex = "(?x)" + regex
170     compiled: Pattern[str] = re.compile(regex)
171     return compiled
172
173
174 def validate_regex(
175     ctx: click.Context,
176     param: click.Parameter,
177     value: Optional[str],
178 ) -> Optional[Pattern[str]]:
179     try:
180         return re_compile_maybe_verbose(value) if value is not None else None
181     except re.error as e:
182         raise click.BadParameter(f"Not a valid regular expression: {e}") from None
183
184
185 @click.command(
186     context_settings={"help_option_names": ["-h", "--help"]},
187     # While Click does set this field automatically using the docstring, mypyc
188     # (annoyingly) strips 'em so we need to set it here too.
189     help="The uncompromising code formatter.",
190 )
191 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
192 @click.option(
193     "-l",
194     "--line-length",
195     type=int,
196     default=DEFAULT_LINE_LENGTH,
197     help="How many characters per line to allow.",
198     show_default=True,
199 )
200 @click.option(
201     "-t",
202     "--target-version",
203     type=click.Choice([v.name.lower() for v in TargetVersion]),
204     callback=target_version_option_callback,
205     multiple=True,
206     help=(
207         "Python versions that should be supported by Black's output. [default: per-file"
208         " auto-detection]"
209     ),
210 )
211 @click.option(
212     "--pyi",
213     is_flag=True,
214     help=(
215         "Format all input files like typing stubs regardless of file extension (useful"
216         " when piping source on standard input)."
217     ),
218 )
219 @click.option(
220     "--ipynb",
221     is_flag=True,
222     help=(
223         "Format all input files like Jupyter Notebooks regardless of file extension "
224         "(useful when piping source on standard input)."
225     ),
226 )
227 @click.option(
228     "-S",
229     "--skip-string-normalization",
230     is_flag=True,
231     help="Don't normalize string quotes or prefixes.",
232 )
233 @click.option(
234     "-C",
235     "--skip-magic-trailing-comma",
236     is_flag=True,
237     help="Don't use trailing commas as a reason to split lines.",
238 )
239 @click.option(
240     "--experimental-string-processing",
241     is_flag=True,
242     hidden=True,
243     help=(
244         "Experimental option that performs more normalization on string literals."
245         " Currently disabled because it leads to some crashes."
246     ),
247 )
248 @click.option(
249     "--check",
250     is_flag=True,
251     help=(
252         "Don't write the files back, just return the status. Return code 0 means"
253         " nothing would change. Return code 1 means some files would be reformatted."
254         " Return code 123 means there was an internal error."
255     ),
256 )
257 @click.option(
258     "--diff",
259     is_flag=True,
260     help="Don't write the files back, just output a diff for each file on stdout.",
261 )
262 @click.option(
263     "--color/--no-color",
264     is_flag=True,
265     help="Show colored diff. Only applies when `--diff` is given.",
266 )
267 @click.option(
268     "--fast/--safe",
269     is_flag=True,
270     help="If --fast given, skip temporary sanity checks. [default: --safe]",
271 )
272 @click.option(
273     "--required-version",
274     type=str,
275     help=(
276         "Require a specific version of Black to be running (useful for unifying results"
277         " across many environments e.g. with a pyproject.toml file)."
278     ),
279 )
280 @click.option(
281     "--include",
282     type=str,
283     default=DEFAULT_INCLUDES,
284     callback=validate_regex,
285     help=(
286         "A regular expression that matches files and directories that should be"
287         " included on recursive searches. An empty value means all files are included"
288         " regardless of the name. Use forward slashes for directories on all platforms"
289         " (Windows, too). Exclusions are calculated first, inclusions later."
290     ),
291     show_default=True,
292 )
293 @click.option(
294     "--exclude",
295     type=str,
296     callback=validate_regex,
297     help=(
298         "A regular expression that matches files and directories that should be"
299         " excluded on recursive searches. An empty value means no paths are excluded."
300         " Use forward slashes for directories on all platforms (Windows, too)."
301         " Exclusions are calculated first, inclusions later. [default:"
302         f" {DEFAULT_EXCLUDES}]"
303     ),
304     show_default=False,
305 )
306 @click.option(
307     "--extend-exclude",
308     type=str,
309     callback=validate_regex,
310     help=(
311         "Like --exclude, but adds additional files and directories on top of the"
312         " excluded ones. (Useful if you simply want to add to the default)"
313     ),
314 )
315 @click.option(
316     "--force-exclude",
317     type=str,
318     callback=validate_regex,
319     help=(
320         "Like --exclude, but files and directories matching this regex will be "
321         "excluded even when they are passed explicitly as arguments."
322     ),
323 )
324 @click.option(
325     "--stdin-filename",
326     type=str,
327     help=(
328         "The name of the file when passing it through stdin. Useful to make "
329         "sure Black will respect --force-exclude option on some "
330         "editors that rely on using stdin."
331     ),
332 )
333 @click.option(
334     "-W",
335     "--workers",
336     type=click.IntRange(min=1),
337     default=DEFAULT_WORKERS,
338     show_default=True,
339     help="Number of parallel workers",
340 )
341 @click.option(
342     "-q",
343     "--quiet",
344     is_flag=True,
345     help=(
346         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
347         " those with 2>/dev/null."
348     ),
349 )
350 @click.option(
351     "-v",
352     "--verbose",
353     is_flag=True,
354     help=(
355         "Also emit messages to stderr about files that were not changed or were ignored"
356         " due to exclusion patterns."
357     ),
358 )
359 @click.version_option(
360     version=__version__,
361     message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
362 )
363 @click.argument(
364     "src",
365     nargs=-1,
366     type=click.Path(
367         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
368     ),
369     is_eager=True,
370     metavar="SRC ...",
371 )
372 @click.option(
373     "--config",
374     type=click.Path(
375         exists=True,
376         file_okay=True,
377         dir_okay=False,
378         readable=True,
379         allow_dash=False,
380         path_type=str,
381     ),
382     is_eager=True,
383     callback=read_pyproject_toml,
384     help="Read configuration from FILE path.",
385 )
386 @click.pass_context
387 def main(
388     ctx: click.Context,
389     code: Optional[str],
390     line_length: int,
391     target_version: List[TargetVersion],
392     check: bool,
393     diff: bool,
394     color: bool,
395     fast: bool,
396     pyi: bool,
397     ipynb: bool,
398     skip_string_normalization: bool,
399     skip_magic_trailing_comma: bool,
400     experimental_string_processing: bool,
401     quiet: bool,
402     verbose: bool,
403     required_version: Optional[str],
404     include: Pattern[str],
405     exclude: Optional[Pattern[str]],
406     extend_exclude: Optional[Pattern[str]],
407     force_exclude: Optional[Pattern[str]],
408     stdin_filename: Optional[str],
409     workers: int,
410     src: Tuple[str, ...],
411     config: Optional[str],
412 ) -> None:
413     """The uncompromising code formatter."""
414     if config and verbose:
415         out(f"Using configuration from {config}.", bold=False, fg="blue")
416
417     error_msg = "Oh no! 💥 💔 💥"
418     if required_version and required_version != __version__:
419         err(
420             f"{error_msg} The required version `{required_version}` does not match"
421             f" the running version `{__version__}`!"
422         )
423         ctx.exit(1)
424     if ipynb and pyi:
425         err("Cannot pass both `pyi` and `ipynb` flags!")
426         ctx.exit(1)
427
428     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
429     if target_version:
430         versions = set(target_version)
431     else:
432         # We'll autodetect later.
433         versions = set()
434     mode = Mode(
435         target_versions=versions,
436         line_length=line_length,
437         is_pyi=pyi,
438         is_ipynb=ipynb,
439         string_normalization=not skip_string_normalization,
440         magic_trailing_comma=not skip_magic_trailing_comma,
441         experimental_string_processing=experimental_string_processing,
442     )
443
444     if code is not None:
445         # Run in quiet mode by default with -c; the extra output isn't useful.
446         # You can still pass -v to get verbose output.
447         quiet = True
448
449     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
450
451     if code is not None:
452         reformat_code(
453             content=code, fast=fast, write_back=write_back, mode=mode, report=report
454         )
455     else:
456         try:
457             sources = get_sources(
458                 ctx=ctx,
459                 src=src,
460                 quiet=quiet,
461                 verbose=verbose,
462                 include=include,
463                 exclude=exclude,
464                 extend_exclude=extend_exclude,
465                 force_exclude=force_exclude,
466                 report=report,
467                 stdin_filename=stdin_filename,
468             )
469         except GitWildMatchPatternError:
470             ctx.exit(1)
471
472         path_empty(
473             sources,
474             "No Python files are present to be formatted. Nothing to do 😴",
475             quiet,
476             verbose,
477             ctx,
478         )
479
480         if len(sources) == 1:
481             reformat_one(
482                 src=sources.pop(),
483                 fast=fast,
484                 write_back=write_back,
485                 mode=mode,
486                 report=report,
487             )
488         else:
489             reformat_many(
490                 sources=sources,
491                 fast=fast,
492                 write_back=write_back,
493                 mode=mode,
494                 report=report,
495                 workers=workers,
496             )
497
498     if verbose or not quiet:
499         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
500         if code is None:
501             click.echo(str(report), err=True)
502     ctx.exit(report.return_code)
503
504
505 def get_sources(
506     *,
507     ctx: click.Context,
508     src: Tuple[str, ...],
509     quiet: bool,
510     verbose: bool,
511     include: Pattern[str],
512     exclude: Optional[Pattern[str]],
513     extend_exclude: Optional[Pattern[str]],
514     force_exclude: Optional[Pattern[str]],
515     report: "Report",
516     stdin_filename: Optional[str],
517 ) -> Set[Path]:
518     """Compute the set of files to be formatted."""
519
520     root = find_project_root(src)
521     sources: Set[Path] = set()
522     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
523
524     if exclude is None:
525         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
526         gitignore = get_gitignore(root)
527     else:
528         gitignore = None
529
530     for s in src:
531         if s == "-" and stdin_filename:
532             p = Path(stdin_filename)
533             is_stdin = True
534         else:
535             p = Path(s)
536             is_stdin = False
537
538         if is_stdin or p.is_file():
539             normalized_path = normalize_path_maybe_ignore(p, root, report)
540             if normalized_path is None:
541                 continue
542
543             normalized_path = "/" + normalized_path
544             # Hard-exclude any files that matches the `--force-exclude` regex.
545             if force_exclude:
546                 force_exclude_match = force_exclude.search(normalized_path)
547             else:
548                 force_exclude_match = None
549             if force_exclude_match and force_exclude_match.group(0):
550                 report.path_ignored(p, "matches the --force-exclude regular expression")
551                 continue
552
553             if is_stdin:
554                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
555
556             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
557                 verbose=verbose, quiet=quiet
558             ):
559                 continue
560
561             sources.add(p)
562         elif p.is_dir():
563             sources.update(
564                 gen_python_files(
565                     p.iterdir(),
566                     root,
567                     include,
568                     exclude,
569                     extend_exclude,
570                     force_exclude,
571                     report,
572                     gitignore,
573                     verbose=verbose,
574                     quiet=quiet,
575                 )
576             )
577         elif s == "-":
578             sources.add(p)
579         else:
580             err(f"invalid path: {s}")
581     return sources
582
583
584 def path_empty(
585     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
586 ) -> None:
587     """
588     Exit if there is no `src` provided for formatting
589     """
590     if not src:
591         if verbose or not quiet:
592             out(msg)
593         ctx.exit(0)
594
595
596 def reformat_code(
597     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
598 ) -> None:
599     """
600     Reformat and print out `content` without spawning child processes.
601     Similar to `reformat_one`, but for string content.
602
603     `fast`, `write_back`, and `mode` options are passed to
604     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
605     """
606     path = Path("<string>")
607     try:
608         changed = Changed.NO
609         if format_stdin_to_stdout(
610             content=content, fast=fast, write_back=write_back, mode=mode
611         ):
612             changed = Changed.YES
613         report.done(path, changed)
614     except Exception as exc:
615         if report.verbose:
616             traceback.print_exc()
617         report.failed(path, str(exc))
618
619
620 def reformat_one(
621     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
622 ) -> None:
623     """Reformat a single file under `src` without spawning child processes.
624
625     `fast`, `write_back`, and `mode` options are passed to
626     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
627     """
628     try:
629         changed = Changed.NO
630
631         if str(src) == "-":
632             is_stdin = True
633         elif str(src).startswith(STDIN_PLACEHOLDER):
634             is_stdin = True
635             # Use the original name again in case we want to print something
636             # to the user
637             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
638         else:
639             is_stdin = False
640
641         if is_stdin:
642             if src.suffix == ".pyi":
643                 mode = replace(mode, is_pyi=True)
644             elif src.suffix == ".ipynb":
645                 mode = replace(mode, is_ipynb=True)
646             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
647                 changed = Changed.YES
648         else:
649             cache: Cache = {}
650             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
651                 cache = read_cache(mode)
652                 res_src = src.resolve()
653                 res_src_s = str(res_src)
654                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
655                     changed = Changed.CACHED
656             if changed is not Changed.CACHED and format_file_in_place(
657                 src, fast=fast, write_back=write_back, mode=mode
658             ):
659                 changed = Changed.YES
660             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
661                 write_back is WriteBack.CHECK and changed is Changed.NO
662             ):
663                 write_cache(cache, [src], mode)
664         report.done(src, changed)
665     except Exception as exc:
666         if report.verbose:
667             traceback.print_exc()
668         report.failed(src, str(exc))
669
670
671 # diff-shades depends on being to monkeypatch this function to operate. I know it's
672 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
673 @mypyc_attr(patchable=True)
674 def reformat_many(
675     sources: Set[Path],
676     fast: bool,
677     write_back: WriteBack,
678     mode: Mode,
679     report: "Report",
680     workers: Optional[int],
681 ) -> None:
682     """Reformat multiple files using a ProcessPoolExecutor."""
683     executor: Executor
684     loop = asyncio.get_event_loop()
685     worker_count = workers if workers is not None else DEFAULT_WORKERS
686     if sys.platform == "win32":
687         # Work around https://bugs.python.org/issue26903
688         assert worker_count is not None
689         worker_count = min(worker_count, 60)
690     try:
691         executor = ProcessPoolExecutor(max_workers=worker_count)
692     except (ImportError, NotImplementedError, OSError):
693         # we arrive here if the underlying system does not support multi-processing
694         # like in AWS Lambda or Termux, in which case we gracefully fallback to
695         # a ThreadPoolExecutor with just a single worker (more workers would not do us
696         # any good due to the Global Interpreter Lock)
697         executor = ThreadPoolExecutor(max_workers=1)
698
699     try:
700         loop.run_until_complete(
701             schedule_formatting(
702                 sources=sources,
703                 fast=fast,
704                 write_back=write_back,
705                 mode=mode,
706                 report=report,
707                 loop=loop,
708                 executor=executor,
709             )
710         )
711     finally:
712         shutdown(loop)
713         if executor is not None:
714             executor.shutdown()
715
716
717 async def schedule_formatting(
718     sources: Set[Path],
719     fast: bool,
720     write_back: WriteBack,
721     mode: Mode,
722     report: "Report",
723     loop: asyncio.AbstractEventLoop,
724     executor: Executor,
725 ) -> None:
726     """Run formatting of `sources` in parallel using the provided `executor`.
727
728     (Use ProcessPoolExecutors for actual parallelism.)
729
730     `write_back`, `fast`, and `mode` options are passed to
731     :func:`format_file_in_place`.
732     """
733     cache: Cache = {}
734     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
735         cache = read_cache(mode)
736         sources, cached = filter_cached(cache, sources)
737         for src in sorted(cached):
738             report.done(src, Changed.CACHED)
739     if not sources:
740         return
741
742     cancelled = []
743     sources_to_cache = []
744     lock = None
745     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
746         # For diff output, we need locks to ensure we don't interleave output
747         # from different processes.
748         manager = Manager()
749         lock = manager.Lock()
750     tasks = {
751         asyncio.ensure_future(
752             loop.run_in_executor(
753                 executor, format_file_in_place, src, fast, mode, write_back, lock
754             )
755         ): src
756         for src in sorted(sources)
757     }
758     pending = tasks.keys()
759     try:
760         loop.add_signal_handler(signal.SIGINT, cancel, pending)
761         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
762     except NotImplementedError:
763         # There are no good alternatives for these on Windows.
764         pass
765     while pending:
766         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
767         for task in done:
768             src = tasks.pop(task)
769             if task.cancelled():
770                 cancelled.append(task)
771             elif task.exception():
772                 report.failed(src, str(task.exception()))
773             else:
774                 changed = Changed.YES if task.result() else Changed.NO
775                 # If the file was written back or was successfully checked as
776                 # well-formatted, store this information in the cache.
777                 if write_back is WriteBack.YES or (
778                     write_back is WriteBack.CHECK and changed is Changed.NO
779                 ):
780                     sources_to_cache.append(src)
781                 report.done(src, changed)
782     if cancelled:
783         if sys.version_info >= (3, 7):
784             await asyncio.gather(*cancelled, return_exceptions=True)
785         else:
786             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
787     if sources_to_cache:
788         write_cache(cache, sources_to_cache, mode)
789
790
791 def format_file_in_place(
792     src: Path,
793     fast: bool,
794     mode: Mode,
795     write_back: WriteBack = WriteBack.NO,
796     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
797 ) -> bool:
798     """Format file under `src` path. Return True if changed.
799
800     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
801     code to the file.
802     `mode` and `fast` options are passed to :func:`format_file_contents`.
803     """
804     if src.suffix == ".pyi":
805         mode = replace(mode, is_pyi=True)
806     elif src.suffix == ".ipynb":
807         mode = replace(mode, is_ipynb=True)
808
809     then = datetime.utcfromtimestamp(src.stat().st_mtime)
810     with open(src, "rb") as buf:
811         src_contents, encoding, newline = decode_bytes(buf.read())
812     try:
813         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
814     except NothingChanged:
815         return False
816     except JSONDecodeError:
817         raise ValueError(
818             f"File '{src}' cannot be parsed as valid Jupyter notebook."
819         ) from None
820
821     if write_back == WriteBack.YES:
822         with open(src, "w", encoding=encoding, newline=newline) as f:
823             f.write(dst_contents)
824     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
825         now = datetime.utcnow()
826         src_name = f"{src}\t{then} +0000"
827         dst_name = f"{src}\t{now} +0000"
828         if mode.is_ipynb:
829             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
830         else:
831             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
832
833         if write_back == WriteBack.COLOR_DIFF:
834             diff_contents = color_diff(diff_contents)
835
836         with lock or nullcontext():
837             f = io.TextIOWrapper(
838                 sys.stdout.buffer,
839                 encoding=encoding,
840                 newline=newline,
841                 write_through=True,
842             )
843             f = wrap_stream_for_windows(f)
844             f.write(diff_contents)
845             f.detach()
846
847     return True
848
849
850 def format_stdin_to_stdout(
851     fast: bool,
852     *,
853     content: Optional[str] = None,
854     write_back: WriteBack = WriteBack.NO,
855     mode: Mode,
856 ) -> bool:
857     """Format file on stdin. Return True if changed.
858
859     If content is None, it's read from sys.stdin.
860
861     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
862     write a diff to stdout. The `mode` argument is passed to
863     :func:`format_file_contents`.
864     """
865     then = datetime.utcnow()
866
867     if content is None:
868         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
869     else:
870         src, encoding, newline = content, "utf-8", ""
871
872     dst = src
873     try:
874         dst = format_file_contents(src, fast=fast, mode=mode)
875         return True
876
877     except NothingChanged:
878         return False
879
880     finally:
881         f = io.TextIOWrapper(
882             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
883         )
884         if write_back == WriteBack.YES:
885             # Make sure there's a newline after the content
886             if dst and dst[-1] != "\n":
887                 dst += "\n"
888             f.write(dst)
889         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
890             now = datetime.utcnow()
891             src_name = f"STDIN\t{then} +0000"
892             dst_name = f"STDOUT\t{now} +0000"
893             d = diff(src, dst, src_name, dst_name)
894             if write_back == WriteBack.COLOR_DIFF:
895                 d = color_diff(d)
896                 f = wrap_stream_for_windows(f)
897             f.write(d)
898         f.detach()
899
900
901 def check_stability_and_equivalence(
902     src_contents: str, dst_contents: str, *, mode: Mode
903 ) -> None:
904     """Perform stability and equivalence checks.
905
906     Raise AssertionError if source and destination contents are not
907     equivalent, or if a second pass of the formatter would format the
908     content differently.
909     """
910     assert_equivalent(src_contents, dst_contents)
911
912     # Forced second pass to work around optional trailing commas (becoming
913     # forced trailing commas on pass 2) interacting differently with optional
914     # parentheses.  Admittedly ugly.
915     dst_contents_pass2 = format_str(dst_contents, mode=mode)
916     if dst_contents != dst_contents_pass2:
917         dst_contents = dst_contents_pass2
918         assert_equivalent(src_contents, dst_contents, pass_num=2)
919         assert_stable(src_contents, dst_contents, mode=mode)
920     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
921     # the same as `dst_contents_pass2`.
922
923
924 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
925     """Reformat contents of a file and return new contents.
926
927     If `fast` is False, additionally confirm that the reformatted code is
928     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
929     `mode` is passed to :func:`format_str`.
930     """
931     if not src_contents.strip():
932         raise NothingChanged
933
934     if mode.is_ipynb:
935         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
936     else:
937         dst_contents = format_str(src_contents, mode=mode)
938     if src_contents == dst_contents:
939         raise NothingChanged
940
941     if not fast and not mode.is_ipynb:
942         # Jupyter notebooks will already have been checked above.
943         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
944     return dst_contents
945
946
947 def validate_cell(src: str) -> None:
948     """Check that cell does not already contain TransformerManager transformations,
949     or non-Python cell magics, which might cause tokenizer_rt to break because of
950     indentations.
951
952     If a cell contains ``!ls``, then it'll be transformed to
953     ``get_ipython().system('ls')``. However, if the cell originally contained
954     ``get_ipython().system('ls')``, then it would get transformed in the same way:
955
956         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
957         "get_ipython().system('ls')\n"
958         >>> TransformerManager().transform_cell("!ls")
959         "get_ipython().system('ls')\n"
960
961     Due to the impossibility of safely roundtripping in such situations, cells
962     containing transformed magics will be ignored.
963     """
964     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
965         raise NothingChanged
966     if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS:
967         raise NothingChanged
968
969
970 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
971     """Format code in given cell of Jupyter notebook.
972
973     General idea is:
974
975       - if cell has trailing semicolon, remove it;
976       - if cell has IPython magics, mask them;
977       - format cell;
978       - reinstate IPython magics;
979       - reinstate trailing semicolon (if originally present);
980       - strip trailing newlines.
981
982     Cells with syntax errors will not be processed, as they
983     could potentially be automagics or multi-line magics, which
984     are currently not supported.
985     """
986     validate_cell(src)
987     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
988         src
989     )
990     try:
991         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
992     except SyntaxError:
993         raise NothingChanged from None
994     masked_dst = format_str(masked_src, mode=mode)
995     if not fast:
996         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
997     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
998     dst = put_trailing_semicolon_back(
999         dst_without_trailing_semicolon, has_trailing_semicolon
1000     )
1001     dst = dst.rstrip("\n")
1002     if dst == src:
1003         raise NothingChanged from None
1004     return dst
1005
1006
1007 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
1008     """If notebook is marked as non-Python, don't format it.
1009
1010     All notebook metadata fields are optional, see
1011     https://nbformat.readthedocs.io/en/latest/format_description.html. So
1012     if a notebook has empty metadata, we will try to parse it anyway.
1013     """
1014     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
1015     if language is not None and language != "python":
1016         raise NothingChanged from None
1017
1018
1019 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
1020     """Format Jupyter notebook.
1021
1022     Operate cell-by-cell, only on code cells, only for Python notebooks.
1023     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1024     """
1025     trailing_newline = src_contents[-1] == "\n"
1026     modified = False
1027     nb = json.loads(src_contents)
1028     validate_metadata(nb)
1029     for cell in nb["cells"]:
1030         if cell.get("cell_type", None) == "code":
1031             try:
1032                 src = "".join(cell["source"])
1033                 dst = format_cell(src, fast=fast, mode=mode)
1034             except NothingChanged:
1035                 pass
1036             else:
1037                 cell["source"] = dst.splitlines(keepends=True)
1038                 modified = True
1039     if modified:
1040         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1041         if trailing_newline:
1042             dst_contents = dst_contents + "\n"
1043         return dst_contents
1044     else:
1045         raise NothingChanged
1046
1047
1048 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1049     """Reformat a string and return new contents.
1050
1051     `mode` determines formatting options, such as how many characters per line are
1052     allowed.  Example:
1053
1054     >>> import black
1055     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1056     def f(arg: str = "") -> None:
1057         ...
1058
1059     A more complex example:
1060
1061     >>> print(
1062     ...   black.format_str(
1063     ...     "def f(arg:str='')->None: hey",
1064     ...     mode=black.Mode(
1065     ...       target_versions={black.TargetVersion.PY36},
1066     ...       line_length=10,
1067     ...       string_normalization=False,
1068     ...       is_pyi=False,
1069     ...     ),
1070     ...   ),
1071     ... )
1072     def f(
1073         arg: str = '',
1074     ) -> None:
1075         hey
1076
1077     """
1078     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1079     dst_contents = []
1080     future_imports = get_future_imports(src_node)
1081     if mode.target_versions:
1082         versions = mode.target_versions
1083     else:
1084         versions = detect_target_versions(src_node, future_imports=future_imports)
1085
1086     # TODO: fully drop support and this code hopefully in January 2022 :D
1087     if TargetVersion.PY27 in mode.target_versions or versions == {TargetVersion.PY27}:
1088         msg = (
1089             "DEPRECATION: Python 2 support will be removed in the first stable release "
1090             "expected in January 2022."
1091         )
1092         err(msg, fg="yellow", bold=True)
1093
1094     normalize_fmt_off(src_node)
1095     lines = LineGenerator(
1096         mode=mode,
1097         remove_u_prefix="unicode_literals" in future_imports
1098         or supports_feature(versions, Feature.UNICODE_LITERALS),
1099     )
1100     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1101     empty_line = Line(mode=mode)
1102     after = 0
1103     split_line_features = {
1104         feature
1105         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1106         if supports_feature(versions, feature)
1107     }
1108     for current_line in lines.visit(src_node):
1109         dst_contents.append(str(empty_line) * after)
1110         before, after = elt.maybe_empty_lines(current_line)
1111         dst_contents.append(str(empty_line) * before)
1112         for line in transform_line(
1113             current_line, mode=mode, features=split_line_features
1114         ):
1115             dst_contents.append(str(line))
1116     return "".join(dst_contents)
1117
1118
1119 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1120     """Return a tuple of (decoded_contents, encoding, newline).
1121
1122     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1123     universal newlines (i.e. only contains LF).
1124     """
1125     srcbuf = io.BytesIO(src)
1126     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1127     if not lines:
1128         return "", encoding, "\n"
1129
1130     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1131     srcbuf.seek(0)
1132     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1133         return tiow.read(), encoding, newline
1134
1135
1136 def get_features_used(  # noqa: C901
1137     node: Node, *, future_imports: Optional[Set[str]] = None
1138 ) -> Set[Feature]:
1139     """Return a set of (relatively) new Python features used in this file.
1140
1141     Currently looking for:
1142     - f-strings;
1143     - underscores in numeric literals;
1144     - trailing commas after * or ** in function signatures and calls;
1145     - positional only arguments in function signatures and lambdas;
1146     - assignment expression;
1147     - relaxed decorator syntax;
1148     - usage of __future__ flags (annotations);
1149     - print / exec statements;
1150     """
1151     features: Set[Feature] = set()
1152     if future_imports:
1153         features |= {
1154             FUTURE_FLAG_TO_FEATURE[future_import]
1155             for future_import in future_imports
1156             if future_import in FUTURE_FLAG_TO_FEATURE
1157         }
1158
1159     for n in node.pre_order():
1160         if is_string_token(n):
1161             value_head = n.value[:2]
1162             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1163                 features.add(Feature.F_STRINGS)
1164
1165         elif n.type == token.NUMBER:
1166             assert isinstance(n, Leaf)
1167             if "_" in n.value:
1168                 features.add(Feature.NUMERIC_UNDERSCORES)
1169             elif n.value.endswith(("L", "l")):
1170                 # Python 2: 10L
1171                 features.add(Feature.LONG_INT_LITERAL)
1172             elif len(n.value) >= 2 and n.value[0] == "0" and n.value[1].isdigit():
1173                 # Python 2: 0123; 00123; ...
1174                 if not all(char == "0" for char in n.value):
1175                     # although we don't want to match 0000 or similar
1176                     features.add(Feature.OCTAL_INT_LITERAL)
1177
1178         elif n.type == token.SLASH:
1179             if n.parent and n.parent.type in {
1180                 syms.typedargslist,
1181                 syms.arglist,
1182                 syms.varargslist,
1183             }:
1184                 features.add(Feature.POS_ONLY_ARGUMENTS)
1185
1186         elif n.type == token.COLONEQUAL:
1187             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1188
1189         elif n.type == syms.decorator:
1190             if len(n.children) > 1 and not is_simple_decorator_expression(
1191                 n.children[1]
1192             ):
1193                 features.add(Feature.RELAXED_DECORATORS)
1194
1195         elif (
1196             n.type in {syms.typedargslist, syms.arglist}
1197             and n.children
1198             and n.children[-1].type == token.COMMA
1199         ):
1200             if n.type == syms.typedargslist:
1201                 feature = Feature.TRAILING_COMMA_IN_DEF
1202             else:
1203                 feature = Feature.TRAILING_COMMA_IN_CALL
1204
1205             for ch in n.children:
1206                 if ch.type in STARS:
1207                     features.add(feature)
1208
1209                 if ch.type == syms.argument:
1210                     for argch in ch.children:
1211                         if argch.type in STARS:
1212                             features.add(feature)
1213
1214         elif (
1215             n.type in {syms.return_stmt, syms.yield_expr}
1216             and len(n.children) >= 2
1217             and n.children[1].type == syms.testlist_star_expr
1218             and any(child.type == syms.star_expr for child in n.children[1].children)
1219         ):
1220             features.add(Feature.UNPACKING_ON_FLOW)
1221
1222         elif (
1223             n.type == syms.annassign
1224             and len(n.children) >= 4
1225             and n.children[3].type == syms.testlist_star_expr
1226         ):
1227             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
1228
1229         # Python 2 only features (for its deprecation) except for integers, see above
1230         elif n.type == syms.print_stmt:
1231             features.add(Feature.PRINT_STMT)
1232         elif n.type == syms.exec_stmt:
1233             features.add(Feature.EXEC_STMT)
1234         elif n.type == syms.tfpdef:
1235             # def set_position((x, y), value):
1236             #     ...
1237             features.add(Feature.AUTOMATIC_PARAMETER_UNPACKING)
1238         elif n.type == syms.except_clause:
1239             # try:
1240             #     ...
1241             # except Exception, err:
1242             #     ...
1243             if len(n.children) >= 4:
1244                 if n.children[-2].type == token.COMMA:
1245                     features.add(Feature.COMMA_STYLE_EXCEPT)
1246         elif n.type == syms.raise_stmt:
1247             # raise Exception, "msg"
1248             if len(n.children) >= 4:
1249                 if n.children[-2].type == token.COMMA:
1250                     features.add(Feature.COMMA_STYLE_RAISE)
1251         elif n.type == token.BACKQUOTE:
1252             # `i'm surprised this ever existed`
1253             features.add(Feature.BACKQUOTE_REPR)
1254
1255     return features
1256
1257
1258 def detect_target_versions(
1259     node: Node, *, future_imports: Optional[Set[str]] = None
1260 ) -> Set[TargetVersion]:
1261     """Detect the version to target based on the nodes used."""
1262     features = get_features_used(node, future_imports=future_imports)
1263     return {
1264         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1265     }
1266
1267
1268 def get_future_imports(node: Node) -> Set[str]:
1269     """Return a set of __future__ imports in the file."""
1270     imports: Set[str] = set()
1271
1272     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1273         for child in children:
1274             if isinstance(child, Leaf):
1275                 if child.type == token.NAME:
1276                     yield child.value
1277
1278             elif child.type == syms.import_as_name:
1279                 orig_name = child.children[0]
1280                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1281                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1282                 yield orig_name.value
1283
1284             elif child.type == syms.import_as_names:
1285                 yield from get_imports_from_children(child.children)
1286
1287             else:
1288                 raise AssertionError("Invalid syntax parsing imports")
1289
1290     for child in node.children:
1291         if child.type != syms.simple_stmt:
1292             break
1293
1294         first_child = child.children[0]
1295         if isinstance(first_child, Leaf):
1296             # Continue looking if we see a docstring; otherwise stop.
1297             if (
1298                 len(child.children) == 2
1299                 and first_child.type == token.STRING
1300                 and child.children[1].type == token.NEWLINE
1301             ):
1302                 continue
1303
1304             break
1305
1306         elif first_child.type == syms.import_from:
1307             module_name = first_child.children[1]
1308             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1309                 break
1310
1311             imports |= set(get_imports_from_children(first_child.children[3:]))
1312         else:
1313             break
1314
1315     return imports
1316
1317
1318 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1319     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1320     try:
1321         src_ast = parse_ast(src)
1322     except Exception as exc:
1323         raise AssertionError(
1324             f"cannot use --safe with this file; failed to parse source file: {exc}"
1325         ) from exc
1326
1327     try:
1328         dst_ast = parse_ast(dst)
1329     except Exception as exc:
1330         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1331         raise AssertionError(
1332             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1333             "Please report a bug on https://github.com/psf/black/issues.  "
1334             f"This invalid output might be helpful: {log}"
1335         ) from None
1336
1337     src_ast_str = "\n".join(stringify_ast(src_ast))
1338     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1339     if src_ast_str != dst_ast_str:
1340         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1341         raise AssertionError(
1342             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1343             f" source on pass {pass_num}.  Please report a bug on "
1344             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1345         ) from None
1346
1347
1348 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1349     """Raise AssertionError if `dst` reformats differently the second time."""
1350     newdst = format_str(dst, mode=mode)
1351     if dst != newdst:
1352         log = dump_to_file(
1353             str(mode),
1354             diff(src, dst, "source", "first pass"),
1355             diff(dst, newdst, "first pass", "second pass"),
1356         )
1357         raise AssertionError(
1358             "INTERNAL ERROR: Black produced different code on the second pass of the"
1359             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1360             f"  This diff might be helpful: {log}"
1361         ) from None
1362
1363
1364 @contextmanager
1365 def nullcontext() -> Iterator[None]:
1366     """Return an empty context manager.
1367
1368     To be used like `nullcontext` in Python 3.7.
1369     """
1370     yield
1371
1372
1373 def patch_click() -> None:
1374     """Make Click not crash on Python 3.6 with LANG=C.
1375
1376     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1377     default which restricts paths that it can access during the lifetime of the
1378     application.  Click refuses to work in this scenario by raising a RuntimeError.
1379
1380     In case of Black the likelihood that non-ASCII characters are going to be used in
1381     file paths is minimal since it's Python source code.  Moreover, this crash was
1382     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1383     """
1384     try:
1385         from click import core
1386         from click import _unicodefun
1387     except ModuleNotFoundError:
1388         return
1389
1390     for module in (core, _unicodefun):
1391         if hasattr(module, "_verify_python3_env"):
1392             module._verify_python3_env = lambda: None  # type: ignore
1393         if hasattr(module, "_verify_python_env"):
1394             module._verify_python_env = lambda: None  # type: ignore
1395
1396
1397 def patched_main() -> None:
1398     maybe_install_uvloop()
1399     freeze_support()
1400     patch_click()
1401     main()
1402
1403
1404 if __name__ == "__main__":
1405     patched_main()