]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

MNT: add pull request template (#2443)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import regex as re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 from dataclasses import replace
34 import click
35
36 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
37 from black.const import STDIN_PLACEHOLDER
38 from black.nodes import STARS, syms, is_simple_decorator_expression
39 from black.lines import Line, EmptyLineTracker
40 from black.linegen import transform_line, LineGenerator, LN
41 from black.comments import normalize_fmt_off
42 from black.mode import Mode, TargetVersion
43 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
44 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
45 from black.concurrency import cancel, shutdown, maybe_install_uvloop
46 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
47 from black.report import Report, Changed, NothingChanged
48 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
49 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
50 from black.files import wrap_stream_for_windows
51 from black.parsing import InvalidInput  # noqa F401
52 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
53 from black.handle_ipynb_magics import (
54     mask_cell,
55     unmask_cell,
56     remove_trailing_semicolon,
57     put_trailing_semicolon_back,
58     TRANSFORMED_MAGICS,
59     jupyter_dependencies_are_installed,
60 )
61
62
63 # lib2to3 fork
64 from blib2to3.pytree import Node, Leaf
65 from blib2to3.pgen2 import token
66
67 from _black_version import version as __version__
68
69 # types
70 FileContent = str
71 Encoding = str
72 NewLine = str
73
74
75 class WriteBack(Enum):
76     NO = 0
77     YES = 1
78     DIFF = 2
79     CHECK = 3
80     COLOR_DIFF = 4
81
82     @classmethod
83     def from_configuration(
84         cls, *, check: bool, diff: bool, color: bool = False
85     ) -> "WriteBack":
86         if check and not diff:
87             return cls.CHECK
88
89         if diff and color:
90             return cls.COLOR_DIFF
91
92         return cls.DIFF if diff else cls.YES
93
94
95 # Legacy name, left for integrations.
96 FileMode = Mode
97
98
99 def read_pyproject_toml(
100     ctx: click.Context, param: click.Parameter, value: Optional[str]
101 ) -> Optional[str]:
102     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
103
104     Returns the path to a successfully found and read configuration file, None
105     otherwise.
106     """
107     if not value:
108         value = find_pyproject_toml(ctx.params.get("src", ()))
109         if value is None:
110             return None
111
112     try:
113         config = parse_pyproject_toml(value)
114     except (OSError, ValueError) as e:
115         raise click.FileError(
116             filename=value, hint=f"Error reading configuration file: {e}"
117         )
118
119     if not config:
120         return None
121     else:
122         # Sanitize the values to be Click friendly. For more information please see:
123         # https://github.com/psf/black/issues/1458
124         # https://github.com/pallets/click/issues/1567
125         config = {
126             k: str(v) if not isinstance(v, (list, dict)) else v
127             for k, v in config.items()
128         }
129
130     target_version = config.get("target_version")
131     if target_version is not None and not isinstance(target_version, list):
132         raise click.BadOptionUsage(
133             "target-version", "Config key target-version must be a list"
134         )
135
136     default_map: Dict[str, Any] = {}
137     if ctx.default_map:
138         default_map.update(ctx.default_map)
139     default_map.update(config)
140
141     ctx.default_map = default_map
142     return value
143
144
145 def target_version_option_callback(
146     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
147 ) -> List[TargetVersion]:
148     """Compute the target versions from a --target-version flag.
149
150     This is its own function because mypy couldn't infer the type correctly
151     when it was a lambda, causing mypyc trouble.
152     """
153     return [TargetVersion[val.upper()] for val in v]
154
155
156 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
157     """Compile a regular expression string in `regex`.
158
159     If it contains newlines, use verbose mode.
160     """
161     if "\n" in regex:
162         regex = "(?x)" + regex
163     compiled: Pattern[str] = re.compile(regex)
164     return compiled
165
166
167 def validate_regex(
168     ctx: click.Context,
169     param: click.Parameter,
170     value: Optional[str],
171 ) -> Optional[Pattern]:
172     try:
173         return re_compile_maybe_verbose(value) if value is not None else None
174     except re.error:
175         raise click.BadParameter("Not a valid regular expression")
176
177
178 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
179 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
180 @click.option(
181     "-l",
182     "--line-length",
183     type=int,
184     default=DEFAULT_LINE_LENGTH,
185     help="How many characters per line to allow.",
186     show_default=True,
187 )
188 @click.option(
189     "-t",
190     "--target-version",
191     type=click.Choice([v.name.lower() for v in TargetVersion]),
192     callback=target_version_option_callback,
193     multiple=True,
194     help=(
195         "Python versions that should be supported by Black's output. [default: per-file"
196         " auto-detection]"
197     ),
198 )
199 @click.option(
200     "--pyi",
201     is_flag=True,
202     help=(
203         "Format all input files like typing stubs regardless of file extension (useful"
204         " when piping source on standard input)."
205     ),
206 )
207 @click.option(
208     "--ipynb",
209     is_flag=True,
210     help=(
211         "Format all input files like Jupyter Notebooks regardless of file extension "
212         "(useful when piping source on standard input)."
213     ),
214 )
215 @click.option(
216     "-S",
217     "--skip-string-normalization",
218     is_flag=True,
219     help="Don't normalize string quotes or prefixes.",
220 )
221 @click.option(
222     "-C",
223     "--skip-magic-trailing-comma",
224     is_flag=True,
225     help="Don't use trailing commas as a reason to split lines.",
226 )
227 @click.option(
228     "--experimental-string-processing",
229     is_flag=True,
230     hidden=True,
231     help=(
232         "Experimental option that performs more normalization on string literals."
233         " Currently disabled because it leads to some crashes."
234     ),
235 )
236 @click.option(
237     "--check",
238     is_flag=True,
239     help=(
240         "Don't write the files back, just return the status. Return code 0 means"
241         " nothing would change. Return code 1 means some files would be reformatted."
242         " Return code 123 means there was an internal error."
243     ),
244 )
245 @click.option(
246     "--diff",
247     is_flag=True,
248     help="Don't write the files back, just output a diff for each file on stdout.",
249 )
250 @click.option(
251     "--color/--no-color",
252     is_flag=True,
253     help="Show colored diff. Only applies when `--diff` is given.",
254 )
255 @click.option(
256     "--fast/--safe",
257     is_flag=True,
258     help="If --fast given, skip temporary sanity checks. [default: --safe]",
259 )
260 @click.option(
261     "--required-version",
262     type=str,
263     help=(
264         "Require a specific version of Black to be running (useful for unifying results"
265         " across many environments e.g. with a pyproject.toml file)."
266     ),
267 )
268 @click.option(
269     "--include",
270     type=str,
271     default=DEFAULT_INCLUDES,
272     callback=validate_regex,
273     help=(
274         "A regular expression that matches files and directories that should be"
275         " included on recursive searches. An empty value means all files are included"
276         " regardless of the name. Use forward slashes for directories on all platforms"
277         " (Windows, too). Exclusions are calculated first, inclusions later."
278     ),
279     show_default=True,
280 )
281 @click.option(
282     "--exclude",
283     type=str,
284     callback=validate_regex,
285     help=(
286         "A regular expression that matches files and directories that should be"
287         " excluded on recursive searches. An empty value means no paths are excluded."
288         " Use forward slashes for directories on all platforms (Windows, too)."
289         " Exclusions are calculated first, inclusions later. [default:"
290         f" {DEFAULT_EXCLUDES}]"
291     ),
292     show_default=False,
293 )
294 @click.option(
295     "--extend-exclude",
296     type=str,
297     callback=validate_regex,
298     help=(
299         "Like --exclude, but adds additional files and directories on top of the"
300         " excluded ones. (Useful if you simply want to add to the default)"
301     ),
302 )
303 @click.option(
304     "--force-exclude",
305     type=str,
306     callback=validate_regex,
307     help=(
308         "Like --exclude, but files and directories matching this regex will be "
309         "excluded even when they are passed explicitly as arguments."
310     ),
311 )
312 @click.option(
313     "--stdin-filename",
314     type=str,
315     help=(
316         "The name of the file when passing it through stdin. Useful to make "
317         "sure Black will respect --force-exclude option on some "
318         "editors that rely on using stdin."
319     ),
320 )
321 @click.option(
322     "-q",
323     "--quiet",
324     is_flag=True,
325     help=(
326         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
327         " those with 2>/dev/null."
328     ),
329 )
330 @click.option(
331     "-v",
332     "--verbose",
333     is_flag=True,
334     help=(
335         "Also emit messages to stderr about files that were not changed or were ignored"
336         " due to exclusion patterns."
337     ),
338 )
339 @click.version_option(version=__version__)
340 @click.argument(
341     "src",
342     nargs=-1,
343     type=click.Path(
344         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
345     ),
346     is_eager=True,
347     metavar="SRC ...",
348 )
349 @click.option(
350     "--config",
351     type=click.Path(
352         exists=True,
353         file_okay=True,
354         dir_okay=False,
355         readable=True,
356         allow_dash=False,
357         path_type=str,
358     ),
359     is_eager=True,
360     callback=read_pyproject_toml,
361     help="Read configuration from FILE path.",
362 )
363 @click.pass_context
364 def main(
365     ctx: click.Context,
366     code: Optional[str],
367     line_length: int,
368     target_version: List[TargetVersion],
369     check: bool,
370     diff: bool,
371     color: bool,
372     fast: bool,
373     pyi: bool,
374     ipynb: bool,
375     skip_string_normalization: bool,
376     skip_magic_trailing_comma: bool,
377     experimental_string_processing: bool,
378     quiet: bool,
379     verbose: bool,
380     required_version: str,
381     include: Pattern,
382     exclude: Optional[Pattern],
383     extend_exclude: Optional[Pattern],
384     force_exclude: Optional[Pattern],
385     stdin_filename: Optional[str],
386     src: Tuple[str, ...],
387     config: Optional[str],
388 ) -> None:
389     """The uncompromising code formatter."""
390     if config and verbose:
391         out(f"Using configuration from {config}.", bold=False, fg="blue")
392
393     error_msg = "Oh no! 💥 💔 💥"
394     if required_version and required_version != __version__:
395         err(
396             f"{error_msg} The required version `{required_version}` does not match"
397             f" the running version `{__version__}`!"
398         )
399         ctx.exit(1)
400     if ipynb and pyi:
401         err("Cannot pass both `pyi` and `ipynb` flags!")
402         ctx.exit(1)
403
404     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
405     if target_version:
406         versions = set(target_version)
407     else:
408         # We'll autodetect later.
409         versions = set()
410     mode = Mode(
411         target_versions=versions,
412         line_length=line_length,
413         is_pyi=pyi,
414         is_ipynb=ipynb,
415         string_normalization=not skip_string_normalization,
416         magic_trailing_comma=not skip_magic_trailing_comma,
417         experimental_string_processing=experimental_string_processing,
418     )
419
420     if code is not None:
421         # Run in quiet mode by default with -c; the extra output isn't useful.
422         # You can still pass -v to get verbose output.
423         quiet = True
424
425     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
426
427     if code is not None:
428         reformat_code(
429             content=code, fast=fast, write_back=write_back, mode=mode, report=report
430         )
431     else:
432         try:
433             sources = get_sources(
434                 ctx=ctx,
435                 src=src,
436                 quiet=quiet,
437                 verbose=verbose,
438                 include=include,
439                 exclude=exclude,
440                 extend_exclude=extend_exclude,
441                 force_exclude=force_exclude,
442                 report=report,
443                 stdin_filename=stdin_filename,
444             )
445         except GitWildMatchPatternError:
446             ctx.exit(1)
447
448         path_empty(
449             sources,
450             "No Python files are present to be formatted. Nothing to do 😴",
451             quiet,
452             verbose,
453             ctx,
454         )
455
456         if len(sources) == 1:
457             reformat_one(
458                 src=sources.pop(),
459                 fast=fast,
460                 write_back=write_back,
461                 mode=mode,
462                 report=report,
463             )
464         else:
465             reformat_many(
466                 sources=sources,
467                 fast=fast,
468                 write_back=write_back,
469                 mode=mode,
470                 report=report,
471             )
472
473     if verbose or not quiet:
474         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
475         if code is None:
476             click.echo(str(report), err=True)
477     ctx.exit(report.return_code)
478
479
480 def get_sources(
481     *,
482     ctx: click.Context,
483     src: Tuple[str, ...],
484     quiet: bool,
485     verbose: bool,
486     include: Pattern[str],
487     exclude: Optional[Pattern[str]],
488     extend_exclude: Optional[Pattern[str]],
489     force_exclude: Optional[Pattern[str]],
490     report: "Report",
491     stdin_filename: Optional[str],
492 ) -> Set[Path]:
493     """Compute the set of files to be formatted."""
494
495     root = find_project_root(src)
496     sources: Set[Path] = set()
497     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
498
499     if exclude is None:
500         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
501         gitignore = get_gitignore(root)
502     else:
503         gitignore = None
504
505     for s in src:
506         if s == "-" and stdin_filename:
507             p = Path(stdin_filename)
508             is_stdin = True
509         else:
510             p = Path(s)
511             is_stdin = False
512
513         if is_stdin or p.is_file():
514             normalized_path = normalize_path_maybe_ignore(p, root, report)
515             if normalized_path is None:
516                 continue
517
518             normalized_path = "/" + normalized_path
519             # Hard-exclude any files that matches the `--force-exclude` regex.
520             if force_exclude:
521                 force_exclude_match = force_exclude.search(normalized_path)
522             else:
523                 force_exclude_match = None
524             if force_exclude_match and force_exclude_match.group(0):
525                 report.path_ignored(p, "matches the --force-exclude regular expression")
526                 continue
527
528             if is_stdin:
529                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
530
531             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
532                 verbose=verbose, quiet=quiet
533             ):
534                 continue
535
536             sources.add(p)
537         elif p.is_dir():
538             sources.update(
539                 gen_python_files(
540                     p.iterdir(),
541                     root,
542                     include,
543                     exclude,
544                     extend_exclude,
545                     force_exclude,
546                     report,
547                     gitignore,
548                     verbose=verbose,
549                     quiet=quiet,
550                 )
551             )
552         elif s == "-":
553             sources.add(p)
554         else:
555             err(f"invalid path: {s}")
556     return sources
557
558
559 def path_empty(
560     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
561 ) -> None:
562     """
563     Exit if there is no `src` provided for formatting
564     """
565     if not src:
566         if verbose or not quiet:
567             out(msg)
568         ctx.exit(0)
569
570
571 def reformat_code(
572     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
573 ) -> None:
574     """
575     Reformat and print out `content` without spawning child processes.
576     Similar to `reformat_one`, but for string content.
577
578     `fast`, `write_back`, and `mode` options are passed to
579     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
580     """
581     path = Path("<string>")
582     try:
583         changed = Changed.NO
584         if format_stdin_to_stdout(
585             content=content, fast=fast, write_back=write_back, mode=mode
586         ):
587             changed = Changed.YES
588         report.done(path, changed)
589     except Exception as exc:
590         if report.verbose:
591             traceback.print_exc()
592         report.failed(path, str(exc))
593
594
595 def reformat_one(
596     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
597 ) -> None:
598     """Reformat a single file under `src` without spawning child processes.
599
600     `fast`, `write_back`, and `mode` options are passed to
601     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
602     """
603     try:
604         changed = Changed.NO
605
606         if str(src) == "-":
607             is_stdin = True
608         elif str(src).startswith(STDIN_PLACEHOLDER):
609             is_stdin = True
610             # Use the original name again in case we want to print something
611             # to the user
612             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
613         else:
614             is_stdin = False
615
616         if is_stdin:
617             if src.suffix == ".pyi":
618                 mode = replace(mode, is_pyi=True)
619             elif src.suffix == ".ipynb":
620                 mode = replace(mode, is_ipynb=True)
621             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
622                 changed = Changed.YES
623         else:
624             cache: Cache = {}
625             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
626                 cache = read_cache(mode)
627                 res_src = src.resolve()
628                 res_src_s = str(res_src)
629                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
630                     changed = Changed.CACHED
631             if changed is not Changed.CACHED and format_file_in_place(
632                 src, fast=fast, write_back=write_back, mode=mode
633             ):
634                 changed = Changed.YES
635             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
636                 write_back is WriteBack.CHECK and changed is Changed.NO
637             ):
638                 write_cache(cache, [src], mode)
639         report.done(src, changed)
640     except Exception as exc:
641         if report.verbose:
642             traceback.print_exc()
643         report.failed(src, str(exc))
644
645
646 def reformat_many(
647     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
648 ) -> None:
649     """Reformat multiple files using a ProcessPoolExecutor."""
650     executor: Executor
651     loop = asyncio.get_event_loop()
652     worker_count = os.cpu_count()
653     if sys.platform == "win32":
654         # Work around https://bugs.python.org/issue26903
655         worker_count = min(worker_count, 60)
656     try:
657         executor = ProcessPoolExecutor(max_workers=worker_count)
658     except (ImportError, OSError):
659         # we arrive here if the underlying system does not support multi-processing
660         # like in AWS Lambda or Termux, in which case we gracefully fallback to
661         # a ThreadPoolExecutor with just a single worker (more workers would not do us
662         # any good due to the Global Interpreter Lock)
663         executor = ThreadPoolExecutor(max_workers=1)
664
665     try:
666         loop.run_until_complete(
667             schedule_formatting(
668                 sources=sources,
669                 fast=fast,
670                 write_back=write_back,
671                 mode=mode,
672                 report=report,
673                 loop=loop,
674                 executor=executor,
675             )
676         )
677     finally:
678         shutdown(loop)
679         if executor is not None:
680             executor.shutdown()
681
682
683 async def schedule_formatting(
684     sources: Set[Path],
685     fast: bool,
686     write_back: WriteBack,
687     mode: Mode,
688     report: "Report",
689     loop: asyncio.AbstractEventLoop,
690     executor: Executor,
691 ) -> None:
692     """Run formatting of `sources` in parallel using the provided `executor`.
693
694     (Use ProcessPoolExecutors for actual parallelism.)
695
696     `write_back`, `fast`, and `mode` options are passed to
697     :func:`format_file_in_place`.
698     """
699     cache: Cache = {}
700     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
701         cache = read_cache(mode)
702         sources, cached = filter_cached(cache, sources)
703         for src in sorted(cached):
704             report.done(src, Changed.CACHED)
705     if not sources:
706         return
707
708     cancelled = []
709     sources_to_cache = []
710     lock = None
711     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
712         # For diff output, we need locks to ensure we don't interleave output
713         # from different processes.
714         manager = Manager()
715         lock = manager.Lock()
716     tasks = {
717         asyncio.ensure_future(
718             loop.run_in_executor(
719                 executor, format_file_in_place, src, fast, mode, write_back, lock
720             )
721         ): src
722         for src in sorted(sources)
723     }
724     pending = tasks.keys()
725     try:
726         loop.add_signal_handler(signal.SIGINT, cancel, pending)
727         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
728     except NotImplementedError:
729         # There are no good alternatives for these on Windows.
730         pass
731     while pending:
732         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
733         for task in done:
734             src = tasks.pop(task)
735             if task.cancelled():
736                 cancelled.append(task)
737             elif task.exception():
738                 report.failed(src, str(task.exception()))
739             else:
740                 changed = Changed.YES if task.result() else Changed.NO
741                 # If the file was written back or was successfully checked as
742                 # well-formatted, store this information in the cache.
743                 if write_back is WriteBack.YES or (
744                     write_back is WriteBack.CHECK and changed is Changed.NO
745                 ):
746                     sources_to_cache.append(src)
747                 report.done(src, changed)
748     if cancelled:
749         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
750     if sources_to_cache:
751         write_cache(cache, sources_to_cache, mode)
752
753
754 def format_file_in_place(
755     src: Path,
756     fast: bool,
757     mode: Mode,
758     write_back: WriteBack = WriteBack.NO,
759     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
760 ) -> bool:
761     """Format file under `src` path. Return True if changed.
762
763     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
764     code to the file.
765     `mode` and `fast` options are passed to :func:`format_file_contents`.
766     """
767     if src.suffix == ".pyi":
768         mode = replace(mode, is_pyi=True)
769     elif src.suffix == ".ipynb":
770         mode = replace(mode, is_ipynb=True)
771
772     then = datetime.utcfromtimestamp(src.stat().st_mtime)
773     with open(src, "rb") as buf:
774         src_contents, encoding, newline = decode_bytes(buf.read())
775     try:
776         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
777     except NothingChanged:
778         return False
779     except JSONDecodeError:
780         raise ValueError(f"File '{src}' cannot be parsed as valid Jupyter notebook.")
781
782     if write_back == WriteBack.YES:
783         with open(src, "w", encoding=encoding, newline=newline) as f:
784             f.write(dst_contents)
785     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
786         now = datetime.utcnow()
787         src_name = f"{src}\t{then} +0000"
788         dst_name = f"{src}\t{now} +0000"
789         if mode.is_ipynb:
790             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
791         else:
792             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
793
794         if write_back == WriteBack.COLOR_DIFF:
795             diff_contents = color_diff(diff_contents)
796
797         with lock or nullcontext():
798             f = io.TextIOWrapper(
799                 sys.stdout.buffer,
800                 encoding=encoding,
801                 newline=newline,
802                 write_through=True,
803             )
804             f = wrap_stream_for_windows(f)
805             f.write(diff_contents)
806             f.detach()
807
808     return True
809
810
811 def format_stdin_to_stdout(
812     fast: bool,
813     *,
814     content: Optional[str] = None,
815     write_back: WriteBack = WriteBack.NO,
816     mode: Mode,
817 ) -> bool:
818     """Format file on stdin. Return True if changed.
819
820     If content is None, it's read from sys.stdin.
821
822     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
823     write a diff to stdout. The `mode` argument is passed to
824     :func:`format_file_contents`.
825     """
826     then = datetime.utcnow()
827
828     if content is None:
829         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
830     else:
831         src, encoding, newline = content, "utf-8", ""
832
833     dst = src
834     try:
835         dst = format_file_contents(src, fast=fast, mode=mode)
836         return True
837
838     except NothingChanged:
839         return False
840
841     finally:
842         f = io.TextIOWrapper(
843             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
844         )
845         if write_back == WriteBack.YES:
846             # Make sure there's a newline after the content
847             if dst and dst[-1] != "\n":
848                 dst += "\n"
849             f.write(dst)
850         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
851             now = datetime.utcnow()
852             src_name = f"STDIN\t{then} +0000"
853             dst_name = f"STDOUT\t{now} +0000"
854             d = diff(src, dst, src_name, dst_name)
855             if write_back == WriteBack.COLOR_DIFF:
856                 d = color_diff(d)
857                 f = wrap_stream_for_windows(f)
858             f.write(d)
859         f.detach()
860
861
862 def check_stability_and_equivalence(
863     src_contents: str, dst_contents: str, *, mode: Mode
864 ) -> None:
865     """Perform stability and equivalence checks.
866
867     Raise AssertionError if source and destination contents are not
868     equivalent, or if a second pass of the formatter would format the
869     content differently.
870     """
871     assert_equivalent(src_contents, dst_contents)
872
873     # Forced second pass to work around optional trailing commas (becoming
874     # forced trailing commas on pass 2) interacting differently with optional
875     # parentheses.  Admittedly ugly.
876     dst_contents_pass2 = format_str(dst_contents, mode=mode)
877     if dst_contents != dst_contents_pass2:
878         dst_contents = dst_contents_pass2
879         assert_equivalent(src_contents, dst_contents, pass_num=2)
880         assert_stable(src_contents, dst_contents, mode=mode)
881     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
882     # the same as `dst_contents_pass2`.
883
884
885 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
886     """Reformat contents of a file and return new contents.
887
888     If `fast` is False, additionally confirm that the reformatted code is
889     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
890     `mode` is passed to :func:`format_str`.
891     """
892     if not src_contents.strip():
893         raise NothingChanged
894
895     if mode.is_ipynb:
896         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
897     else:
898         dst_contents = format_str(src_contents, mode=mode)
899     if src_contents == dst_contents:
900         raise NothingChanged
901
902     if not fast and not mode.is_ipynb:
903         # Jupyter notebooks will already have been checked above.
904         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
905     return dst_contents
906
907
908 def validate_cell(src: str) -> None:
909     """Check that cell does not already contain TransformerManager transformations.
910
911     If a cell contains ``!ls``, then it'll be transformed to
912     ``get_ipython().system('ls')``. However, if the cell originally contained
913     ``get_ipython().system('ls')``, then it would get transformed in the same way:
914
915         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
916         "get_ipython().system('ls')\n"
917         >>> TransformerManager().transform_cell("!ls")
918         "get_ipython().system('ls')\n"
919
920     Due to the impossibility of safely roundtripping in such situations, cells
921     containing transformed magics will be ignored.
922     """
923     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
924         raise NothingChanged
925
926
927 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
928     """Format code in given cell of Jupyter notebook.
929
930     General idea is:
931
932       - if cell has trailing semicolon, remove it;
933       - if cell has IPython magics, mask them;
934       - format cell;
935       - reinstate IPython magics;
936       - reinstate trailing semicolon (if originally present);
937       - strip trailing newlines.
938
939     Cells with syntax errors will not be processed, as they
940     could potentially be automagics or multi-line magics, which
941     are currently not supported.
942     """
943     validate_cell(src)
944     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
945         src
946     )
947     try:
948         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
949     except SyntaxError:
950         raise NothingChanged
951     masked_dst = format_str(masked_src, mode=mode)
952     if not fast:
953         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
954     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
955     dst = put_trailing_semicolon_back(
956         dst_without_trailing_semicolon, has_trailing_semicolon
957     )
958     dst = dst.rstrip("\n")
959     if dst == src:
960         raise NothingChanged
961     return dst
962
963
964 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
965     """If notebook is marked as non-Python, don't format it.
966
967     All notebook metadata fields are optional, see
968     https://nbformat.readthedocs.io/en/latest/format_description.html. So
969     if a notebook has empty metadata, we will try to parse it anyway.
970     """
971     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
972     if language is not None and language != "python":
973         raise NothingChanged
974
975
976 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
977     """Format Jupyter notebook.
978
979     Operate cell-by-cell, only on code cells, only for Python notebooks.
980     If the ``.ipynb`` originally had a trailing newline, it'll be preseved.
981     """
982     trailing_newline = src_contents[-1] == "\n"
983     modified = False
984     nb = json.loads(src_contents)
985     validate_metadata(nb)
986     for cell in nb["cells"]:
987         if cell.get("cell_type", None) == "code":
988             try:
989                 src = "".join(cell["source"])
990                 dst = format_cell(src, fast=fast, mode=mode)
991             except NothingChanged:
992                 pass
993             else:
994                 cell["source"] = dst.splitlines(keepends=True)
995                 modified = True
996     if modified:
997         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
998         if trailing_newline:
999             dst_contents = dst_contents + "\n"
1000         return dst_contents
1001     else:
1002         raise NothingChanged
1003
1004
1005 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1006     """Reformat a string and return new contents.
1007
1008     `mode` determines formatting options, such as how many characters per line are
1009     allowed.  Example:
1010
1011     >>> import black
1012     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1013     def f(arg: str = "") -> None:
1014         ...
1015
1016     A more complex example:
1017
1018     >>> print(
1019     ...   black.format_str(
1020     ...     "def f(arg:str='')->None: hey",
1021     ...     mode=black.Mode(
1022     ...       target_versions={black.TargetVersion.PY36},
1023     ...       line_length=10,
1024     ...       string_normalization=False,
1025     ...       is_pyi=False,
1026     ...     ),
1027     ...   ),
1028     ... )
1029     def f(
1030         arg: str = '',
1031     ) -> None:
1032         hey
1033
1034     """
1035     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1036     dst_contents = []
1037     future_imports = get_future_imports(src_node)
1038     if mode.target_versions:
1039         versions = mode.target_versions
1040     else:
1041         versions = detect_target_versions(src_node)
1042     normalize_fmt_off(src_node)
1043     lines = LineGenerator(
1044         mode=mode,
1045         remove_u_prefix="unicode_literals" in future_imports
1046         or supports_feature(versions, Feature.UNICODE_LITERALS),
1047     )
1048     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1049     empty_line = Line(mode=mode)
1050     after = 0
1051     split_line_features = {
1052         feature
1053         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1054         if supports_feature(versions, feature)
1055     }
1056     for current_line in lines.visit(src_node):
1057         dst_contents.append(str(empty_line) * after)
1058         before, after = elt.maybe_empty_lines(current_line)
1059         dst_contents.append(str(empty_line) * before)
1060         for line in transform_line(
1061             current_line, mode=mode, features=split_line_features
1062         ):
1063             dst_contents.append(str(line))
1064     return "".join(dst_contents)
1065
1066
1067 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1068     """Return a tuple of (decoded_contents, encoding, newline).
1069
1070     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1071     universal newlines (i.e. only contains LF).
1072     """
1073     srcbuf = io.BytesIO(src)
1074     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1075     if not lines:
1076         return "", encoding, "\n"
1077
1078     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1079     srcbuf.seek(0)
1080     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1081         return tiow.read(), encoding, newline
1082
1083
1084 def get_features_used(node: Node) -> Set[Feature]:
1085     """Return a set of (relatively) new Python features used in this file.
1086
1087     Currently looking for:
1088     - f-strings;
1089     - underscores in numeric literals;
1090     - trailing commas after * or ** in function signatures and calls;
1091     - positional only arguments in function signatures and lambdas;
1092     - assignment expression;
1093     - relaxed decorator syntax;
1094     """
1095     features: Set[Feature] = set()
1096     for n in node.pre_order():
1097         if n.type == token.STRING:
1098             value_head = n.value[:2]  # type: ignore
1099             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1100                 features.add(Feature.F_STRINGS)
1101
1102         elif n.type == token.NUMBER:
1103             if "_" in n.value:  # type: ignore
1104                 features.add(Feature.NUMERIC_UNDERSCORES)
1105
1106         elif n.type == token.SLASH:
1107             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
1108                 features.add(Feature.POS_ONLY_ARGUMENTS)
1109
1110         elif n.type == token.COLONEQUAL:
1111             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1112
1113         elif n.type == syms.decorator:
1114             if len(n.children) > 1 and not is_simple_decorator_expression(
1115                 n.children[1]
1116             ):
1117                 features.add(Feature.RELAXED_DECORATORS)
1118
1119         elif (
1120             n.type in {syms.typedargslist, syms.arglist}
1121             and n.children
1122             and n.children[-1].type == token.COMMA
1123         ):
1124             if n.type == syms.typedargslist:
1125                 feature = Feature.TRAILING_COMMA_IN_DEF
1126             else:
1127                 feature = Feature.TRAILING_COMMA_IN_CALL
1128
1129             for ch in n.children:
1130                 if ch.type in STARS:
1131                     features.add(feature)
1132
1133                 if ch.type == syms.argument:
1134                     for argch in ch.children:
1135                         if argch.type in STARS:
1136                             features.add(feature)
1137
1138     return features
1139
1140
1141 def detect_target_versions(node: Node) -> Set[TargetVersion]:
1142     """Detect the version to target based on the nodes used."""
1143     features = get_features_used(node)
1144     return {
1145         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1146     }
1147
1148
1149 def get_future_imports(node: Node) -> Set[str]:
1150     """Return a set of __future__ imports in the file."""
1151     imports: Set[str] = set()
1152
1153     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1154         for child in children:
1155             if isinstance(child, Leaf):
1156                 if child.type == token.NAME:
1157                     yield child.value
1158
1159             elif child.type == syms.import_as_name:
1160                 orig_name = child.children[0]
1161                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1162                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1163                 yield orig_name.value
1164
1165             elif child.type == syms.import_as_names:
1166                 yield from get_imports_from_children(child.children)
1167
1168             else:
1169                 raise AssertionError("Invalid syntax parsing imports")
1170
1171     for child in node.children:
1172         if child.type != syms.simple_stmt:
1173             break
1174
1175         first_child = child.children[0]
1176         if isinstance(first_child, Leaf):
1177             # Continue looking if we see a docstring; otherwise stop.
1178             if (
1179                 len(child.children) == 2
1180                 and first_child.type == token.STRING
1181                 and child.children[1].type == token.NEWLINE
1182             ):
1183                 continue
1184
1185             break
1186
1187         elif first_child.type == syms.import_from:
1188             module_name = first_child.children[1]
1189             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1190                 break
1191
1192             imports |= set(get_imports_from_children(first_child.children[3:]))
1193         else:
1194             break
1195
1196     return imports
1197
1198
1199 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1200     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1201     try:
1202         src_ast = parse_ast(src)
1203     except Exception as exc:
1204         raise AssertionError(
1205             "cannot use --safe with this file; failed to parse source file.  AST"
1206             f" error message: {exc}"
1207         )
1208
1209     try:
1210         dst_ast = parse_ast(dst)
1211     except Exception as exc:
1212         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1213         raise AssertionError(
1214             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1215             "Please report a bug on https://github.com/psf/black/issues.  "
1216             f"This invalid output might be helpful: {log}"
1217         ) from None
1218
1219     src_ast_str = "\n".join(stringify_ast(src_ast))
1220     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1221     if src_ast_str != dst_ast_str:
1222         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1223         raise AssertionError(
1224             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1225             f" source on pass {pass_num}.  Please report a bug on "
1226             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1227         ) from None
1228
1229
1230 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1231     """Raise AssertionError if `dst` reformats differently the second time."""
1232     newdst = format_str(dst, mode=mode)
1233     if dst != newdst:
1234         log = dump_to_file(
1235             str(mode),
1236             diff(src, dst, "source", "first pass"),
1237             diff(dst, newdst, "first pass", "second pass"),
1238         )
1239         raise AssertionError(
1240             "INTERNAL ERROR: Black produced different code on the second pass of the"
1241             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1242             f"  This diff might be helpful: {log}"
1243         ) from None
1244
1245
1246 @contextmanager
1247 def nullcontext() -> Iterator[None]:
1248     """Return an empty context manager.
1249
1250     To be used like `nullcontext` in Python 3.7.
1251     """
1252     yield
1253
1254
1255 def patch_click() -> None:
1256     """Make Click not crash on Python 3.6 with LANG=C.
1257
1258     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1259     default which restricts paths that it can access during the lifetime of the
1260     application.  Click refuses to work in this scenario by raising a RuntimeError.
1261
1262     In case of Black the likelihood that non-ASCII characters are going to be used in
1263     file paths is minimal since it's Python source code.  Moreover, this crash was
1264     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1265     """
1266     try:
1267         from click import core
1268         from click import _unicodefun  # type: ignore
1269     except ModuleNotFoundError:
1270         return
1271
1272     for module in (core, _unicodefun):
1273         if hasattr(module, "_verify_python3_env"):
1274             module._verify_python3_env = lambda: None  # type: ignore
1275         if hasattr(module, "_verify_python_env"):
1276             module._verify_python_env = lambda: None  # type: ignore
1277
1278
1279 def patched_main() -> None:
1280     maybe_install_uvloop()
1281     freeze_support()
1282     patch_click()
1283     main()
1284
1285
1286 if __name__ == "__main__":
1287     patched_main()