]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump required aiohttp version to 3.7.4 (#2509)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import regex as re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 from dataclasses import replace
34 import click
35
36 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
37 from black.const import STDIN_PLACEHOLDER
38 from black.nodes import STARS, syms, is_simple_decorator_expression
39 from black.lines import Line, EmptyLineTracker
40 from black.linegen import transform_line, LineGenerator, LN
41 from black.comments import normalize_fmt_off
42 from black.mode import Mode, TargetVersion
43 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
44 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
45 from black.concurrency import cancel, shutdown, maybe_install_uvloop
46 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
47 from black.report import Report, Changed, NothingChanged
48 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
49 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
50 from black.files import wrap_stream_for_windows
51 from black.parsing import InvalidInput  # noqa F401
52 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
53 from black.handle_ipynb_magics import (
54     mask_cell,
55     unmask_cell,
56     remove_trailing_semicolon,
57     put_trailing_semicolon_back,
58     TRANSFORMED_MAGICS,
59     jupyter_dependencies_are_installed,
60 )
61
62
63 # lib2to3 fork
64 from blib2to3.pytree import Node, Leaf
65 from blib2to3.pgen2 import token
66
67 from _black_version import version as __version__
68
69 # types
70 FileContent = str
71 Encoding = str
72 NewLine = str
73
74
75 class WriteBack(Enum):
76     NO = 0
77     YES = 1
78     DIFF = 2
79     CHECK = 3
80     COLOR_DIFF = 4
81
82     @classmethod
83     def from_configuration(
84         cls, *, check: bool, diff: bool, color: bool = False
85     ) -> "WriteBack":
86         if check and not diff:
87             return cls.CHECK
88
89         if diff and color:
90             return cls.COLOR_DIFF
91
92         return cls.DIFF if diff else cls.YES
93
94
95 # Legacy name, left for integrations.
96 FileMode = Mode
97
98
99 def read_pyproject_toml(
100     ctx: click.Context, param: click.Parameter, value: Optional[str]
101 ) -> Optional[str]:
102     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
103
104     Returns the path to a successfully found and read configuration file, None
105     otherwise.
106     """
107     if not value:
108         value = find_pyproject_toml(ctx.params.get("src", ()))
109         if value is None:
110             return None
111
112     try:
113         config = parse_pyproject_toml(value)
114     except (OSError, ValueError) as e:
115         raise click.FileError(
116             filename=value, hint=f"Error reading configuration file: {e}"
117         ) from None
118
119     if not config:
120         return None
121     else:
122         # Sanitize the values to be Click friendly. For more information please see:
123         # https://github.com/psf/black/issues/1458
124         # https://github.com/pallets/click/issues/1567
125         config = {
126             k: str(v) if not isinstance(v, (list, dict)) else v
127             for k, v in config.items()
128         }
129
130     target_version = config.get("target_version")
131     if target_version is not None and not isinstance(target_version, list):
132         raise click.BadOptionUsage(
133             "target-version", "Config key target-version must be a list"
134         )
135
136     default_map: Dict[str, Any] = {}
137     if ctx.default_map:
138         default_map.update(ctx.default_map)
139     default_map.update(config)
140
141     ctx.default_map = default_map
142     return value
143
144
145 def target_version_option_callback(
146     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
147 ) -> List[TargetVersion]:
148     """Compute the target versions from a --target-version flag.
149
150     This is its own function because mypy couldn't infer the type correctly
151     when it was a lambda, causing mypyc trouble.
152     """
153     return [TargetVersion[val.upper()] for val in v]
154
155
156 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
157     """Compile a regular expression string in `regex`.
158
159     If it contains newlines, use verbose mode.
160     """
161     if "\n" in regex:
162         regex = "(?x)" + regex
163     compiled: Pattern[str] = re.compile(regex)
164     return compiled
165
166
167 def validate_regex(
168     ctx: click.Context,
169     param: click.Parameter,
170     value: Optional[str],
171 ) -> Optional[Pattern]:
172     try:
173         return re_compile_maybe_verbose(value) if value is not None else None
174     except re.error:
175         raise click.BadParameter("Not a valid regular expression") from None
176
177
178 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
179 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
180 @click.option(
181     "-l",
182     "--line-length",
183     type=int,
184     default=DEFAULT_LINE_LENGTH,
185     help="How many characters per line to allow.",
186     show_default=True,
187 )
188 @click.option(
189     "-t",
190     "--target-version",
191     type=click.Choice([v.name.lower() for v in TargetVersion]),
192     callback=target_version_option_callback,
193     multiple=True,
194     help=(
195         "Python versions that should be supported by Black's output. [default: per-file"
196         " auto-detection]"
197     ),
198 )
199 @click.option(
200     "--pyi",
201     is_flag=True,
202     help=(
203         "Format all input files like typing stubs regardless of file extension (useful"
204         " when piping source on standard input)."
205     ),
206 )
207 @click.option(
208     "--ipynb",
209     is_flag=True,
210     help=(
211         "Format all input files like Jupyter Notebooks regardless of file extension "
212         "(useful when piping source on standard input)."
213     ),
214 )
215 @click.option(
216     "-S",
217     "--skip-string-normalization",
218     is_flag=True,
219     help="Don't normalize string quotes or prefixes.",
220 )
221 @click.option(
222     "-C",
223     "--skip-magic-trailing-comma",
224     is_flag=True,
225     help="Don't use trailing commas as a reason to split lines.",
226 )
227 @click.option(
228     "--experimental-string-processing",
229     is_flag=True,
230     hidden=True,
231     help=(
232         "Experimental option that performs more normalization on string literals."
233         " Currently disabled because it leads to some crashes."
234     ),
235 )
236 @click.option(
237     "--check",
238     is_flag=True,
239     help=(
240         "Don't write the files back, just return the status. Return code 0 means"
241         " nothing would change. Return code 1 means some files would be reformatted."
242         " Return code 123 means there was an internal error."
243     ),
244 )
245 @click.option(
246     "--diff",
247     is_flag=True,
248     help="Don't write the files back, just output a diff for each file on stdout.",
249 )
250 @click.option(
251     "--color/--no-color",
252     is_flag=True,
253     help="Show colored diff. Only applies when `--diff` is given.",
254 )
255 @click.option(
256     "--fast/--safe",
257     is_flag=True,
258     help="If --fast given, skip temporary sanity checks. [default: --safe]",
259 )
260 @click.option(
261     "--required-version",
262     type=str,
263     help=(
264         "Require a specific version of Black to be running (useful for unifying results"
265         " across many environments e.g. with a pyproject.toml file)."
266     ),
267 )
268 @click.option(
269     "--include",
270     type=str,
271     default=DEFAULT_INCLUDES,
272     callback=validate_regex,
273     help=(
274         "A regular expression that matches files and directories that should be"
275         " included on recursive searches. An empty value means all files are included"
276         " regardless of the name. Use forward slashes for directories on all platforms"
277         " (Windows, too). Exclusions are calculated first, inclusions later."
278     ),
279     show_default=True,
280 )
281 @click.option(
282     "--exclude",
283     type=str,
284     callback=validate_regex,
285     help=(
286         "A regular expression that matches files and directories that should be"
287         " excluded on recursive searches. An empty value means no paths are excluded."
288         " Use forward slashes for directories on all platforms (Windows, too)."
289         " Exclusions are calculated first, inclusions later. [default:"
290         f" {DEFAULT_EXCLUDES}]"
291     ),
292     show_default=False,
293 )
294 @click.option(
295     "--extend-exclude",
296     type=str,
297     callback=validate_regex,
298     help=(
299         "Like --exclude, but adds additional files and directories on top of the"
300         " excluded ones. (Useful if you simply want to add to the default)"
301     ),
302 )
303 @click.option(
304     "--force-exclude",
305     type=str,
306     callback=validate_regex,
307     help=(
308         "Like --exclude, but files and directories matching this regex will be "
309         "excluded even when they are passed explicitly as arguments."
310     ),
311 )
312 @click.option(
313     "--stdin-filename",
314     type=str,
315     help=(
316         "The name of the file when passing it through stdin. Useful to make "
317         "sure Black will respect --force-exclude option on some "
318         "editors that rely on using stdin."
319     ),
320 )
321 @click.option(
322     "-q",
323     "--quiet",
324     is_flag=True,
325     help=(
326         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
327         " those with 2>/dev/null."
328     ),
329 )
330 @click.option(
331     "-v",
332     "--verbose",
333     is_flag=True,
334     help=(
335         "Also emit messages to stderr about files that were not changed or were ignored"
336         " due to exclusion patterns."
337     ),
338 )
339 @click.version_option(version=__version__)
340 @click.argument(
341     "src",
342     nargs=-1,
343     type=click.Path(
344         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
345     ),
346     is_eager=True,
347     metavar="SRC ...",
348 )
349 @click.option(
350     "--config",
351     type=click.Path(
352         exists=True,
353         file_okay=True,
354         dir_okay=False,
355         readable=True,
356         allow_dash=False,
357         path_type=str,
358     ),
359     is_eager=True,
360     callback=read_pyproject_toml,
361     help="Read configuration from FILE path.",
362 )
363 @click.pass_context
364 def main(
365     ctx: click.Context,
366     code: Optional[str],
367     line_length: int,
368     target_version: List[TargetVersion],
369     check: bool,
370     diff: bool,
371     color: bool,
372     fast: bool,
373     pyi: bool,
374     ipynb: bool,
375     skip_string_normalization: bool,
376     skip_magic_trailing_comma: bool,
377     experimental_string_processing: bool,
378     quiet: bool,
379     verbose: bool,
380     required_version: str,
381     include: Pattern,
382     exclude: Optional[Pattern],
383     extend_exclude: Optional[Pattern],
384     force_exclude: Optional[Pattern],
385     stdin_filename: Optional[str],
386     src: Tuple[str, ...],
387     config: Optional[str],
388 ) -> None:
389     """The uncompromising code formatter."""
390     if config and verbose:
391         out(f"Using configuration from {config}.", bold=False, fg="blue")
392
393     error_msg = "Oh no! 💥 💔 💥"
394     if required_version and required_version != __version__:
395         err(
396             f"{error_msg} The required version `{required_version}` does not match"
397             f" the running version `{__version__}`!"
398         )
399         ctx.exit(1)
400     if ipynb and pyi:
401         err("Cannot pass both `pyi` and `ipynb` flags!")
402         ctx.exit(1)
403
404     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
405     if target_version:
406         versions = set(target_version)
407     else:
408         # We'll autodetect later.
409         versions = set()
410     mode = Mode(
411         target_versions=versions,
412         line_length=line_length,
413         is_pyi=pyi,
414         is_ipynb=ipynb,
415         string_normalization=not skip_string_normalization,
416         magic_trailing_comma=not skip_magic_trailing_comma,
417         experimental_string_processing=experimental_string_processing,
418     )
419
420     if code is not None:
421         # Run in quiet mode by default with -c; the extra output isn't useful.
422         # You can still pass -v to get verbose output.
423         quiet = True
424
425     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
426
427     if code is not None:
428         reformat_code(
429             content=code, fast=fast, write_back=write_back, mode=mode, report=report
430         )
431     else:
432         try:
433             sources = get_sources(
434                 ctx=ctx,
435                 src=src,
436                 quiet=quiet,
437                 verbose=verbose,
438                 include=include,
439                 exclude=exclude,
440                 extend_exclude=extend_exclude,
441                 force_exclude=force_exclude,
442                 report=report,
443                 stdin_filename=stdin_filename,
444             )
445         except GitWildMatchPatternError:
446             ctx.exit(1)
447
448         path_empty(
449             sources,
450             "No Python files are present to be formatted. Nothing to do 😴",
451             quiet,
452             verbose,
453             ctx,
454         )
455
456         if len(sources) == 1:
457             reformat_one(
458                 src=sources.pop(),
459                 fast=fast,
460                 write_back=write_back,
461                 mode=mode,
462                 report=report,
463             )
464         else:
465             reformat_many(
466                 sources=sources,
467                 fast=fast,
468                 write_back=write_back,
469                 mode=mode,
470                 report=report,
471             )
472
473     if verbose or not quiet:
474         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
475         if code is None:
476             click.echo(str(report), err=True)
477     ctx.exit(report.return_code)
478
479
480 def get_sources(
481     *,
482     ctx: click.Context,
483     src: Tuple[str, ...],
484     quiet: bool,
485     verbose: bool,
486     include: Pattern[str],
487     exclude: Optional[Pattern[str]],
488     extend_exclude: Optional[Pattern[str]],
489     force_exclude: Optional[Pattern[str]],
490     report: "Report",
491     stdin_filename: Optional[str],
492 ) -> Set[Path]:
493     """Compute the set of files to be formatted."""
494
495     root = find_project_root(src)
496     sources: Set[Path] = set()
497     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
498
499     if exclude is None:
500         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
501         gitignore = get_gitignore(root)
502     else:
503         gitignore = None
504
505     for s in src:
506         if s == "-" and stdin_filename:
507             p = Path(stdin_filename)
508             is_stdin = True
509         else:
510             p = Path(s)
511             is_stdin = False
512
513         if is_stdin or p.is_file():
514             normalized_path = normalize_path_maybe_ignore(p, root, report)
515             if normalized_path is None:
516                 continue
517
518             normalized_path = "/" + normalized_path
519             # Hard-exclude any files that matches the `--force-exclude` regex.
520             if force_exclude:
521                 force_exclude_match = force_exclude.search(normalized_path)
522             else:
523                 force_exclude_match = None
524             if force_exclude_match and force_exclude_match.group(0):
525                 report.path_ignored(p, "matches the --force-exclude regular expression")
526                 continue
527
528             if is_stdin:
529                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
530
531             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
532                 verbose=verbose, quiet=quiet
533             ):
534                 continue
535
536             sources.add(p)
537         elif p.is_dir():
538             sources.update(
539                 gen_python_files(
540                     p.iterdir(),
541                     root,
542                     include,
543                     exclude,
544                     extend_exclude,
545                     force_exclude,
546                     report,
547                     gitignore,
548                     verbose=verbose,
549                     quiet=quiet,
550                 )
551             )
552         elif s == "-":
553             sources.add(p)
554         else:
555             err(f"invalid path: {s}")
556     return sources
557
558
559 def path_empty(
560     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
561 ) -> None:
562     """
563     Exit if there is no `src` provided for formatting
564     """
565     if not src:
566         if verbose or not quiet:
567             out(msg)
568         ctx.exit(0)
569
570
571 def reformat_code(
572     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
573 ) -> None:
574     """
575     Reformat and print out `content` without spawning child processes.
576     Similar to `reformat_one`, but for string content.
577
578     `fast`, `write_back`, and `mode` options are passed to
579     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
580     """
581     path = Path("<string>")
582     try:
583         changed = Changed.NO
584         if format_stdin_to_stdout(
585             content=content, fast=fast, write_back=write_back, mode=mode
586         ):
587             changed = Changed.YES
588         report.done(path, changed)
589     except Exception as exc:
590         if report.verbose:
591             traceback.print_exc()
592         report.failed(path, str(exc))
593
594
595 def reformat_one(
596     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
597 ) -> None:
598     """Reformat a single file under `src` without spawning child processes.
599
600     `fast`, `write_back`, and `mode` options are passed to
601     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
602     """
603     try:
604         changed = Changed.NO
605
606         if str(src) == "-":
607             is_stdin = True
608         elif str(src).startswith(STDIN_PLACEHOLDER):
609             is_stdin = True
610             # Use the original name again in case we want to print something
611             # to the user
612             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
613         else:
614             is_stdin = False
615
616         if is_stdin:
617             if src.suffix == ".pyi":
618                 mode = replace(mode, is_pyi=True)
619             elif src.suffix == ".ipynb":
620                 mode = replace(mode, is_ipynb=True)
621             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
622                 changed = Changed.YES
623         else:
624             cache: Cache = {}
625             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
626                 cache = read_cache(mode)
627                 res_src = src.resolve()
628                 res_src_s = str(res_src)
629                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
630                     changed = Changed.CACHED
631             if changed is not Changed.CACHED and format_file_in_place(
632                 src, fast=fast, write_back=write_back, mode=mode
633             ):
634                 changed = Changed.YES
635             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
636                 write_back is WriteBack.CHECK and changed is Changed.NO
637             ):
638                 write_cache(cache, [src], mode)
639         report.done(src, changed)
640     except Exception as exc:
641         if report.verbose:
642             traceback.print_exc()
643         report.failed(src, str(exc))
644
645
646 def reformat_many(
647     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
648 ) -> None:
649     """Reformat multiple files using a ProcessPoolExecutor."""
650     executor: Executor
651     loop = asyncio.get_event_loop()
652     worker_count = os.cpu_count()
653     if sys.platform == "win32":
654         # Work around https://bugs.python.org/issue26903
655         worker_count = min(worker_count, 60)
656     try:
657         executor = ProcessPoolExecutor(max_workers=worker_count)
658     except (ImportError, OSError):
659         # we arrive here if the underlying system does not support multi-processing
660         # like in AWS Lambda or Termux, in which case we gracefully fallback to
661         # a ThreadPoolExecutor with just a single worker (more workers would not do us
662         # any good due to the Global Interpreter Lock)
663         executor = ThreadPoolExecutor(max_workers=1)
664
665     try:
666         loop.run_until_complete(
667             schedule_formatting(
668                 sources=sources,
669                 fast=fast,
670                 write_back=write_back,
671                 mode=mode,
672                 report=report,
673                 loop=loop,
674                 executor=executor,
675             )
676         )
677     finally:
678         shutdown(loop)
679         if executor is not None:
680             executor.shutdown()
681
682
683 async def schedule_formatting(
684     sources: Set[Path],
685     fast: bool,
686     write_back: WriteBack,
687     mode: Mode,
688     report: "Report",
689     loop: asyncio.AbstractEventLoop,
690     executor: Executor,
691 ) -> None:
692     """Run formatting of `sources` in parallel using the provided `executor`.
693
694     (Use ProcessPoolExecutors for actual parallelism.)
695
696     `write_back`, `fast`, and `mode` options are passed to
697     :func:`format_file_in_place`.
698     """
699     cache: Cache = {}
700     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
701         cache = read_cache(mode)
702         sources, cached = filter_cached(cache, sources)
703         for src in sorted(cached):
704             report.done(src, Changed.CACHED)
705     if not sources:
706         return
707
708     cancelled = []
709     sources_to_cache = []
710     lock = None
711     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
712         # For diff output, we need locks to ensure we don't interleave output
713         # from different processes.
714         manager = Manager()
715         lock = manager.Lock()
716     tasks = {
717         asyncio.ensure_future(
718             loop.run_in_executor(
719                 executor, format_file_in_place, src, fast, mode, write_back, lock
720             )
721         ): src
722         for src in sorted(sources)
723     }
724     pending = tasks.keys()
725     try:
726         loop.add_signal_handler(signal.SIGINT, cancel, pending)
727         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
728     except NotImplementedError:
729         # There are no good alternatives for these on Windows.
730         pass
731     while pending:
732         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
733         for task in done:
734             src = tasks.pop(task)
735             if task.cancelled():
736                 cancelled.append(task)
737             elif task.exception():
738                 report.failed(src, str(task.exception()))
739             else:
740                 changed = Changed.YES if task.result() else Changed.NO
741                 # If the file was written back or was successfully checked as
742                 # well-formatted, store this information in the cache.
743                 if write_back is WriteBack.YES or (
744                     write_back is WriteBack.CHECK and changed is Changed.NO
745                 ):
746                     sources_to_cache.append(src)
747                 report.done(src, changed)
748     if cancelled:
749         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
750     if sources_to_cache:
751         write_cache(cache, sources_to_cache, mode)
752
753
754 def format_file_in_place(
755     src: Path,
756     fast: bool,
757     mode: Mode,
758     write_back: WriteBack = WriteBack.NO,
759     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
760 ) -> bool:
761     """Format file under `src` path. Return True if changed.
762
763     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
764     code to the file.
765     `mode` and `fast` options are passed to :func:`format_file_contents`.
766     """
767     if src.suffix == ".pyi":
768         mode = replace(mode, is_pyi=True)
769     elif src.suffix == ".ipynb":
770         mode = replace(mode, is_ipynb=True)
771
772     then = datetime.utcfromtimestamp(src.stat().st_mtime)
773     with open(src, "rb") as buf:
774         src_contents, encoding, newline = decode_bytes(buf.read())
775     try:
776         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
777     except NothingChanged:
778         return False
779     except JSONDecodeError:
780         raise ValueError(
781             f"File '{src}' cannot be parsed as valid Jupyter notebook."
782         ) from None
783
784     if write_back == WriteBack.YES:
785         with open(src, "w", encoding=encoding, newline=newline) as f:
786             f.write(dst_contents)
787     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
788         now = datetime.utcnow()
789         src_name = f"{src}\t{then} +0000"
790         dst_name = f"{src}\t{now} +0000"
791         if mode.is_ipynb:
792             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
793         else:
794             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
795
796         if write_back == WriteBack.COLOR_DIFF:
797             diff_contents = color_diff(diff_contents)
798
799         with lock or nullcontext():
800             f = io.TextIOWrapper(
801                 sys.stdout.buffer,
802                 encoding=encoding,
803                 newline=newline,
804                 write_through=True,
805             )
806             f = wrap_stream_for_windows(f)
807             f.write(diff_contents)
808             f.detach()
809
810     return True
811
812
813 def format_stdin_to_stdout(
814     fast: bool,
815     *,
816     content: Optional[str] = None,
817     write_back: WriteBack = WriteBack.NO,
818     mode: Mode,
819 ) -> bool:
820     """Format file on stdin. Return True if changed.
821
822     If content is None, it's read from sys.stdin.
823
824     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
825     write a diff to stdout. The `mode` argument is passed to
826     :func:`format_file_contents`.
827     """
828     then = datetime.utcnow()
829
830     if content is None:
831         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
832     else:
833         src, encoding, newline = content, "utf-8", ""
834
835     dst = src
836     try:
837         dst = format_file_contents(src, fast=fast, mode=mode)
838         return True
839
840     except NothingChanged:
841         return False
842
843     finally:
844         f = io.TextIOWrapper(
845             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
846         )
847         if write_back == WriteBack.YES:
848             # Make sure there's a newline after the content
849             if dst and dst[-1] != "\n":
850                 dst += "\n"
851             f.write(dst)
852         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
853             now = datetime.utcnow()
854             src_name = f"STDIN\t{then} +0000"
855             dst_name = f"STDOUT\t{now} +0000"
856             d = diff(src, dst, src_name, dst_name)
857             if write_back == WriteBack.COLOR_DIFF:
858                 d = color_diff(d)
859                 f = wrap_stream_for_windows(f)
860             f.write(d)
861         f.detach()
862
863
864 def check_stability_and_equivalence(
865     src_contents: str, dst_contents: str, *, mode: Mode
866 ) -> None:
867     """Perform stability and equivalence checks.
868
869     Raise AssertionError if source and destination contents are not
870     equivalent, or if a second pass of the formatter would format the
871     content differently.
872     """
873     assert_equivalent(src_contents, dst_contents)
874
875     # Forced second pass to work around optional trailing commas (becoming
876     # forced trailing commas on pass 2) interacting differently with optional
877     # parentheses.  Admittedly ugly.
878     dst_contents_pass2 = format_str(dst_contents, mode=mode)
879     if dst_contents != dst_contents_pass2:
880         dst_contents = dst_contents_pass2
881         assert_equivalent(src_contents, dst_contents, pass_num=2)
882         assert_stable(src_contents, dst_contents, mode=mode)
883     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
884     # the same as `dst_contents_pass2`.
885
886
887 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
888     """Reformat contents of a file and return new contents.
889
890     If `fast` is False, additionally confirm that the reformatted code is
891     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
892     `mode` is passed to :func:`format_str`.
893     """
894     if not src_contents.strip():
895         raise NothingChanged
896
897     if mode.is_ipynb:
898         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
899     else:
900         dst_contents = format_str(src_contents, mode=mode)
901     if src_contents == dst_contents:
902         raise NothingChanged
903
904     if not fast and not mode.is_ipynb:
905         # Jupyter notebooks will already have been checked above.
906         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
907     return dst_contents
908
909
910 def validate_cell(src: str) -> None:
911     """Check that cell does not already contain TransformerManager transformations.
912
913     If a cell contains ``!ls``, then it'll be transformed to
914     ``get_ipython().system('ls')``. However, if the cell originally contained
915     ``get_ipython().system('ls')``, then it would get transformed in the same way:
916
917         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
918         "get_ipython().system('ls')\n"
919         >>> TransformerManager().transform_cell("!ls")
920         "get_ipython().system('ls')\n"
921
922     Due to the impossibility of safely roundtripping in such situations, cells
923     containing transformed magics will be ignored.
924     """
925     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
926         raise NothingChanged
927
928
929 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
930     """Format code in given cell of Jupyter notebook.
931
932     General idea is:
933
934       - if cell has trailing semicolon, remove it;
935       - if cell has IPython magics, mask them;
936       - format cell;
937       - reinstate IPython magics;
938       - reinstate trailing semicolon (if originally present);
939       - strip trailing newlines.
940
941     Cells with syntax errors will not be processed, as they
942     could potentially be automagics or multi-line magics, which
943     are currently not supported.
944     """
945     validate_cell(src)
946     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
947         src
948     )
949     try:
950         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
951     except SyntaxError:
952         raise NothingChanged from None
953     masked_dst = format_str(masked_src, mode=mode)
954     if not fast:
955         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
956     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
957     dst = put_trailing_semicolon_back(
958         dst_without_trailing_semicolon, has_trailing_semicolon
959     )
960     dst = dst.rstrip("\n")
961     if dst == src:
962         raise NothingChanged from None
963     return dst
964
965
966 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
967     """If notebook is marked as non-Python, don't format it.
968
969     All notebook metadata fields are optional, see
970     https://nbformat.readthedocs.io/en/latest/format_description.html. So
971     if a notebook has empty metadata, we will try to parse it anyway.
972     """
973     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
974     if language is not None and language != "python":
975         raise NothingChanged from None
976
977
978 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
979     """Format Jupyter notebook.
980
981     Operate cell-by-cell, only on code cells, only for Python notebooks.
982     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
983     """
984     trailing_newline = src_contents[-1] == "\n"
985     modified = False
986     nb = json.loads(src_contents)
987     validate_metadata(nb)
988     for cell in nb["cells"]:
989         if cell.get("cell_type", None) == "code":
990             try:
991                 src = "".join(cell["source"])
992                 dst = format_cell(src, fast=fast, mode=mode)
993             except NothingChanged:
994                 pass
995             else:
996                 cell["source"] = dst.splitlines(keepends=True)
997                 modified = True
998     if modified:
999         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1000         if trailing_newline:
1001             dst_contents = dst_contents + "\n"
1002         return dst_contents
1003     else:
1004         raise NothingChanged
1005
1006
1007 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1008     """Reformat a string and return new contents.
1009
1010     `mode` determines formatting options, such as how many characters per line are
1011     allowed.  Example:
1012
1013     >>> import black
1014     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1015     def f(arg: str = "") -> None:
1016         ...
1017
1018     A more complex example:
1019
1020     >>> print(
1021     ...   black.format_str(
1022     ...     "def f(arg:str='')->None: hey",
1023     ...     mode=black.Mode(
1024     ...       target_versions={black.TargetVersion.PY36},
1025     ...       line_length=10,
1026     ...       string_normalization=False,
1027     ...       is_pyi=False,
1028     ...     ),
1029     ...   ),
1030     ... )
1031     def f(
1032         arg: str = '',
1033     ) -> None:
1034         hey
1035
1036     """
1037     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1038     dst_contents = []
1039     future_imports = get_future_imports(src_node)
1040     if mode.target_versions:
1041         versions = mode.target_versions
1042     else:
1043         versions = detect_target_versions(src_node)
1044     normalize_fmt_off(src_node)
1045     lines = LineGenerator(
1046         mode=mode,
1047         remove_u_prefix="unicode_literals" in future_imports
1048         or supports_feature(versions, Feature.UNICODE_LITERALS),
1049     )
1050     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1051     empty_line = Line(mode=mode)
1052     after = 0
1053     split_line_features = {
1054         feature
1055         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1056         if supports_feature(versions, feature)
1057     }
1058     for current_line in lines.visit(src_node):
1059         dst_contents.append(str(empty_line) * after)
1060         before, after = elt.maybe_empty_lines(current_line)
1061         dst_contents.append(str(empty_line) * before)
1062         for line in transform_line(
1063             current_line, mode=mode, features=split_line_features
1064         ):
1065             dst_contents.append(str(line))
1066     return "".join(dst_contents)
1067
1068
1069 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1070     """Return a tuple of (decoded_contents, encoding, newline).
1071
1072     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1073     universal newlines (i.e. only contains LF).
1074     """
1075     srcbuf = io.BytesIO(src)
1076     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1077     if not lines:
1078         return "", encoding, "\n"
1079
1080     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1081     srcbuf.seek(0)
1082     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1083         return tiow.read(), encoding, newline
1084
1085
1086 def get_features_used(node: Node) -> Set[Feature]:
1087     """Return a set of (relatively) new Python features used in this file.
1088
1089     Currently looking for:
1090     - f-strings;
1091     - underscores in numeric literals;
1092     - trailing commas after * or ** in function signatures and calls;
1093     - positional only arguments in function signatures and lambdas;
1094     - assignment expression;
1095     - relaxed decorator syntax;
1096     """
1097     features: Set[Feature] = set()
1098     for n in node.pre_order():
1099         if n.type == token.STRING:
1100             value_head = n.value[:2]  # type: ignore
1101             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1102                 features.add(Feature.F_STRINGS)
1103
1104         elif n.type == token.NUMBER:
1105             if "_" in n.value:  # type: ignore
1106                 features.add(Feature.NUMERIC_UNDERSCORES)
1107
1108         elif n.type == token.SLASH:
1109             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
1110                 features.add(Feature.POS_ONLY_ARGUMENTS)
1111
1112         elif n.type == token.COLONEQUAL:
1113             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1114
1115         elif n.type == syms.decorator:
1116             if len(n.children) > 1 and not is_simple_decorator_expression(
1117                 n.children[1]
1118             ):
1119                 features.add(Feature.RELAXED_DECORATORS)
1120
1121         elif (
1122             n.type in {syms.typedargslist, syms.arglist}
1123             and n.children
1124             and n.children[-1].type == token.COMMA
1125         ):
1126             if n.type == syms.typedargslist:
1127                 feature = Feature.TRAILING_COMMA_IN_DEF
1128             else:
1129                 feature = Feature.TRAILING_COMMA_IN_CALL
1130
1131             for ch in n.children:
1132                 if ch.type in STARS:
1133                     features.add(feature)
1134
1135                 if ch.type == syms.argument:
1136                     for argch in ch.children:
1137                         if argch.type in STARS:
1138                             features.add(feature)
1139
1140     return features
1141
1142
1143 def detect_target_versions(node: Node) -> Set[TargetVersion]:
1144     """Detect the version to target based on the nodes used."""
1145     features = get_features_used(node)
1146     return {
1147         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1148     }
1149
1150
1151 def get_future_imports(node: Node) -> Set[str]:
1152     """Return a set of __future__ imports in the file."""
1153     imports: Set[str] = set()
1154
1155     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1156         for child in children:
1157             if isinstance(child, Leaf):
1158                 if child.type == token.NAME:
1159                     yield child.value
1160
1161             elif child.type == syms.import_as_name:
1162                 orig_name = child.children[0]
1163                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1164                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1165                 yield orig_name.value
1166
1167             elif child.type == syms.import_as_names:
1168                 yield from get_imports_from_children(child.children)
1169
1170             else:
1171                 raise AssertionError("Invalid syntax parsing imports")
1172
1173     for child in node.children:
1174         if child.type != syms.simple_stmt:
1175             break
1176
1177         first_child = child.children[0]
1178         if isinstance(first_child, Leaf):
1179             # Continue looking if we see a docstring; otherwise stop.
1180             if (
1181                 len(child.children) == 2
1182                 and first_child.type == token.STRING
1183                 and child.children[1].type == token.NEWLINE
1184             ):
1185                 continue
1186
1187             break
1188
1189         elif first_child.type == syms.import_from:
1190             module_name = first_child.children[1]
1191             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1192                 break
1193
1194             imports |= set(get_imports_from_children(first_child.children[3:]))
1195         else:
1196             break
1197
1198     return imports
1199
1200
1201 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1202     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1203     try:
1204         src_ast = parse_ast(src)
1205     except Exception as exc:
1206         raise AssertionError(
1207             "cannot use --safe with this file; failed to parse source file."
1208         ) from exc
1209
1210     try:
1211         dst_ast = parse_ast(dst)
1212     except Exception as exc:
1213         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1214         raise AssertionError(
1215             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1216             "Please report a bug on https://github.com/psf/black/issues.  "
1217             f"This invalid output might be helpful: {log}"
1218         ) from None
1219
1220     src_ast_str = "\n".join(stringify_ast(src_ast))
1221     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1222     if src_ast_str != dst_ast_str:
1223         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1224         raise AssertionError(
1225             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1226             f" source on pass {pass_num}.  Please report a bug on "
1227             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1228         ) from None
1229
1230
1231 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1232     """Raise AssertionError if `dst` reformats differently the second time."""
1233     newdst = format_str(dst, mode=mode)
1234     if dst != newdst:
1235         log = dump_to_file(
1236             str(mode),
1237             diff(src, dst, "source", "first pass"),
1238             diff(dst, newdst, "first pass", "second pass"),
1239         )
1240         raise AssertionError(
1241             "INTERNAL ERROR: Black produced different code on the second pass of the"
1242             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1243             f"  This diff might be helpful: {log}"
1244         ) from None
1245
1246
1247 @contextmanager
1248 def nullcontext() -> Iterator[None]:
1249     """Return an empty context manager.
1250
1251     To be used like `nullcontext` in Python 3.7.
1252     """
1253     yield
1254
1255
1256 def patch_click() -> None:
1257     """Make Click not crash on Python 3.6 with LANG=C.
1258
1259     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1260     default which restricts paths that it can access during the lifetime of the
1261     application.  Click refuses to work in this scenario by raising a RuntimeError.
1262
1263     In case of Black the likelihood that non-ASCII characters are going to be used in
1264     file paths is minimal since it's Python source code.  Moreover, this crash was
1265     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1266     """
1267     try:
1268         from click import core
1269         from click import _unicodefun  # type: ignore
1270     except ModuleNotFoundError:
1271         return
1272
1273     for module in (core, _unicodefun):
1274         if hasattr(module, "_verify_python3_env"):
1275             module._verify_python3_env = lambda: None  # type: ignore
1276         if hasattr(module, "_verify_python_env"):
1277             module._verify_python_env = lambda: None  # type: ignore
1278
1279
1280 def patched_main() -> None:
1281     maybe_install_uvloop()
1282     freeze_support()
1283     patch_click()
1284     main()
1285
1286
1287 if __name__ == "__main__":
1288     patched_main()