]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Jupyter notebook support (#2357)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 import regex as re
13 import signal
14 import sys
15 import tokenize
16 import traceback
17 from typing import (
18     Any,
19     Dict,
20     Generator,
21     Iterator,
22     List,
23     MutableMapping,
24     Optional,
25     Pattern,
26     Set,
27     Sized,
28     Tuple,
29     Union,
30 )
31
32 from dataclasses import replace
33 import click
34
35 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
36 from black.const import STDIN_PLACEHOLDER
37 from black.nodes import STARS, syms, is_simple_decorator_expression
38 from black.lines import Line, EmptyLineTracker
39 from black.linegen import transform_line, LineGenerator, LN
40 from black.comments import normalize_fmt_off
41 from black.mode import Mode, TargetVersion
42 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
43 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
44 from black.concurrency import cancel, shutdown, maybe_install_uvloop
45 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
46 from black.report import Report, Changed, NothingChanged
47 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
48 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
49 from black.files import wrap_stream_for_windows
50 from black.parsing import InvalidInput  # noqa F401
51 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
52 from black.handle_ipynb_magics import (
53     mask_cell,
54     unmask_cell,
55     remove_trailing_semicolon,
56     put_trailing_semicolon_back,
57     TRANSFORMED_MAGICS,
58     jupyter_dependencies_are_installed,
59 )
60
61
62 # lib2to3 fork
63 from blib2to3.pytree import Node, Leaf
64 from blib2to3.pgen2 import token
65
66 from _black_version import version as __version__
67
68 # types
69 FileContent = str
70 Encoding = str
71 NewLine = str
72
73
74 class WriteBack(Enum):
75     NO = 0
76     YES = 1
77     DIFF = 2
78     CHECK = 3
79     COLOR_DIFF = 4
80
81     @classmethod
82     def from_configuration(
83         cls, *, check: bool, diff: bool, color: bool = False
84     ) -> "WriteBack":
85         if check and not diff:
86             return cls.CHECK
87
88         if diff and color:
89             return cls.COLOR_DIFF
90
91         return cls.DIFF if diff else cls.YES
92
93
94 # Legacy name, left for integrations.
95 FileMode = Mode
96
97
98 def read_pyproject_toml(
99     ctx: click.Context, param: click.Parameter, value: Optional[str]
100 ) -> Optional[str]:
101     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
102
103     Returns the path to a successfully found and read configuration file, None
104     otherwise.
105     """
106     if not value:
107         value = find_pyproject_toml(ctx.params.get("src", ()))
108         if value is None:
109             return None
110
111     try:
112         config = parse_pyproject_toml(value)
113     except (OSError, ValueError) as e:
114         raise click.FileError(
115             filename=value, hint=f"Error reading configuration file: {e}"
116         )
117
118     if not config:
119         return None
120     else:
121         # Sanitize the values to be Click friendly. For more information please see:
122         # https://github.com/psf/black/issues/1458
123         # https://github.com/pallets/click/issues/1567
124         config = {
125             k: str(v) if not isinstance(v, (list, dict)) else v
126             for k, v in config.items()
127         }
128
129     target_version = config.get("target_version")
130     if target_version is not None and not isinstance(target_version, list):
131         raise click.BadOptionUsage(
132             "target-version", "Config key target-version must be a list"
133         )
134
135     default_map: Dict[str, Any] = {}
136     if ctx.default_map:
137         default_map.update(ctx.default_map)
138     default_map.update(config)
139
140     ctx.default_map = default_map
141     return value
142
143
144 def target_version_option_callback(
145     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
146 ) -> List[TargetVersion]:
147     """Compute the target versions from a --target-version flag.
148
149     This is its own function because mypy couldn't infer the type correctly
150     when it was a lambda, causing mypyc trouble.
151     """
152     return [TargetVersion[val.upper()] for val in v]
153
154
155 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
156     """Compile a regular expression string in `regex`.
157
158     If it contains newlines, use verbose mode.
159     """
160     if "\n" in regex:
161         regex = "(?x)" + regex
162     compiled: Pattern[str] = re.compile(regex)
163     return compiled
164
165
166 def validate_regex(
167     ctx: click.Context,
168     param: click.Parameter,
169     value: Optional[str],
170 ) -> Optional[Pattern]:
171     try:
172         return re_compile_maybe_verbose(value) if value is not None else None
173     except re.error:
174         raise click.BadParameter("Not a valid regular expression")
175
176
177 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
178 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
179 @click.option(
180     "-l",
181     "--line-length",
182     type=int,
183     default=DEFAULT_LINE_LENGTH,
184     help="How many characters per line to allow.",
185     show_default=True,
186 )
187 @click.option(
188     "-t",
189     "--target-version",
190     type=click.Choice([v.name.lower() for v in TargetVersion]),
191     callback=target_version_option_callback,
192     multiple=True,
193     help=(
194         "Python versions that should be supported by Black's output. [default: per-file"
195         " auto-detection]"
196     ),
197 )
198 @click.option(
199     "--pyi",
200     is_flag=True,
201     help=(
202         "Format all input files like typing stubs regardless of file extension (useful"
203         " when piping source on standard input)."
204     ),
205 )
206 @click.option(
207     "--ipynb",
208     is_flag=True,
209     help=(
210         "Format all input files like Jupyter Notebooks regardless of file extension "
211         "(useful when piping source on standard input)."
212     ),
213 )
214 @click.option(
215     "-S",
216     "--skip-string-normalization",
217     is_flag=True,
218     help="Don't normalize string quotes or prefixes.",
219 )
220 @click.option(
221     "-C",
222     "--skip-magic-trailing-comma",
223     is_flag=True,
224     help="Don't use trailing commas as a reason to split lines.",
225 )
226 @click.option(
227     "--experimental-string-processing",
228     is_flag=True,
229     hidden=True,
230     help=(
231         "Experimental option that performs more normalization on string literals."
232         " Currently disabled because it leads to some crashes."
233     ),
234 )
235 @click.option(
236     "--check",
237     is_flag=True,
238     help=(
239         "Don't write the files back, just return the status. Return code 0 means"
240         " nothing would change. Return code 1 means some files would be reformatted."
241         " Return code 123 means there was an internal error."
242     ),
243 )
244 @click.option(
245     "--diff",
246     is_flag=True,
247     help="Don't write the files back, just output a diff for each file on stdout.",
248 )
249 @click.option(
250     "--color/--no-color",
251     is_flag=True,
252     help="Show colored diff. Only applies when `--diff` is given.",
253 )
254 @click.option(
255     "--fast/--safe",
256     is_flag=True,
257     help="If --fast given, skip temporary sanity checks. [default: --safe]",
258 )
259 @click.option(
260     "--required-version",
261     type=str,
262     help=(
263         "Require a specific version of Black to be running (useful for unifying results"
264         " across many environments e.g. with a pyproject.toml file)."
265     ),
266 )
267 @click.option(
268     "--include",
269     type=str,
270     default=DEFAULT_INCLUDES,
271     callback=validate_regex,
272     help=(
273         "A regular expression that matches files and directories that should be"
274         " included on recursive searches. An empty value means all files are included"
275         " regardless of the name. Use forward slashes for directories on all platforms"
276         " (Windows, too). Exclusions are calculated first, inclusions later."
277     ),
278     show_default=True,
279 )
280 @click.option(
281     "--exclude",
282     type=str,
283     callback=validate_regex,
284     help=(
285         "A regular expression that matches files and directories that should be"
286         " excluded on recursive searches. An empty value means no paths are excluded."
287         " Use forward slashes for directories on all platforms (Windows, too)."
288         " Exclusions are calculated first, inclusions later. [default:"
289         f" {DEFAULT_EXCLUDES}]"
290     ),
291     show_default=False,
292 )
293 @click.option(
294     "--extend-exclude",
295     type=str,
296     callback=validate_regex,
297     help=(
298         "Like --exclude, but adds additional files and directories on top of the"
299         " excluded ones. (Useful if you simply want to add to the default)"
300     ),
301 )
302 @click.option(
303     "--force-exclude",
304     type=str,
305     callback=validate_regex,
306     help=(
307         "Like --exclude, but files and directories matching this regex will be "
308         "excluded even when they are passed explicitly as arguments."
309     ),
310 )
311 @click.option(
312     "--stdin-filename",
313     type=str,
314     help=(
315         "The name of the file when passing it through stdin. Useful to make "
316         "sure Black will respect --force-exclude option on some "
317         "editors that rely on using stdin."
318     ),
319 )
320 @click.option(
321     "-q",
322     "--quiet",
323     is_flag=True,
324     help=(
325         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
326         " those with 2>/dev/null."
327     ),
328 )
329 @click.option(
330     "-v",
331     "--verbose",
332     is_flag=True,
333     help=(
334         "Also emit messages to stderr about files that were not changed or were ignored"
335         " due to exclusion patterns."
336     ),
337 )
338 @click.version_option(version=__version__)
339 @click.argument(
340     "src",
341     nargs=-1,
342     type=click.Path(
343         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
344     ),
345     is_eager=True,
346     metavar="SRC ...",
347 )
348 @click.option(
349     "--config",
350     type=click.Path(
351         exists=True,
352         file_okay=True,
353         dir_okay=False,
354         readable=True,
355         allow_dash=False,
356         path_type=str,
357     ),
358     is_eager=True,
359     callback=read_pyproject_toml,
360     help="Read configuration from FILE path.",
361 )
362 @click.pass_context
363 def main(
364     ctx: click.Context,
365     code: Optional[str],
366     line_length: int,
367     target_version: List[TargetVersion],
368     check: bool,
369     diff: bool,
370     color: bool,
371     fast: bool,
372     pyi: bool,
373     ipynb: bool,
374     skip_string_normalization: bool,
375     skip_magic_trailing_comma: bool,
376     experimental_string_processing: bool,
377     quiet: bool,
378     verbose: bool,
379     required_version: str,
380     include: Pattern,
381     exclude: Optional[Pattern],
382     extend_exclude: Optional[Pattern],
383     force_exclude: Optional[Pattern],
384     stdin_filename: Optional[str],
385     src: Tuple[str, ...],
386     config: Optional[str],
387 ) -> None:
388     """The uncompromising code formatter."""
389     if config and verbose:
390         out(f"Using configuration from {config}.", bold=False, fg="blue")
391
392     error_msg = "Oh no! 💥 💔 💥"
393     if required_version and required_version != __version__:
394         err(
395             f"{error_msg} The required version `{required_version}` does not match"
396             f" the running version `{__version__}`!"
397         )
398         ctx.exit(1)
399     if ipynb and pyi:
400         err("Cannot pass both `pyi` and `ipynb` flags!")
401         ctx.exit(1)
402
403     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
404     if target_version:
405         versions = set(target_version)
406     else:
407         # We'll autodetect later.
408         versions = set()
409     mode = Mode(
410         target_versions=versions,
411         line_length=line_length,
412         is_pyi=pyi,
413         is_ipynb=ipynb,
414         string_normalization=not skip_string_normalization,
415         magic_trailing_comma=not skip_magic_trailing_comma,
416         experimental_string_processing=experimental_string_processing,
417     )
418
419     if code is not None:
420         # Run in quiet mode by default with -c; the extra output isn't useful.
421         # You can still pass -v to get verbose output.
422         quiet = True
423
424     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
425
426     if code is not None:
427         reformat_code(
428             content=code, fast=fast, write_back=write_back, mode=mode, report=report
429         )
430     else:
431         sources = get_sources(
432             ctx=ctx,
433             src=src,
434             quiet=quiet,
435             verbose=verbose,
436             include=include,
437             exclude=exclude,
438             extend_exclude=extend_exclude,
439             force_exclude=force_exclude,
440             report=report,
441             stdin_filename=stdin_filename,
442         )
443
444         path_empty(
445             sources,
446             "No Python files are present to be formatted. Nothing to do 😴",
447             quiet,
448             verbose,
449             ctx,
450         )
451
452         if len(sources) == 1:
453             reformat_one(
454                 src=sources.pop(),
455                 fast=fast,
456                 write_back=write_back,
457                 mode=mode,
458                 report=report,
459             )
460         else:
461             reformat_many(
462                 sources=sources,
463                 fast=fast,
464                 write_back=write_back,
465                 mode=mode,
466                 report=report,
467             )
468
469     if verbose or not quiet:
470         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
471         if code is None:
472             click.echo(str(report), err=True)
473     ctx.exit(report.return_code)
474
475
476 def get_sources(
477     *,
478     ctx: click.Context,
479     src: Tuple[str, ...],
480     quiet: bool,
481     verbose: bool,
482     include: Pattern[str],
483     exclude: Optional[Pattern[str]],
484     extend_exclude: Optional[Pattern[str]],
485     force_exclude: Optional[Pattern[str]],
486     report: "Report",
487     stdin_filename: Optional[str],
488 ) -> Set[Path]:
489     """Compute the set of files to be formatted."""
490
491     root = find_project_root(src)
492     sources: Set[Path] = set()
493     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
494
495     if exclude is None:
496         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
497         gitignore = get_gitignore(root)
498     else:
499         gitignore = None
500
501     for s in src:
502         if s == "-" and stdin_filename:
503             p = Path(stdin_filename)
504             is_stdin = True
505         else:
506             p = Path(s)
507             is_stdin = False
508
509         if is_stdin or p.is_file():
510             normalized_path = normalize_path_maybe_ignore(p, root, report)
511             if normalized_path is None:
512                 continue
513
514             normalized_path = "/" + normalized_path
515             # Hard-exclude any files that matches the `--force-exclude` regex.
516             if force_exclude:
517                 force_exclude_match = force_exclude.search(normalized_path)
518             else:
519                 force_exclude_match = None
520             if force_exclude_match and force_exclude_match.group(0):
521                 report.path_ignored(p, "matches the --force-exclude regular expression")
522                 continue
523
524             if is_stdin:
525                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
526
527             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
528                 verbose=verbose, quiet=quiet
529             ):
530                 continue
531
532             sources.add(p)
533         elif p.is_dir():
534             sources.update(
535                 gen_python_files(
536                     p.iterdir(),
537                     root,
538                     include,
539                     exclude,
540                     extend_exclude,
541                     force_exclude,
542                     report,
543                     gitignore,
544                     verbose=verbose,
545                     quiet=quiet,
546                 )
547             )
548         elif s == "-":
549             sources.add(p)
550         else:
551             err(f"invalid path: {s}")
552     return sources
553
554
555 def path_empty(
556     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
557 ) -> None:
558     """
559     Exit if there is no `src` provided for formatting
560     """
561     if not src:
562         if verbose or not quiet:
563             out(msg)
564         ctx.exit(0)
565
566
567 def reformat_code(
568     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
569 ) -> None:
570     """
571     Reformat and print out `content` without spawning child processes.
572     Similar to `reformat_one`, but for string content.
573
574     `fast`, `write_back`, and `mode` options are passed to
575     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
576     """
577     path = Path("<string>")
578     try:
579         changed = Changed.NO
580         if format_stdin_to_stdout(
581             content=content, fast=fast, write_back=write_back, mode=mode
582         ):
583             changed = Changed.YES
584         report.done(path, changed)
585     except Exception as exc:
586         if report.verbose:
587             traceback.print_exc()
588         report.failed(path, str(exc))
589
590
591 def reformat_one(
592     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
593 ) -> None:
594     """Reformat a single file under `src` without spawning child processes.
595
596     `fast`, `write_back`, and `mode` options are passed to
597     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
598     """
599     try:
600         changed = Changed.NO
601
602         if str(src) == "-":
603             is_stdin = True
604         elif str(src).startswith(STDIN_PLACEHOLDER):
605             is_stdin = True
606             # Use the original name again in case we want to print something
607             # to the user
608             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
609         else:
610             is_stdin = False
611
612         if is_stdin:
613             if src.suffix == ".pyi":
614                 mode = replace(mode, is_pyi=True)
615             elif src.suffix == ".ipynb":
616                 mode = replace(mode, is_ipynb=True)
617             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
618                 changed = Changed.YES
619         else:
620             cache: Cache = {}
621             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
622                 cache = read_cache(mode)
623                 res_src = src.resolve()
624                 res_src_s = str(res_src)
625                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
626                     changed = Changed.CACHED
627             if changed is not Changed.CACHED and format_file_in_place(
628                 src, fast=fast, write_back=write_back, mode=mode
629             ):
630                 changed = Changed.YES
631             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
632                 write_back is WriteBack.CHECK and changed is Changed.NO
633             ):
634                 write_cache(cache, [src], mode)
635         report.done(src, changed)
636     except Exception as exc:
637         if report.verbose:
638             traceback.print_exc()
639         report.failed(src, str(exc))
640
641
642 def reformat_many(
643     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
644 ) -> None:
645     """Reformat multiple files using a ProcessPoolExecutor."""
646     executor: Executor
647     loop = asyncio.get_event_loop()
648     worker_count = os.cpu_count()
649     if sys.platform == "win32":
650         # Work around https://bugs.python.org/issue26903
651         worker_count = min(worker_count, 60)
652     try:
653         executor = ProcessPoolExecutor(max_workers=worker_count)
654     except (ImportError, OSError):
655         # we arrive here if the underlying system does not support multi-processing
656         # like in AWS Lambda or Termux, in which case we gracefully fallback to
657         # a ThreadPoolExecutor with just a single worker (more workers would not do us
658         # any good due to the Global Interpreter Lock)
659         executor = ThreadPoolExecutor(max_workers=1)
660
661     try:
662         loop.run_until_complete(
663             schedule_formatting(
664                 sources=sources,
665                 fast=fast,
666                 write_back=write_back,
667                 mode=mode,
668                 report=report,
669                 loop=loop,
670                 executor=executor,
671             )
672         )
673     finally:
674         shutdown(loop)
675         if executor is not None:
676             executor.shutdown()
677
678
679 async def schedule_formatting(
680     sources: Set[Path],
681     fast: bool,
682     write_back: WriteBack,
683     mode: Mode,
684     report: "Report",
685     loop: asyncio.AbstractEventLoop,
686     executor: Executor,
687 ) -> None:
688     """Run formatting of `sources` in parallel using the provided `executor`.
689
690     (Use ProcessPoolExecutors for actual parallelism.)
691
692     `write_back`, `fast`, and `mode` options are passed to
693     :func:`format_file_in_place`.
694     """
695     cache: Cache = {}
696     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
697         cache = read_cache(mode)
698         sources, cached = filter_cached(cache, sources)
699         for src in sorted(cached):
700             report.done(src, Changed.CACHED)
701     if not sources:
702         return
703
704     cancelled = []
705     sources_to_cache = []
706     lock = None
707     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
708         # For diff output, we need locks to ensure we don't interleave output
709         # from different processes.
710         manager = Manager()
711         lock = manager.Lock()
712     tasks = {
713         asyncio.ensure_future(
714             loop.run_in_executor(
715                 executor, format_file_in_place, src, fast, mode, write_back, lock
716             )
717         ): src
718         for src in sorted(sources)
719     }
720     pending = tasks.keys()
721     try:
722         loop.add_signal_handler(signal.SIGINT, cancel, pending)
723         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
724     except NotImplementedError:
725         # There are no good alternatives for these on Windows.
726         pass
727     while pending:
728         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
729         for task in done:
730             src = tasks.pop(task)
731             if task.cancelled():
732                 cancelled.append(task)
733             elif task.exception():
734                 report.failed(src, str(task.exception()))
735             else:
736                 changed = Changed.YES if task.result() else Changed.NO
737                 # If the file was written back or was successfully checked as
738                 # well-formatted, store this information in the cache.
739                 if write_back is WriteBack.YES or (
740                     write_back is WriteBack.CHECK and changed is Changed.NO
741                 ):
742                     sources_to_cache.append(src)
743                 report.done(src, changed)
744     if cancelled:
745         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
746     if sources_to_cache:
747         write_cache(cache, sources_to_cache, mode)
748
749
750 def format_file_in_place(
751     src: Path,
752     fast: bool,
753     mode: Mode,
754     write_back: WriteBack = WriteBack.NO,
755     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
756 ) -> bool:
757     """Format file under `src` path. Return True if changed.
758
759     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
760     code to the file.
761     `mode` and `fast` options are passed to :func:`format_file_contents`.
762     """
763     if src.suffix == ".pyi":
764         mode = replace(mode, is_pyi=True)
765     elif src.suffix == ".ipynb":
766         mode = replace(mode, is_ipynb=True)
767
768     then = datetime.utcfromtimestamp(src.stat().st_mtime)
769     with open(src, "rb") as buf:
770         src_contents, encoding, newline = decode_bytes(buf.read())
771     try:
772         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
773     except NothingChanged:
774         return False
775     except JSONDecodeError:
776         raise ValueError(f"File '{src}' cannot be parsed as valid Jupyter notebook.")
777
778     if write_back == WriteBack.YES:
779         with open(src, "w", encoding=encoding, newline=newline) as f:
780             f.write(dst_contents)
781     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
782         now = datetime.utcnow()
783         src_name = f"{src}\t{then} +0000"
784         dst_name = f"{src}\t{now} +0000"
785         if mode.is_ipynb:
786             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
787         else:
788             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
789
790         if write_back == WriteBack.COLOR_DIFF:
791             diff_contents = color_diff(diff_contents)
792
793         with lock or nullcontext():
794             f = io.TextIOWrapper(
795                 sys.stdout.buffer,
796                 encoding=encoding,
797                 newline=newline,
798                 write_through=True,
799             )
800             f = wrap_stream_for_windows(f)
801             f.write(diff_contents)
802             f.detach()
803
804     return True
805
806
807 def format_stdin_to_stdout(
808     fast: bool,
809     *,
810     content: Optional[str] = None,
811     write_back: WriteBack = WriteBack.NO,
812     mode: Mode,
813 ) -> bool:
814     """Format file on stdin. Return True if changed.
815
816     If content is None, it's read from sys.stdin.
817
818     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
819     write a diff to stdout. The `mode` argument is passed to
820     :func:`format_file_contents`.
821     """
822     then = datetime.utcnow()
823
824     if content is None:
825         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
826     else:
827         src, encoding, newline = content, "utf-8", ""
828
829     dst = src
830     try:
831         dst = format_file_contents(src, fast=fast, mode=mode)
832         return True
833
834     except NothingChanged:
835         return False
836
837     finally:
838         f = io.TextIOWrapper(
839             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
840         )
841         if write_back == WriteBack.YES:
842             # Make sure there's a newline after the content
843             if dst and dst[-1] != "\n":
844                 dst += "\n"
845             f.write(dst)
846         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
847             now = datetime.utcnow()
848             src_name = f"STDIN\t{then} +0000"
849             dst_name = f"STDOUT\t{now} +0000"
850             d = diff(src, dst, src_name, dst_name)
851             if write_back == WriteBack.COLOR_DIFF:
852                 d = color_diff(d)
853                 f = wrap_stream_for_windows(f)
854             f.write(d)
855         f.detach()
856
857
858 def check_stability_and_equivalence(
859     src_contents: str, dst_contents: str, *, mode: Mode
860 ) -> None:
861     """Perform stability and equivalence checks.
862
863     Raise AssertionError if source and destination contents are not
864     equivalent, or if a second pass of the formatter would format the
865     content differently.
866     """
867     assert_equivalent(src_contents, dst_contents)
868
869     # Forced second pass to work around optional trailing commas (becoming
870     # forced trailing commas on pass 2) interacting differently with optional
871     # parentheses.  Admittedly ugly.
872     dst_contents_pass2 = format_str(dst_contents, mode=mode)
873     if dst_contents != dst_contents_pass2:
874         dst_contents = dst_contents_pass2
875         assert_equivalent(src_contents, dst_contents, pass_num=2)
876         assert_stable(src_contents, dst_contents, mode=mode)
877     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
878     # the same as `dst_contents_pass2`.
879
880
881 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
882     """Reformat contents of a file and return new contents.
883
884     If `fast` is False, additionally confirm that the reformatted code is
885     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
886     `mode` is passed to :func:`format_str`.
887     """
888     if not src_contents.strip():
889         raise NothingChanged
890
891     if mode.is_ipynb:
892         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
893     else:
894         dst_contents = format_str(src_contents, mode=mode)
895     if src_contents == dst_contents:
896         raise NothingChanged
897
898     if not fast and not mode.is_ipynb:
899         # Jupyter notebooks will already have been checked above.
900         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
901     return dst_contents
902
903
904 def validate_cell(src: str) -> None:
905     """Check that cell does not already contain TransformerManager transformations.
906
907     If a cell contains ``!ls``, then it'll be transformed to
908     ``get_ipython().system('ls')``. However, if the cell originally contained
909     ``get_ipython().system('ls')``, then it would get transformed in the same way:
910
911         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
912         "get_ipython().system('ls')\n"
913         >>> TransformerManager().transform_cell("!ls")
914         "get_ipython().system('ls')\n"
915
916     Due to the impossibility of safely roundtripping in such situations, cells
917     containing transformed magics will be ignored.
918     """
919     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
920         raise NothingChanged
921
922
923 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
924     """Format code in given cell of Jupyter notebook.
925
926     General idea is:
927
928       - if cell has trailing semicolon, remove it;
929       - if cell has IPython magics, mask them;
930       - format cell;
931       - reinstate IPython magics;
932       - reinstate trailing semicolon (if originally present);
933       - strip trailing newlines.
934
935     Cells with syntax errors will not be processed, as they
936     could potentially be automagics or multi-line magics, which
937     are currently not supported.
938     """
939     validate_cell(src)
940     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
941         src
942     )
943     try:
944         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
945     except SyntaxError:
946         raise NothingChanged
947     masked_dst = format_str(masked_src, mode=mode)
948     if not fast:
949         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
950     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
951     dst = put_trailing_semicolon_back(
952         dst_without_trailing_semicolon, has_trailing_semicolon
953     )
954     dst = dst.rstrip("\n")
955     if dst == src:
956         raise NothingChanged
957     return dst
958
959
960 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
961     """If notebook is marked as non-Python, don't format it.
962
963     All notebook metadata fields are optional, see
964     https://nbformat.readthedocs.io/en/latest/format_description.html. So
965     if a notebook has empty metadata, we will try to parse it anyway.
966     """
967     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
968     if language is not None and language != "python":
969         raise NothingChanged
970
971
972 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
973     """Format Jupyter notebook.
974
975     Operate cell-by-cell, only on code cells, only for Python notebooks.
976     If the ``.ipynb`` originally had a trailing newline, it'll be preseved.
977     """
978     trailing_newline = src_contents[-1] == "\n"
979     modified = False
980     nb = json.loads(src_contents)
981     validate_metadata(nb)
982     for cell in nb["cells"]:
983         if cell.get("cell_type", None) == "code":
984             try:
985                 src = "".join(cell["source"])
986                 dst = format_cell(src, fast=fast, mode=mode)
987             except NothingChanged:
988                 pass
989             else:
990                 cell["source"] = dst.splitlines(keepends=True)
991                 modified = True
992     if modified:
993         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
994         if trailing_newline:
995             dst_contents = dst_contents + "\n"
996         return dst_contents
997     else:
998         raise NothingChanged
999
1000
1001 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1002     """Reformat a string and return new contents.
1003
1004     `mode` determines formatting options, such as how many characters per line are
1005     allowed.  Example:
1006
1007     >>> import black
1008     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1009     def f(arg: str = "") -> None:
1010         ...
1011
1012     A more complex example:
1013
1014     >>> print(
1015     ...   black.format_str(
1016     ...     "def f(arg:str='')->None: hey",
1017     ...     mode=black.Mode(
1018     ...       target_versions={black.TargetVersion.PY36},
1019     ...       line_length=10,
1020     ...       string_normalization=False,
1021     ...       is_pyi=False,
1022     ...     ),
1023     ...   ),
1024     ... )
1025     def f(
1026         arg: str = '',
1027     ) -> None:
1028         hey
1029
1030     """
1031     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1032     dst_contents = []
1033     future_imports = get_future_imports(src_node)
1034     if mode.target_versions:
1035         versions = mode.target_versions
1036     else:
1037         versions = detect_target_versions(src_node)
1038     normalize_fmt_off(src_node)
1039     lines = LineGenerator(
1040         mode=mode,
1041         remove_u_prefix="unicode_literals" in future_imports
1042         or supports_feature(versions, Feature.UNICODE_LITERALS),
1043     )
1044     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1045     empty_line = Line(mode=mode)
1046     after = 0
1047     split_line_features = {
1048         feature
1049         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1050         if supports_feature(versions, feature)
1051     }
1052     for current_line in lines.visit(src_node):
1053         dst_contents.append(str(empty_line) * after)
1054         before, after = elt.maybe_empty_lines(current_line)
1055         dst_contents.append(str(empty_line) * before)
1056         for line in transform_line(
1057             current_line, mode=mode, features=split_line_features
1058         ):
1059             dst_contents.append(str(line))
1060     return "".join(dst_contents)
1061
1062
1063 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1064     """Return a tuple of (decoded_contents, encoding, newline).
1065
1066     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1067     universal newlines (i.e. only contains LF).
1068     """
1069     srcbuf = io.BytesIO(src)
1070     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1071     if not lines:
1072         return "", encoding, "\n"
1073
1074     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1075     srcbuf.seek(0)
1076     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1077         return tiow.read(), encoding, newline
1078
1079
1080 def get_features_used(node: Node) -> Set[Feature]:
1081     """Return a set of (relatively) new Python features used in this file.
1082
1083     Currently looking for:
1084     - f-strings;
1085     - underscores in numeric literals;
1086     - trailing commas after * or ** in function signatures and calls;
1087     - positional only arguments in function signatures and lambdas;
1088     - assignment expression;
1089     - relaxed decorator syntax;
1090     """
1091     features: Set[Feature] = set()
1092     for n in node.pre_order():
1093         if n.type == token.STRING:
1094             value_head = n.value[:2]  # type: ignore
1095             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1096                 features.add(Feature.F_STRINGS)
1097
1098         elif n.type == token.NUMBER:
1099             if "_" in n.value:  # type: ignore
1100                 features.add(Feature.NUMERIC_UNDERSCORES)
1101
1102         elif n.type == token.SLASH:
1103             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
1104                 features.add(Feature.POS_ONLY_ARGUMENTS)
1105
1106         elif n.type == token.COLONEQUAL:
1107             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1108
1109         elif n.type == syms.decorator:
1110             if len(n.children) > 1 and not is_simple_decorator_expression(
1111                 n.children[1]
1112             ):
1113                 features.add(Feature.RELAXED_DECORATORS)
1114
1115         elif (
1116             n.type in {syms.typedargslist, syms.arglist}
1117             and n.children
1118             and n.children[-1].type == token.COMMA
1119         ):
1120             if n.type == syms.typedargslist:
1121                 feature = Feature.TRAILING_COMMA_IN_DEF
1122             else:
1123                 feature = Feature.TRAILING_COMMA_IN_CALL
1124
1125             for ch in n.children:
1126                 if ch.type in STARS:
1127                     features.add(feature)
1128
1129                 if ch.type == syms.argument:
1130                     for argch in ch.children:
1131                         if argch.type in STARS:
1132                             features.add(feature)
1133
1134     return features
1135
1136
1137 def detect_target_versions(node: Node) -> Set[TargetVersion]:
1138     """Detect the version to target based on the nodes used."""
1139     features = get_features_used(node)
1140     return {
1141         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1142     }
1143
1144
1145 def get_future_imports(node: Node) -> Set[str]:
1146     """Return a set of __future__ imports in the file."""
1147     imports: Set[str] = set()
1148
1149     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1150         for child in children:
1151             if isinstance(child, Leaf):
1152                 if child.type == token.NAME:
1153                     yield child.value
1154
1155             elif child.type == syms.import_as_name:
1156                 orig_name = child.children[0]
1157                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1158                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1159                 yield orig_name.value
1160
1161             elif child.type == syms.import_as_names:
1162                 yield from get_imports_from_children(child.children)
1163
1164             else:
1165                 raise AssertionError("Invalid syntax parsing imports")
1166
1167     for child in node.children:
1168         if child.type != syms.simple_stmt:
1169             break
1170
1171         first_child = child.children[0]
1172         if isinstance(first_child, Leaf):
1173             # Continue looking if we see a docstring; otherwise stop.
1174             if (
1175                 len(child.children) == 2
1176                 and first_child.type == token.STRING
1177                 and child.children[1].type == token.NEWLINE
1178             ):
1179                 continue
1180
1181             break
1182
1183         elif first_child.type == syms.import_from:
1184             module_name = first_child.children[1]
1185             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1186                 break
1187
1188             imports |= set(get_imports_from_children(first_child.children[3:]))
1189         else:
1190             break
1191
1192     return imports
1193
1194
1195 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1196     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1197     try:
1198         src_ast = parse_ast(src)
1199     except Exception as exc:
1200         raise AssertionError(
1201             "cannot use --safe with this file; failed to parse source file.  AST"
1202             f" error message: {exc}"
1203         )
1204
1205     try:
1206         dst_ast = parse_ast(dst)
1207     except Exception as exc:
1208         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1209         raise AssertionError(
1210             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1211             "Please report a bug on https://github.com/psf/black/issues.  "
1212             f"This invalid output might be helpful: {log}"
1213         ) from None
1214
1215     src_ast_str = "\n".join(stringify_ast(src_ast))
1216     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1217     if src_ast_str != dst_ast_str:
1218         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1219         raise AssertionError(
1220             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1221             f" source on pass {pass_num}.  Please report a bug on "
1222             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1223         ) from None
1224
1225
1226 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1227     """Raise AssertionError if `dst` reformats differently the second time."""
1228     newdst = format_str(dst, mode=mode)
1229     if dst != newdst:
1230         log = dump_to_file(
1231             str(mode),
1232             diff(src, dst, "source", "first pass"),
1233             diff(dst, newdst, "first pass", "second pass"),
1234         )
1235         raise AssertionError(
1236             "INTERNAL ERROR: Black produced different code on the second pass of the"
1237             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1238             f"  This diff might be helpful: {log}"
1239         ) from None
1240
1241
1242 @contextmanager
1243 def nullcontext() -> Iterator[None]:
1244     """Return an empty context manager.
1245
1246     To be used like `nullcontext` in Python 3.7.
1247     """
1248     yield
1249
1250
1251 def patch_click() -> None:
1252     """Make Click not crash on Python 3.6 with LANG=C.
1253
1254     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1255     default which restricts paths that it can access during the lifetime of the
1256     application.  Click refuses to work in this scenario by raising a RuntimeError.
1257
1258     In case of Black the likelihood that non-ASCII characters are going to be used in
1259     file paths is minimal since it's Python source code.  Moreover, this crash was
1260     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1261     """
1262     try:
1263         from click import core
1264         from click import _unicodefun  # type: ignore
1265     except ModuleNotFoundError:
1266         return
1267
1268     for module in (core, _unicodefun):
1269         if hasattr(module, "_verify_python3_env"):
1270             module._verify_python3_env = lambda: None  # type: ignore
1271         if hasattr(module, "_verify_python_env"):
1272             module._verify_python_env = lambda: None  # type: ignore
1273
1274
1275 def patched_main() -> None:
1276     maybe_install_uvloop()
1277     freeze_support()
1278     patch_click()
1279     main()
1280
1281
1282 if __name__ == "__main__":
1283     patched_main()