]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix feature detection for positional-only arguments in lambdas (#2532)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import regex as re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 from dataclasses import replace
34 import click
35
36 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
37 from black.const import STDIN_PLACEHOLDER
38 from black.nodes import STARS, syms, is_simple_decorator_expression
39 from black.lines import Line, EmptyLineTracker
40 from black.linegen import transform_line, LineGenerator, LN
41 from black.comments import normalize_fmt_off
42 from black.mode import Mode, TargetVersion
43 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
44 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
45 from black.concurrency import cancel, shutdown, maybe_install_uvloop
46 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
47 from black.report import Report, Changed, NothingChanged
48 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
49 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
50 from black.files import wrap_stream_for_windows
51 from black.parsing import InvalidInput  # noqa F401
52 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
53 from black.handle_ipynb_magics import (
54     mask_cell,
55     unmask_cell,
56     remove_trailing_semicolon,
57     put_trailing_semicolon_back,
58     TRANSFORMED_MAGICS,
59     jupyter_dependencies_are_installed,
60 )
61
62
63 # lib2to3 fork
64 from blib2to3.pytree import Node, Leaf
65 from blib2to3.pgen2 import token
66
67 from _black_version import version as __version__
68
69 # types
70 FileContent = str
71 Encoding = str
72 NewLine = str
73
74
75 class WriteBack(Enum):
76     NO = 0
77     YES = 1
78     DIFF = 2
79     CHECK = 3
80     COLOR_DIFF = 4
81
82     @classmethod
83     def from_configuration(
84         cls, *, check: bool, diff: bool, color: bool = False
85     ) -> "WriteBack":
86         if check and not diff:
87             return cls.CHECK
88
89         if diff and color:
90             return cls.COLOR_DIFF
91
92         return cls.DIFF if diff else cls.YES
93
94
95 # Legacy name, left for integrations.
96 FileMode = Mode
97
98 DEFAULT_WORKERS = os.cpu_count()
99
100
101 def read_pyproject_toml(
102     ctx: click.Context, param: click.Parameter, value: Optional[str]
103 ) -> Optional[str]:
104     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
105
106     Returns the path to a successfully found and read configuration file, None
107     otherwise.
108     """
109     if not value:
110         value = find_pyproject_toml(ctx.params.get("src", ()))
111         if value is None:
112             return None
113
114     try:
115         config = parse_pyproject_toml(value)
116     except (OSError, ValueError) as e:
117         raise click.FileError(
118             filename=value, hint=f"Error reading configuration file: {e}"
119         ) from None
120
121     if not config:
122         return None
123     else:
124         # Sanitize the values to be Click friendly. For more information please see:
125         # https://github.com/psf/black/issues/1458
126         # https://github.com/pallets/click/issues/1567
127         config = {
128             k: str(v) if not isinstance(v, (list, dict)) else v
129             for k, v in config.items()
130         }
131
132     target_version = config.get("target_version")
133     if target_version is not None and not isinstance(target_version, list):
134         raise click.BadOptionUsage(
135             "target-version", "Config key target-version must be a list"
136         )
137
138     default_map: Dict[str, Any] = {}
139     if ctx.default_map:
140         default_map.update(ctx.default_map)
141     default_map.update(config)
142
143     ctx.default_map = default_map
144     return value
145
146
147 def target_version_option_callback(
148     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
149 ) -> List[TargetVersion]:
150     """Compute the target versions from a --target-version flag.
151
152     This is its own function because mypy couldn't infer the type correctly
153     when it was a lambda, causing mypyc trouble.
154     """
155     return [TargetVersion[val.upper()] for val in v]
156
157
158 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
159     """Compile a regular expression string in `regex`.
160
161     If it contains newlines, use verbose mode.
162     """
163     if "\n" in regex:
164         regex = "(?x)" + regex
165     compiled: Pattern[str] = re.compile(regex)
166     return compiled
167
168
169 def validate_regex(
170     ctx: click.Context,
171     param: click.Parameter,
172     value: Optional[str],
173 ) -> Optional[Pattern]:
174     try:
175         return re_compile_maybe_verbose(value) if value is not None else None
176     except re.error:
177         raise click.BadParameter("Not a valid regular expression") from None
178
179
180 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
181 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
182 @click.option(
183     "-l",
184     "--line-length",
185     type=int,
186     default=DEFAULT_LINE_LENGTH,
187     help="How many characters per line to allow.",
188     show_default=True,
189 )
190 @click.option(
191     "-t",
192     "--target-version",
193     type=click.Choice([v.name.lower() for v in TargetVersion]),
194     callback=target_version_option_callback,
195     multiple=True,
196     help=(
197         "Python versions that should be supported by Black's output. [default: per-file"
198         " auto-detection]"
199     ),
200 )
201 @click.option(
202     "--pyi",
203     is_flag=True,
204     help=(
205         "Format all input files like typing stubs regardless of file extension (useful"
206         " when piping source on standard input)."
207     ),
208 )
209 @click.option(
210     "--ipynb",
211     is_flag=True,
212     help=(
213         "Format all input files like Jupyter Notebooks regardless of file extension "
214         "(useful when piping source on standard input)."
215     ),
216 )
217 @click.option(
218     "-S",
219     "--skip-string-normalization",
220     is_flag=True,
221     help="Don't normalize string quotes or prefixes.",
222 )
223 @click.option(
224     "-C",
225     "--skip-magic-trailing-comma",
226     is_flag=True,
227     help="Don't use trailing commas as a reason to split lines.",
228 )
229 @click.option(
230     "--experimental-string-processing",
231     is_flag=True,
232     hidden=True,
233     help=(
234         "Experimental option that performs more normalization on string literals."
235         " Currently disabled because it leads to some crashes."
236     ),
237 )
238 @click.option(
239     "--check",
240     is_flag=True,
241     help=(
242         "Don't write the files back, just return the status. Return code 0 means"
243         " nothing would change. Return code 1 means some files would be reformatted."
244         " Return code 123 means there was an internal error."
245     ),
246 )
247 @click.option(
248     "--diff",
249     is_flag=True,
250     help="Don't write the files back, just output a diff for each file on stdout.",
251 )
252 @click.option(
253     "--color/--no-color",
254     is_flag=True,
255     help="Show colored diff. Only applies when `--diff` is given.",
256 )
257 @click.option(
258     "--fast/--safe",
259     is_flag=True,
260     help="If --fast given, skip temporary sanity checks. [default: --safe]",
261 )
262 @click.option(
263     "--required-version",
264     type=str,
265     help=(
266         "Require a specific version of Black to be running (useful for unifying results"
267         " across many environments e.g. with a pyproject.toml file)."
268     ),
269 )
270 @click.option(
271     "--include",
272     type=str,
273     default=DEFAULT_INCLUDES,
274     callback=validate_regex,
275     help=(
276         "A regular expression that matches files and directories that should be"
277         " included on recursive searches. An empty value means all files are included"
278         " regardless of the name. Use forward slashes for directories on all platforms"
279         " (Windows, too). Exclusions are calculated first, inclusions later."
280     ),
281     show_default=True,
282 )
283 @click.option(
284     "--exclude",
285     type=str,
286     callback=validate_regex,
287     help=(
288         "A regular expression that matches files and directories that should be"
289         " excluded on recursive searches. An empty value means no paths are excluded."
290         " Use forward slashes for directories on all platforms (Windows, too)."
291         " Exclusions are calculated first, inclusions later. [default:"
292         f" {DEFAULT_EXCLUDES}]"
293     ),
294     show_default=False,
295 )
296 @click.option(
297     "--extend-exclude",
298     type=str,
299     callback=validate_regex,
300     help=(
301         "Like --exclude, but adds additional files and directories on top of the"
302         " excluded ones. (Useful if you simply want to add to the default)"
303     ),
304 )
305 @click.option(
306     "--force-exclude",
307     type=str,
308     callback=validate_regex,
309     help=(
310         "Like --exclude, but files and directories matching this regex will be "
311         "excluded even when they are passed explicitly as arguments."
312     ),
313 )
314 @click.option(
315     "--stdin-filename",
316     type=str,
317     help=(
318         "The name of the file when passing it through stdin. Useful to make "
319         "sure Black will respect --force-exclude option on some "
320         "editors that rely on using stdin."
321     ),
322 )
323 @click.option(
324     "-W",
325     "--workers",
326     type=click.IntRange(min=1),
327     default=DEFAULT_WORKERS,
328     show_default=True,
329     help="Number of parallel workers",
330 )
331 @click.option(
332     "-q",
333     "--quiet",
334     is_flag=True,
335     help=(
336         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
337         " those with 2>/dev/null."
338     ),
339 )
340 @click.option(
341     "-v",
342     "--verbose",
343     is_flag=True,
344     help=(
345         "Also emit messages to stderr about files that were not changed or were ignored"
346         " due to exclusion patterns."
347     ),
348 )
349 @click.version_option(version=__version__)
350 @click.argument(
351     "src",
352     nargs=-1,
353     type=click.Path(
354         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
355     ),
356     is_eager=True,
357     metavar="SRC ...",
358 )
359 @click.option(
360     "--config",
361     type=click.Path(
362         exists=True,
363         file_okay=True,
364         dir_okay=False,
365         readable=True,
366         allow_dash=False,
367         path_type=str,
368     ),
369     is_eager=True,
370     callback=read_pyproject_toml,
371     help="Read configuration from FILE path.",
372 )
373 @click.pass_context
374 def main(
375     ctx: click.Context,
376     code: Optional[str],
377     line_length: int,
378     target_version: List[TargetVersion],
379     check: bool,
380     diff: bool,
381     color: bool,
382     fast: bool,
383     pyi: bool,
384     ipynb: bool,
385     skip_string_normalization: bool,
386     skip_magic_trailing_comma: bool,
387     experimental_string_processing: bool,
388     quiet: bool,
389     verbose: bool,
390     required_version: str,
391     include: Pattern,
392     exclude: Optional[Pattern],
393     extend_exclude: Optional[Pattern],
394     force_exclude: Optional[Pattern],
395     stdin_filename: Optional[str],
396     workers: int,
397     src: Tuple[str, ...],
398     config: Optional[str],
399 ) -> None:
400     """The uncompromising code formatter."""
401     if config and verbose:
402         out(f"Using configuration from {config}.", bold=False, fg="blue")
403
404     error_msg = "Oh no! 💥 💔 💥"
405     if required_version and required_version != __version__:
406         err(
407             f"{error_msg} The required version `{required_version}` does not match"
408             f" the running version `{__version__}`!"
409         )
410         ctx.exit(1)
411     if ipynb and pyi:
412         err("Cannot pass both `pyi` and `ipynb` flags!")
413         ctx.exit(1)
414
415     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
416     if target_version:
417         versions = set(target_version)
418     else:
419         # We'll autodetect later.
420         versions = set()
421     mode = Mode(
422         target_versions=versions,
423         line_length=line_length,
424         is_pyi=pyi,
425         is_ipynb=ipynb,
426         string_normalization=not skip_string_normalization,
427         magic_trailing_comma=not skip_magic_trailing_comma,
428         experimental_string_processing=experimental_string_processing,
429     )
430
431     if code is not None:
432         # Run in quiet mode by default with -c; the extra output isn't useful.
433         # You can still pass -v to get verbose output.
434         quiet = True
435
436     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
437
438     if code is not None:
439         reformat_code(
440             content=code, fast=fast, write_back=write_back, mode=mode, report=report
441         )
442     else:
443         try:
444             sources = get_sources(
445                 ctx=ctx,
446                 src=src,
447                 quiet=quiet,
448                 verbose=verbose,
449                 include=include,
450                 exclude=exclude,
451                 extend_exclude=extend_exclude,
452                 force_exclude=force_exclude,
453                 report=report,
454                 stdin_filename=stdin_filename,
455             )
456         except GitWildMatchPatternError:
457             ctx.exit(1)
458
459         path_empty(
460             sources,
461             "No Python files are present to be formatted. Nothing to do 😴",
462             quiet,
463             verbose,
464             ctx,
465         )
466
467         if len(sources) == 1:
468             reformat_one(
469                 src=sources.pop(),
470                 fast=fast,
471                 write_back=write_back,
472                 mode=mode,
473                 report=report,
474             )
475         else:
476             reformat_many(
477                 sources=sources,
478                 fast=fast,
479                 write_back=write_back,
480                 mode=mode,
481                 report=report,
482                 workers=workers,
483             )
484
485     if verbose or not quiet:
486         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
487         if code is None:
488             click.echo(str(report), err=True)
489     ctx.exit(report.return_code)
490
491
492 def get_sources(
493     *,
494     ctx: click.Context,
495     src: Tuple[str, ...],
496     quiet: bool,
497     verbose: bool,
498     include: Pattern[str],
499     exclude: Optional[Pattern[str]],
500     extend_exclude: Optional[Pattern[str]],
501     force_exclude: Optional[Pattern[str]],
502     report: "Report",
503     stdin_filename: Optional[str],
504 ) -> Set[Path]:
505     """Compute the set of files to be formatted."""
506
507     root = find_project_root(src)
508     sources: Set[Path] = set()
509     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
510
511     if exclude is None:
512         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
513         gitignore = get_gitignore(root)
514     else:
515         gitignore = None
516
517     for s in src:
518         if s == "-" and stdin_filename:
519             p = Path(stdin_filename)
520             is_stdin = True
521         else:
522             p = Path(s)
523             is_stdin = False
524
525         if is_stdin or p.is_file():
526             normalized_path = normalize_path_maybe_ignore(p, root, report)
527             if normalized_path is None:
528                 continue
529
530             normalized_path = "/" + normalized_path
531             # Hard-exclude any files that matches the `--force-exclude` regex.
532             if force_exclude:
533                 force_exclude_match = force_exclude.search(normalized_path)
534             else:
535                 force_exclude_match = None
536             if force_exclude_match and force_exclude_match.group(0):
537                 report.path_ignored(p, "matches the --force-exclude regular expression")
538                 continue
539
540             if is_stdin:
541                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
542
543             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
544                 verbose=verbose, quiet=quiet
545             ):
546                 continue
547
548             sources.add(p)
549         elif p.is_dir():
550             sources.update(
551                 gen_python_files(
552                     p.iterdir(),
553                     root,
554                     include,
555                     exclude,
556                     extend_exclude,
557                     force_exclude,
558                     report,
559                     gitignore,
560                     verbose=verbose,
561                     quiet=quiet,
562                 )
563             )
564         elif s == "-":
565             sources.add(p)
566         else:
567             err(f"invalid path: {s}")
568     return sources
569
570
571 def path_empty(
572     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
573 ) -> None:
574     """
575     Exit if there is no `src` provided for formatting
576     """
577     if not src:
578         if verbose or not quiet:
579             out(msg)
580         ctx.exit(0)
581
582
583 def reformat_code(
584     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
585 ) -> None:
586     """
587     Reformat and print out `content` without spawning child processes.
588     Similar to `reformat_one`, but for string content.
589
590     `fast`, `write_back`, and `mode` options are passed to
591     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
592     """
593     path = Path("<string>")
594     try:
595         changed = Changed.NO
596         if format_stdin_to_stdout(
597             content=content, fast=fast, write_back=write_back, mode=mode
598         ):
599             changed = Changed.YES
600         report.done(path, changed)
601     except Exception as exc:
602         if report.verbose:
603             traceback.print_exc()
604         report.failed(path, str(exc))
605
606
607 def reformat_one(
608     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
609 ) -> None:
610     """Reformat a single file under `src` without spawning child processes.
611
612     `fast`, `write_back`, and `mode` options are passed to
613     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
614     """
615     try:
616         changed = Changed.NO
617
618         if str(src) == "-":
619             is_stdin = True
620         elif str(src).startswith(STDIN_PLACEHOLDER):
621             is_stdin = True
622             # Use the original name again in case we want to print something
623             # to the user
624             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
625         else:
626             is_stdin = False
627
628         if is_stdin:
629             if src.suffix == ".pyi":
630                 mode = replace(mode, is_pyi=True)
631             elif src.suffix == ".ipynb":
632                 mode = replace(mode, is_ipynb=True)
633             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
634                 changed = Changed.YES
635         else:
636             cache: Cache = {}
637             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
638                 cache = read_cache(mode)
639                 res_src = src.resolve()
640                 res_src_s = str(res_src)
641                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
642                     changed = Changed.CACHED
643             if changed is not Changed.CACHED and format_file_in_place(
644                 src, fast=fast, write_back=write_back, mode=mode
645             ):
646                 changed = Changed.YES
647             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
648                 write_back is WriteBack.CHECK and changed is Changed.NO
649             ):
650                 write_cache(cache, [src], mode)
651         report.done(src, changed)
652     except Exception as exc:
653         if report.verbose:
654             traceback.print_exc()
655         report.failed(src, str(exc))
656
657
658 def reformat_many(
659     sources: Set[Path],
660     fast: bool,
661     write_back: WriteBack,
662     mode: Mode,
663     report: "Report",
664     workers: Optional[int],
665 ) -> None:
666     """Reformat multiple files using a ProcessPoolExecutor."""
667     executor: Executor
668     loop = asyncio.get_event_loop()
669     worker_count = workers if workers is not None else DEFAULT_WORKERS
670     if sys.platform == "win32":
671         # Work around https://bugs.python.org/issue26903
672         worker_count = min(worker_count, 60)
673     try:
674         executor = ProcessPoolExecutor(max_workers=worker_count)
675     except (ImportError, OSError):
676         # we arrive here if the underlying system does not support multi-processing
677         # like in AWS Lambda or Termux, in which case we gracefully fallback to
678         # a ThreadPoolExecutor with just a single worker (more workers would not do us
679         # any good due to the Global Interpreter Lock)
680         executor = ThreadPoolExecutor(max_workers=1)
681
682     try:
683         loop.run_until_complete(
684             schedule_formatting(
685                 sources=sources,
686                 fast=fast,
687                 write_back=write_back,
688                 mode=mode,
689                 report=report,
690                 loop=loop,
691                 executor=executor,
692             )
693         )
694     finally:
695         shutdown(loop)
696         if executor is not None:
697             executor.shutdown()
698
699
700 async def schedule_formatting(
701     sources: Set[Path],
702     fast: bool,
703     write_back: WriteBack,
704     mode: Mode,
705     report: "Report",
706     loop: asyncio.AbstractEventLoop,
707     executor: Executor,
708 ) -> None:
709     """Run formatting of `sources` in parallel using the provided `executor`.
710
711     (Use ProcessPoolExecutors for actual parallelism.)
712
713     `write_back`, `fast`, and `mode` options are passed to
714     :func:`format_file_in_place`.
715     """
716     cache: Cache = {}
717     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
718         cache = read_cache(mode)
719         sources, cached = filter_cached(cache, sources)
720         for src in sorted(cached):
721             report.done(src, Changed.CACHED)
722     if not sources:
723         return
724
725     cancelled = []
726     sources_to_cache = []
727     lock = None
728     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
729         # For diff output, we need locks to ensure we don't interleave output
730         # from different processes.
731         manager = Manager()
732         lock = manager.Lock()
733     tasks = {
734         asyncio.ensure_future(
735             loop.run_in_executor(
736                 executor, format_file_in_place, src, fast, mode, write_back, lock
737             )
738         ): src
739         for src in sorted(sources)
740     }
741     pending = tasks.keys()
742     try:
743         loop.add_signal_handler(signal.SIGINT, cancel, pending)
744         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
745     except NotImplementedError:
746         # There are no good alternatives for these on Windows.
747         pass
748     while pending:
749         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
750         for task in done:
751             src = tasks.pop(task)
752             if task.cancelled():
753                 cancelled.append(task)
754             elif task.exception():
755                 report.failed(src, str(task.exception()))
756             else:
757                 changed = Changed.YES if task.result() else Changed.NO
758                 # If the file was written back or was successfully checked as
759                 # well-formatted, store this information in the cache.
760                 if write_back is WriteBack.YES or (
761                     write_back is WriteBack.CHECK and changed is Changed.NO
762                 ):
763                     sources_to_cache.append(src)
764                 report.done(src, changed)
765     if cancelled:
766         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
767     if sources_to_cache:
768         write_cache(cache, sources_to_cache, mode)
769
770
771 def format_file_in_place(
772     src: Path,
773     fast: bool,
774     mode: Mode,
775     write_back: WriteBack = WriteBack.NO,
776     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
777 ) -> bool:
778     """Format file under `src` path. Return True if changed.
779
780     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
781     code to the file.
782     `mode` and `fast` options are passed to :func:`format_file_contents`.
783     """
784     if src.suffix == ".pyi":
785         mode = replace(mode, is_pyi=True)
786     elif src.suffix == ".ipynb":
787         mode = replace(mode, is_ipynb=True)
788
789     then = datetime.utcfromtimestamp(src.stat().st_mtime)
790     with open(src, "rb") as buf:
791         src_contents, encoding, newline = decode_bytes(buf.read())
792     try:
793         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
794     except NothingChanged:
795         return False
796     except JSONDecodeError:
797         raise ValueError(
798             f"File '{src}' cannot be parsed as valid Jupyter notebook."
799         ) from None
800
801     if write_back == WriteBack.YES:
802         with open(src, "w", encoding=encoding, newline=newline) as f:
803             f.write(dst_contents)
804     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
805         now = datetime.utcnow()
806         src_name = f"{src}\t{then} +0000"
807         dst_name = f"{src}\t{now} +0000"
808         if mode.is_ipynb:
809             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
810         else:
811             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
812
813         if write_back == WriteBack.COLOR_DIFF:
814             diff_contents = color_diff(diff_contents)
815
816         with lock or nullcontext():
817             f = io.TextIOWrapper(
818                 sys.stdout.buffer,
819                 encoding=encoding,
820                 newline=newline,
821                 write_through=True,
822             )
823             f = wrap_stream_for_windows(f)
824             f.write(diff_contents)
825             f.detach()
826
827     return True
828
829
830 def format_stdin_to_stdout(
831     fast: bool,
832     *,
833     content: Optional[str] = None,
834     write_back: WriteBack = WriteBack.NO,
835     mode: Mode,
836 ) -> bool:
837     """Format file on stdin. Return True if changed.
838
839     If content is None, it's read from sys.stdin.
840
841     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
842     write a diff to stdout. The `mode` argument is passed to
843     :func:`format_file_contents`.
844     """
845     then = datetime.utcnow()
846
847     if content is None:
848         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
849     else:
850         src, encoding, newline = content, "utf-8", ""
851
852     dst = src
853     try:
854         dst = format_file_contents(src, fast=fast, mode=mode)
855         return True
856
857     except NothingChanged:
858         return False
859
860     finally:
861         f = io.TextIOWrapper(
862             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
863         )
864         if write_back == WriteBack.YES:
865             # Make sure there's a newline after the content
866             if dst and dst[-1] != "\n":
867                 dst += "\n"
868             f.write(dst)
869         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
870             now = datetime.utcnow()
871             src_name = f"STDIN\t{then} +0000"
872             dst_name = f"STDOUT\t{now} +0000"
873             d = diff(src, dst, src_name, dst_name)
874             if write_back == WriteBack.COLOR_DIFF:
875                 d = color_diff(d)
876                 f = wrap_stream_for_windows(f)
877             f.write(d)
878         f.detach()
879
880
881 def check_stability_and_equivalence(
882     src_contents: str, dst_contents: str, *, mode: Mode
883 ) -> None:
884     """Perform stability and equivalence checks.
885
886     Raise AssertionError if source and destination contents are not
887     equivalent, or if a second pass of the formatter would format the
888     content differently.
889     """
890     assert_equivalent(src_contents, dst_contents)
891
892     # Forced second pass to work around optional trailing commas (becoming
893     # forced trailing commas on pass 2) interacting differently with optional
894     # parentheses.  Admittedly ugly.
895     dst_contents_pass2 = format_str(dst_contents, mode=mode)
896     if dst_contents != dst_contents_pass2:
897         dst_contents = dst_contents_pass2
898         assert_equivalent(src_contents, dst_contents, pass_num=2)
899         assert_stable(src_contents, dst_contents, mode=mode)
900     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
901     # the same as `dst_contents_pass2`.
902
903
904 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
905     """Reformat contents of a file and return new contents.
906
907     If `fast` is False, additionally confirm that the reformatted code is
908     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
909     `mode` is passed to :func:`format_str`.
910     """
911     if not src_contents.strip():
912         raise NothingChanged
913
914     if mode.is_ipynb:
915         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
916     else:
917         dst_contents = format_str(src_contents, mode=mode)
918     if src_contents == dst_contents:
919         raise NothingChanged
920
921     if not fast and not mode.is_ipynb:
922         # Jupyter notebooks will already have been checked above.
923         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
924     return dst_contents
925
926
927 def validate_cell(src: str) -> None:
928     """Check that cell does not already contain TransformerManager transformations.
929
930     If a cell contains ``!ls``, then it'll be transformed to
931     ``get_ipython().system('ls')``. However, if the cell originally contained
932     ``get_ipython().system('ls')``, then it would get transformed in the same way:
933
934         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
935         "get_ipython().system('ls')\n"
936         >>> TransformerManager().transform_cell("!ls")
937         "get_ipython().system('ls')\n"
938
939     Due to the impossibility of safely roundtripping in such situations, cells
940     containing transformed magics will be ignored.
941     """
942     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
943         raise NothingChanged
944
945
946 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
947     """Format code in given cell of Jupyter notebook.
948
949     General idea is:
950
951       - if cell has trailing semicolon, remove it;
952       - if cell has IPython magics, mask them;
953       - format cell;
954       - reinstate IPython magics;
955       - reinstate trailing semicolon (if originally present);
956       - strip trailing newlines.
957
958     Cells with syntax errors will not be processed, as they
959     could potentially be automagics or multi-line magics, which
960     are currently not supported.
961     """
962     validate_cell(src)
963     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
964         src
965     )
966     try:
967         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
968     except SyntaxError:
969         raise NothingChanged from None
970     masked_dst = format_str(masked_src, mode=mode)
971     if not fast:
972         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
973     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
974     dst = put_trailing_semicolon_back(
975         dst_without_trailing_semicolon, has_trailing_semicolon
976     )
977     dst = dst.rstrip("\n")
978     if dst == src:
979         raise NothingChanged from None
980     return dst
981
982
983 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
984     """If notebook is marked as non-Python, don't format it.
985
986     All notebook metadata fields are optional, see
987     https://nbformat.readthedocs.io/en/latest/format_description.html. So
988     if a notebook has empty metadata, we will try to parse it anyway.
989     """
990     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
991     if language is not None and language != "python":
992         raise NothingChanged from None
993
994
995 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
996     """Format Jupyter notebook.
997
998     Operate cell-by-cell, only on code cells, only for Python notebooks.
999     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1000     """
1001     trailing_newline = src_contents[-1] == "\n"
1002     modified = False
1003     nb = json.loads(src_contents)
1004     validate_metadata(nb)
1005     for cell in nb["cells"]:
1006         if cell.get("cell_type", None) == "code":
1007             try:
1008                 src = "".join(cell["source"])
1009                 dst = format_cell(src, fast=fast, mode=mode)
1010             except NothingChanged:
1011                 pass
1012             else:
1013                 cell["source"] = dst.splitlines(keepends=True)
1014                 modified = True
1015     if modified:
1016         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1017         if trailing_newline:
1018             dst_contents = dst_contents + "\n"
1019         return dst_contents
1020     else:
1021         raise NothingChanged
1022
1023
1024 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1025     """Reformat a string and return new contents.
1026
1027     `mode` determines formatting options, such as how many characters per line are
1028     allowed.  Example:
1029
1030     >>> import black
1031     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1032     def f(arg: str = "") -> None:
1033         ...
1034
1035     A more complex example:
1036
1037     >>> print(
1038     ...   black.format_str(
1039     ...     "def f(arg:str='')->None: hey",
1040     ...     mode=black.Mode(
1041     ...       target_versions={black.TargetVersion.PY36},
1042     ...       line_length=10,
1043     ...       string_normalization=False,
1044     ...       is_pyi=False,
1045     ...     ),
1046     ...   ),
1047     ... )
1048     def f(
1049         arg: str = '',
1050     ) -> None:
1051         hey
1052
1053     """
1054     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1055     dst_contents = []
1056     future_imports = get_future_imports(src_node)
1057     if mode.target_versions:
1058         versions = mode.target_versions
1059     else:
1060         versions = detect_target_versions(src_node)
1061     normalize_fmt_off(src_node)
1062     lines = LineGenerator(
1063         mode=mode,
1064         remove_u_prefix="unicode_literals" in future_imports
1065         or supports_feature(versions, Feature.UNICODE_LITERALS),
1066     )
1067     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1068     empty_line = Line(mode=mode)
1069     after = 0
1070     split_line_features = {
1071         feature
1072         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1073         if supports_feature(versions, feature)
1074     }
1075     for current_line in lines.visit(src_node):
1076         dst_contents.append(str(empty_line) * after)
1077         before, after = elt.maybe_empty_lines(current_line)
1078         dst_contents.append(str(empty_line) * before)
1079         for line in transform_line(
1080             current_line, mode=mode, features=split_line_features
1081         ):
1082             dst_contents.append(str(line))
1083     return "".join(dst_contents)
1084
1085
1086 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1087     """Return a tuple of (decoded_contents, encoding, newline).
1088
1089     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1090     universal newlines (i.e. only contains LF).
1091     """
1092     srcbuf = io.BytesIO(src)
1093     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1094     if not lines:
1095         return "", encoding, "\n"
1096
1097     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1098     srcbuf.seek(0)
1099     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1100         return tiow.read(), encoding, newline
1101
1102
1103 def get_features_used(node: Node) -> Set[Feature]:
1104     """Return a set of (relatively) new Python features used in this file.
1105
1106     Currently looking for:
1107     - f-strings;
1108     - underscores in numeric literals;
1109     - trailing commas after * or ** in function signatures and calls;
1110     - positional only arguments in function signatures and lambdas;
1111     - assignment expression;
1112     - relaxed decorator syntax;
1113     """
1114     features: Set[Feature] = set()
1115     for n in node.pre_order():
1116         if n.type == token.STRING:
1117             value_head = n.value[:2]  # type: ignore
1118             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1119                 features.add(Feature.F_STRINGS)
1120
1121         elif n.type == token.NUMBER:
1122             if "_" in n.value:  # type: ignore
1123                 features.add(Feature.NUMERIC_UNDERSCORES)
1124
1125         elif n.type == token.SLASH:
1126             if n.parent and n.parent.type in {
1127                 syms.typedargslist,
1128                 syms.arglist,
1129                 syms.varargslist,
1130             }:
1131                 features.add(Feature.POS_ONLY_ARGUMENTS)
1132
1133         elif n.type == token.COLONEQUAL:
1134             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1135
1136         elif n.type == syms.decorator:
1137             if len(n.children) > 1 and not is_simple_decorator_expression(
1138                 n.children[1]
1139             ):
1140                 features.add(Feature.RELAXED_DECORATORS)
1141
1142         elif (
1143             n.type in {syms.typedargslist, syms.arglist}
1144             and n.children
1145             and n.children[-1].type == token.COMMA
1146         ):
1147             if n.type == syms.typedargslist:
1148                 feature = Feature.TRAILING_COMMA_IN_DEF
1149             else:
1150                 feature = Feature.TRAILING_COMMA_IN_CALL
1151
1152             for ch in n.children:
1153                 if ch.type in STARS:
1154                     features.add(feature)
1155
1156                 if ch.type == syms.argument:
1157                     for argch in ch.children:
1158                         if argch.type in STARS:
1159                             features.add(feature)
1160
1161     return features
1162
1163
1164 def detect_target_versions(node: Node) -> Set[TargetVersion]:
1165     """Detect the version to target based on the nodes used."""
1166     features = get_features_used(node)
1167     return {
1168         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1169     }
1170
1171
1172 def get_future_imports(node: Node) -> Set[str]:
1173     """Return a set of __future__ imports in the file."""
1174     imports: Set[str] = set()
1175
1176     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1177         for child in children:
1178             if isinstance(child, Leaf):
1179                 if child.type == token.NAME:
1180                     yield child.value
1181
1182             elif child.type == syms.import_as_name:
1183                 orig_name = child.children[0]
1184                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1185                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1186                 yield orig_name.value
1187
1188             elif child.type == syms.import_as_names:
1189                 yield from get_imports_from_children(child.children)
1190
1191             else:
1192                 raise AssertionError("Invalid syntax parsing imports")
1193
1194     for child in node.children:
1195         if child.type != syms.simple_stmt:
1196             break
1197
1198         first_child = child.children[0]
1199         if isinstance(first_child, Leaf):
1200             # Continue looking if we see a docstring; otherwise stop.
1201             if (
1202                 len(child.children) == 2
1203                 and first_child.type == token.STRING
1204                 and child.children[1].type == token.NEWLINE
1205             ):
1206                 continue
1207
1208             break
1209
1210         elif first_child.type == syms.import_from:
1211             module_name = first_child.children[1]
1212             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1213                 break
1214
1215             imports |= set(get_imports_from_children(first_child.children[3:]))
1216         else:
1217             break
1218
1219     return imports
1220
1221
1222 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1223     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1224     try:
1225         src_ast = parse_ast(src)
1226     except Exception as exc:
1227         raise AssertionError(
1228             "cannot use --safe with this file; failed to parse source file."
1229         ) from exc
1230
1231     try:
1232         dst_ast = parse_ast(dst)
1233     except Exception as exc:
1234         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1235         raise AssertionError(
1236             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1237             "Please report a bug on https://github.com/psf/black/issues.  "
1238             f"This invalid output might be helpful: {log}"
1239         ) from None
1240
1241     src_ast_str = "\n".join(stringify_ast(src_ast))
1242     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1243     if src_ast_str != dst_ast_str:
1244         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1245         raise AssertionError(
1246             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1247             f" source on pass {pass_num}.  Please report a bug on "
1248             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1249         ) from None
1250
1251
1252 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1253     """Raise AssertionError if `dst` reformats differently the second time."""
1254     newdst = format_str(dst, mode=mode)
1255     if dst != newdst:
1256         log = dump_to_file(
1257             str(mode),
1258             diff(src, dst, "source", "first pass"),
1259             diff(dst, newdst, "first pass", "second pass"),
1260         )
1261         raise AssertionError(
1262             "INTERNAL ERROR: Black produced different code on the second pass of the"
1263             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1264             f"  This diff might be helpful: {log}"
1265         ) from None
1266
1267
1268 @contextmanager
1269 def nullcontext() -> Iterator[None]:
1270     """Return an empty context manager.
1271
1272     To be used like `nullcontext` in Python 3.7.
1273     """
1274     yield
1275
1276
1277 def patch_click() -> None:
1278     """Make Click not crash on Python 3.6 with LANG=C.
1279
1280     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1281     default which restricts paths that it can access during the lifetime of the
1282     application.  Click refuses to work in this scenario by raising a RuntimeError.
1283
1284     In case of Black the likelihood that non-ASCII characters are going to be used in
1285     file paths is minimal since it's Python source code.  Moreover, this crash was
1286     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1287     """
1288     try:
1289         from click import core
1290         from click import _unicodefun  # type: ignore
1291     except ModuleNotFoundError:
1292         return
1293
1294     for module in (core, _unicodefun):
1295         if hasattr(module, "_verify_python3_env"):
1296             module._verify_python3_env = lambda: None  # type: ignore
1297         if hasattr(module, "_verify_python_env"):
1298             module._verify_python_env = lambda: None  # type: ignore
1299
1300
1301 def patched_main() -> None:
1302     maybe_install_uvloop()
1303     freeze_support()
1304     patch_click()
1305     main()
1306
1307
1308 if __name__ == "__main__":
1309     patched_main()