]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

c61bc8c1d60e2f2b681a663e2b94435ce1439bf0
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
3 from contextlib import contextmanager
4 from datetime import datetime
5 from enum import Enum
6 import io
7 from multiprocessing import Manager, freeze_support
8 import os
9 from pathlib import Path
10 import regex as re
11 import signal
12 import sys
13 import tokenize
14 import traceback
15 from typing import (
16     Any,
17     Dict,
18     Generator,
19     Iterator,
20     List,
21     Optional,
22     Pattern,
23     Set,
24     Sized,
25     Tuple,
26     Union,
27 )
28
29 from dataclasses import replace
30 import click
31
32 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
33 from black.const import STDIN_PLACEHOLDER
34 from black.nodes import STARS, syms, is_simple_decorator_expression
35 from black.lines import Line, EmptyLineTracker
36 from black.linegen import transform_line, LineGenerator, LN
37 from black.comments import normalize_fmt_off
38 from black.mode import Mode, TargetVersion
39 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
40 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
41 from black.concurrency import cancel, shutdown
42 from black.output import dump_to_file, diff, color_diff, out, err
43 from black.report import Report, Changed
44 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
45 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
46 from black.files import wrap_stream_for_windows
47 from black.parsing import InvalidInput  # noqa F401
48 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
49
50
51 # lib2to3 fork
52 from blib2to3.pytree import Node, Leaf
53 from blib2to3.pgen2 import token
54
55 from _black_version import version as __version__
56
57
58 # types
59 FileContent = str
60 Encoding = str
61 NewLine = str
62
63
64 class NothingChanged(UserWarning):
65     """Raised when reformatted code is the same as source."""
66
67
68 class WriteBack(Enum):
69     NO = 0
70     YES = 1
71     DIFF = 2
72     CHECK = 3
73     COLOR_DIFF = 4
74
75     @classmethod
76     def from_configuration(
77         cls, *, check: bool, diff: bool, color: bool = False
78     ) -> "WriteBack":
79         if check and not diff:
80             return cls.CHECK
81
82         if diff and color:
83             return cls.COLOR_DIFF
84
85         return cls.DIFF if diff else cls.YES
86
87
88 # Legacy name, left for integrations.
89 FileMode = Mode
90
91
92 def read_pyproject_toml(
93     ctx: click.Context, param: click.Parameter, value: Optional[str]
94 ) -> Optional[str]:
95     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
96
97     Returns the path to a successfully found and read configuration file, None
98     otherwise.
99     """
100     if not value:
101         value = find_pyproject_toml(ctx.params.get("src", ()))
102         if value is None:
103             return None
104
105     try:
106         config = parse_pyproject_toml(value)
107     except (OSError, ValueError) as e:
108         raise click.FileError(
109             filename=value, hint=f"Error reading configuration file: {e}"
110         )
111
112     if not config:
113         return None
114     else:
115         # Sanitize the values to be Click friendly. For more information please see:
116         # https://github.com/psf/black/issues/1458
117         # https://github.com/pallets/click/issues/1567
118         config = {
119             k: str(v) if not isinstance(v, (list, dict)) else v
120             for k, v in config.items()
121         }
122
123     target_version = config.get("target_version")
124     if target_version is not None and not isinstance(target_version, list):
125         raise click.BadOptionUsage(
126             "target-version", "Config key target-version must be a list"
127         )
128
129     default_map: Dict[str, Any] = {}
130     if ctx.default_map:
131         default_map.update(ctx.default_map)
132     default_map.update(config)
133
134     ctx.default_map = default_map
135     return value
136
137
138 def target_version_option_callback(
139     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
140 ) -> List[TargetVersion]:
141     """Compute the target versions from a --target-version flag.
142
143     This is its own function because mypy couldn't infer the type correctly
144     when it was a lambda, causing mypyc trouble.
145     """
146     return [TargetVersion[val.upper()] for val in v]
147
148
149 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
150     """Compile a regular expression string in `regex`.
151
152     If it contains newlines, use verbose mode.
153     """
154     if "\n" in regex:
155         regex = "(?x)" + regex
156     compiled: Pattern[str] = re.compile(regex)
157     return compiled
158
159
160 def validate_regex(
161     ctx: click.Context,
162     param: click.Parameter,
163     value: Optional[str],
164 ) -> Optional[Pattern]:
165     try:
166         return re_compile_maybe_verbose(value) if value is not None else None
167     except re.error:
168         raise click.BadParameter("Not a valid regular expression")
169
170
171 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
172 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
173 @click.option(
174     "-l",
175     "--line-length",
176     type=int,
177     default=DEFAULT_LINE_LENGTH,
178     help="How many characters per line to allow.",
179     show_default=True,
180 )
181 @click.option(
182     "-t",
183     "--target-version",
184     type=click.Choice([v.name.lower() for v in TargetVersion]),
185     callback=target_version_option_callback,
186     multiple=True,
187     help=(
188         "Python versions that should be supported by Black's output. [default: per-file"
189         " auto-detection]"
190     ),
191 )
192 @click.option(
193     "--pyi",
194     is_flag=True,
195     help=(
196         "Format all input files like typing stubs regardless of file extension (useful"
197         " when piping source on standard input)."
198     ),
199 )
200 @click.option(
201     "-S",
202     "--skip-string-normalization",
203     is_flag=True,
204     help="Don't normalize string quotes or prefixes.",
205 )
206 @click.option(
207     "-C",
208     "--skip-magic-trailing-comma",
209     is_flag=True,
210     help="Don't use trailing commas as a reason to split lines.",
211 )
212 @click.option(
213     "--experimental-string-processing",
214     is_flag=True,
215     hidden=True,
216     help=(
217         "Experimental option that performs more normalization on string literals."
218         " Currently disabled because it leads to some crashes."
219     ),
220 )
221 @click.option(
222     "--check",
223     is_flag=True,
224     help=(
225         "Don't write the files back, just return the status. Return code 0 means"
226         " nothing would change. Return code 1 means some files would be reformatted."
227         " Return code 123 means there was an internal error."
228     ),
229 )
230 @click.option(
231     "--diff",
232     is_flag=True,
233     help="Don't write the files back, just output a diff for each file on stdout.",
234 )
235 @click.option(
236     "--color/--no-color",
237     is_flag=True,
238     help="Show colored diff. Only applies when `--diff` is given.",
239 )
240 @click.option(
241     "--fast/--safe",
242     is_flag=True,
243     help="If --fast given, skip temporary sanity checks. [default: --safe]",
244 )
245 @click.option(
246     "--include",
247     type=str,
248     default=DEFAULT_INCLUDES,
249     callback=validate_regex,
250     help=(
251         "A regular expression that matches files and directories that should be"
252         " included on recursive searches. An empty value means all files are included"
253         " regardless of the name. Use forward slashes for directories on all platforms"
254         " (Windows, too). Exclusions are calculated first, inclusions later."
255     ),
256     show_default=True,
257 )
258 @click.option(
259     "--exclude",
260     type=str,
261     callback=validate_regex,
262     help=(
263         "A regular expression that matches files and directories that should be"
264         " excluded on recursive searches. An empty value means no paths are excluded."
265         " Use forward slashes for directories on all platforms (Windows, too)."
266         " Exclusions are calculated first, inclusions later. [default:"
267         f" {DEFAULT_EXCLUDES}]"
268     ),
269     show_default=False,
270 )
271 @click.option(
272     "--extend-exclude",
273     type=str,
274     callback=validate_regex,
275     help=(
276         "Like --exclude, but adds additional files and directories on top of the"
277         " excluded ones. (Useful if you simply want to add to the default)"
278     ),
279 )
280 @click.option(
281     "--force-exclude",
282     type=str,
283     callback=validate_regex,
284     help=(
285         "Like --exclude, but files and directories matching this regex will be "
286         "excluded even when they are passed explicitly as arguments."
287     ),
288 )
289 @click.option(
290     "--stdin-filename",
291     type=str,
292     help=(
293         "The name of the file when passing it through stdin. Useful to make "
294         "sure Black will respect --force-exclude option on some "
295         "editors that rely on using stdin."
296     ),
297 )
298 @click.option(
299     "-q",
300     "--quiet",
301     is_flag=True,
302     help=(
303         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
304         " those with 2>/dev/null."
305     ),
306 )
307 @click.option(
308     "-v",
309     "--verbose",
310     is_flag=True,
311     help=(
312         "Also emit messages to stderr about files that were not changed or were ignored"
313         " due to exclusion patterns."
314     ),
315 )
316 @click.version_option(version=__version__)
317 @click.argument(
318     "src",
319     nargs=-1,
320     type=click.Path(
321         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
322     ),
323     is_eager=True,
324 )
325 @click.option(
326     "--config",
327     type=click.Path(
328         exists=True,
329         file_okay=True,
330         dir_okay=False,
331         readable=True,
332         allow_dash=False,
333         path_type=str,
334     ),
335     is_eager=True,
336     callback=read_pyproject_toml,
337     help="Read configuration from FILE path.",
338 )
339 @click.pass_context
340 def main(
341     ctx: click.Context,
342     code: Optional[str],
343     line_length: int,
344     target_version: List[TargetVersion],
345     check: bool,
346     diff: bool,
347     color: bool,
348     fast: bool,
349     pyi: bool,
350     skip_string_normalization: bool,
351     skip_magic_trailing_comma: bool,
352     experimental_string_processing: bool,
353     quiet: bool,
354     verbose: bool,
355     include: Pattern,
356     exclude: Optional[Pattern],
357     extend_exclude: Optional[Pattern],
358     force_exclude: Optional[Pattern],
359     stdin_filename: Optional[str],
360     src: Tuple[str, ...],
361     config: Optional[str],
362 ) -> None:
363     """The uncompromising code formatter."""
364     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
365     if target_version:
366         versions = set(target_version)
367     else:
368         # We'll autodetect later.
369         versions = set()
370     mode = Mode(
371         target_versions=versions,
372         line_length=line_length,
373         is_pyi=pyi,
374         string_normalization=not skip_string_normalization,
375         magic_trailing_comma=not skip_magic_trailing_comma,
376         experimental_string_processing=experimental_string_processing,
377     )
378     if config and verbose:
379         out(f"Using configuration from {config}.", bold=False, fg="blue")
380     if code is not None:
381         print(format_str(code, mode=mode))
382         ctx.exit(0)
383     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
384     sources = get_sources(
385         ctx=ctx,
386         src=src,
387         quiet=quiet,
388         verbose=verbose,
389         include=include,
390         exclude=exclude,
391         extend_exclude=extend_exclude,
392         force_exclude=force_exclude,
393         report=report,
394         stdin_filename=stdin_filename,
395     )
396
397     path_empty(
398         sources,
399         "No Python files are present to be formatted. Nothing to do 😴",
400         quiet,
401         verbose,
402         ctx,
403     )
404
405     if len(sources) == 1:
406         reformat_one(
407             src=sources.pop(),
408             fast=fast,
409             write_back=write_back,
410             mode=mode,
411             report=report,
412         )
413     else:
414         reformat_many(
415             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
416         )
417
418     if verbose or not quiet:
419         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
420         click.secho(str(report), err=True)
421     ctx.exit(report.return_code)
422
423
424 def get_sources(
425     *,
426     ctx: click.Context,
427     src: Tuple[str, ...],
428     quiet: bool,
429     verbose: bool,
430     include: Pattern[str],
431     exclude: Optional[Pattern[str]],
432     extend_exclude: Optional[Pattern[str]],
433     force_exclude: Optional[Pattern[str]],
434     report: "Report",
435     stdin_filename: Optional[str],
436 ) -> Set[Path]:
437     """Compute the set of files to be formatted."""
438
439     root = find_project_root(src)
440     sources: Set[Path] = set()
441     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
442
443     if exclude is None:
444         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
445         gitignore = get_gitignore(root)
446     else:
447         gitignore = None
448
449     for s in src:
450         if s == "-" and stdin_filename:
451             p = Path(stdin_filename)
452             is_stdin = True
453         else:
454             p = Path(s)
455             is_stdin = False
456
457         if is_stdin or p.is_file():
458             normalized_path = normalize_path_maybe_ignore(p, root, report)
459             if normalized_path is None:
460                 continue
461
462             normalized_path = "/" + normalized_path
463             # Hard-exclude any files that matches the `--force-exclude` regex.
464             if force_exclude:
465                 force_exclude_match = force_exclude.search(normalized_path)
466             else:
467                 force_exclude_match = None
468             if force_exclude_match and force_exclude_match.group(0):
469                 report.path_ignored(p, "matches the --force-exclude regular expression")
470                 continue
471
472             if is_stdin:
473                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
474
475             sources.add(p)
476         elif p.is_dir():
477             sources.update(
478                 gen_python_files(
479                     p.iterdir(),
480                     root,
481                     include,
482                     exclude,
483                     extend_exclude,
484                     force_exclude,
485                     report,
486                     gitignore,
487                 )
488             )
489         elif s == "-":
490             sources.add(p)
491         else:
492             err(f"invalid path: {s}")
493     return sources
494
495
496 def path_empty(
497     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
498 ) -> None:
499     """
500     Exit if there is no `src` provided for formatting
501     """
502     if not src and (verbose or not quiet):
503         out(msg)
504         ctx.exit(0)
505
506
507 def reformat_one(
508     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
509 ) -> None:
510     """Reformat a single file under `src` without spawning child processes.
511
512     `fast`, `write_back`, and `mode` options are passed to
513     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
514     """
515     try:
516         changed = Changed.NO
517
518         if str(src) == "-":
519             is_stdin = True
520         elif str(src).startswith(STDIN_PLACEHOLDER):
521             is_stdin = True
522             # Use the original name again in case we want to print something
523             # to the user
524             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
525         else:
526             is_stdin = False
527
528         if is_stdin:
529             if src.suffix == ".pyi":
530                 mode = replace(mode, is_pyi=True)
531             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
532                 changed = Changed.YES
533         else:
534             cache: Cache = {}
535             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
536                 cache = read_cache(mode)
537                 res_src = src.resolve()
538                 res_src_s = str(res_src)
539                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
540                     changed = Changed.CACHED
541             if changed is not Changed.CACHED and format_file_in_place(
542                 src, fast=fast, write_back=write_back, mode=mode
543             ):
544                 changed = Changed.YES
545             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
546                 write_back is WriteBack.CHECK and changed is Changed.NO
547             ):
548                 write_cache(cache, [src], mode)
549         report.done(src, changed)
550     except Exception as exc:
551         if report.verbose:
552             traceback.print_exc()
553         report.failed(src, str(exc))
554
555
556 def reformat_many(
557     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
558 ) -> None:
559     """Reformat multiple files using a ProcessPoolExecutor."""
560     executor: Executor
561     loop = asyncio.get_event_loop()
562     worker_count = os.cpu_count()
563     if sys.platform == "win32":
564         # Work around https://bugs.python.org/issue26903
565         worker_count = min(worker_count, 60)
566     try:
567         executor = ProcessPoolExecutor(max_workers=worker_count)
568     except (ImportError, OSError):
569         # we arrive here if the underlying system does not support multi-processing
570         # like in AWS Lambda or Termux, in which case we gracefully fallback to
571         # a ThreadPoolExecutor with just a single worker (more workers would not do us
572         # any good due to the Global Interpreter Lock)
573         executor = ThreadPoolExecutor(max_workers=1)
574
575     try:
576         loop.run_until_complete(
577             schedule_formatting(
578                 sources=sources,
579                 fast=fast,
580                 write_back=write_back,
581                 mode=mode,
582                 report=report,
583                 loop=loop,
584                 executor=executor,
585             )
586         )
587     finally:
588         shutdown(loop)
589         if executor is not None:
590             executor.shutdown()
591
592
593 async def schedule_formatting(
594     sources: Set[Path],
595     fast: bool,
596     write_back: WriteBack,
597     mode: Mode,
598     report: "Report",
599     loop: asyncio.AbstractEventLoop,
600     executor: Executor,
601 ) -> None:
602     """Run formatting of `sources` in parallel using the provided `executor`.
603
604     (Use ProcessPoolExecutors for actual parallelism.)
605
606     `write_back`, `fast`, and `mode` options are passed to
607     :func:`format_file_in_place`.
608     """
609     cache: Cache = {}
610     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
611         cache = read_cache(mode)
612         sources, cached = filter_cached(cache, sources)
613         for src in sorted(cached):
614             report.done(src, Changed.CACHED)
615     if not sources:
616         return
617
618     cancelled = []
619     sources_to_cache = []
620     lock = None
621     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
622         # For diff output, we need locks to ensure we don't interleave output
623         # from different processes.
624         manager = Manager()
625         lock = manager.Lock()
626     tasks = {
627         asyncio.ensure_future(
628             loop.run_in_executor(
629                 executor, format_file_in_place, src, fast, mode, write_back, lock
630             )
631         ): src
632         for src in sorted(sources)
633     }
634     pending = tasks.keys()
635     try:
636         loop.add_signal_handler(signal.SIGINT, cancel, pending)
637         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
638     except NotImplementedError:
639         # There are no good alternatives for these on Windows.
640         pass
641     while pending:
642         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
643         for task in done:
644             src = tasks.pop(task)
645             if task.cancelled():
646                 cancelled.append(task)
647             elif task.exception():
648                 report.failed(src, str(task.exception()))
649             else:
650                 changed = Changed.YES if task.result() else Changed.NO
651                 # If the file was written back or was successfully checked as
652                 # well-formatted, store this information in the cache.
653                 if write_back is WriteBack.YES or (
654                     write_back is WriteBack.CHECK and changed is Changed.NO
655                 ):
656                     sources_to_cache.append(src)
657                 report.done(src, changed)
658     if cancelled:
659         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
660     if sources_to_cache:
661         write_cache(cache, sources_to_cache, mode)
662
663
664 def format_file_in_place(
665     src: Path,
666     fast: bool,
667     mode: Mode,
668     write_back: WriteBack = WriteBack.NO,
669     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
670 ) -> bool:
671     """Format file under `src` path. Return True if changed.
672
673     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
674     code to the file.
675     `mode` and `fast` options are passed to :func:`format_file_contents`.
676     """
677     if src.suffix == ".pyi":
678         mode = replace(mode, is_pyi=True)
679
680     then = datetime.utcfromtimestamp(src.stat().st_mtime)
681     with open(src, "rb") as buf:
682         src_contents, encoding, newline = decode_bytes(buf.read())
683     try:
684         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
685     except NothingChanged:
686         return False
687
688     if write_back == WriteBack.YES:
689         with open(src, "w", encoding=encoding, newline=newline) as f:
690             f.write(dst_contents)
691     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
692         now = datetime.utcnow()
693         src_name = f"{src}\t{then} +0000"
694         dst_name = f"{src}\t{now} +0000"
695         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
696
697         if write_back == WriteBack.COLOR_DIFF:
698             diff_contents = color_diff(diff_contents)
699
700         with lock or nullcontext():
701             f = io.TextIOWrapper(
702                 sys.stdout.buffer,
703                 encoding=encoding,
704                 newline=newline,
705                 write_through=True,
706             )
707             f = wrap_stream_for_windows(f)
708             f.write(diff_contents)
709             f.detach()
710
711     return True
712
713
714 def format_stdin_to_stdout(
715     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
716 ) -> bool:
717     """Format file on stdin. Return True if changed.
718
719     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
720     write a diff to stdout. The `mode` argument is passed to
721     :func:`format_file_contents`.
722     """
723     then = datetime.utcnow()
724     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
725     dst = src
726     try:
727         dst = format_file_contents(src, fast=fast, mode=mode)
728         return True
729
730     except NothingChanged:
731         return False
732
733     finally:
734         f = io.TextIOWrapper(
735             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
736         )
737         if write_back == WriteBack.YES:
738             f.write(dst)
739         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
740             now = datetime.utcnow()
741             src_name = f"STDIN\t{then} +0000"
742             dst_name = f"STDOUT\t{now} +0000"
743             d = diff(src, dst, src_name, dst_name)
744             if write_back == WriteBack.COLOR_DIFF:
745                 d = color_diff(d)
746                 f = wrap_stream_for_windows(f)
747             f.write(d)
748         f.detach()
749
750
751 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
752     """Reformat contents of a file and return new contents.
753
754     If `fast` is False, additionally confirm that the reformatted code is
755     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
756     `mode` is passed to :func:`format_str`.
757     """
758     if not src_contents.strip():
759         raise NothingChanged
760
761     dst_contents = format_str(src_contents, mode=mode)
762     if src_contents == dst_contents:
763         raise NothingChanged
764
765     if not fast:
766         assert_equivalent(src_contents, dst_contents)
767
768         # Forced second pass to work around optional trailing commas (becoming
769         # forced trailing commas on pass 2) interacting differently with optional
770         # parentheses.  Admittedly ugly.
771         dst_contents_pass2 = format_str(dst_contents, mode=mode)
772         if dst_contents != dst_contents_pass2:
773             dst_contents = dst_contents_pass2
774             assert_equivalent(src_contents, dst_contents, pass_num=2)
775             assert_stable(src_contents, dst_contents, mode=mode)
776         # Note: no need to explicitly call `assert_stable` if `dst_contents` was
777         # the same as `dst_contents_pass2`.
778     return dst_contents
779
780
781 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
782     """Reformat a string and return new contents.
783
784     `mode` determines formatting options, such as how many characters per line are
785     allowed.  Example:
786
787     >>> import black
788     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
789     def f(arg: str = "") -> None:
790         ...
791
792     A more complex example:
793
794     >>> print(
795     ...   black.format_str(
796     ...     "def f(arg:str='')->None: hey",
797     ...     mode=black.Mode(
798     ...       target_versions={black.TargetVersion.PY36},
799     ...       line_length=10,
800     ...       string_normalization=False,
801     ...       is_pyi=False,
802     ...     ),
803     ...   ),
804     ... )
805     def f(
806         arg: str = '',
807     ) -> None:
808         hey
809
810     """
811     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
812     dst_contents = []
813     future_imports = get_future_imports(src_node)
814     if mode.target_versions:
815         versions = mode.target_versions
816     else:
817         versions = detect_target_versions(src_node)
818     normalize_fmt_off(src_node)
819     lines = LineGenerator(
820         mode=mode,
821         remove_u_prefix="unicode_literals" in future_imports
822         or supports_feature(versions, Feature.UNICODE_LITERALS),
823     )
824     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
825     empty_line = Line(mode=mode)
826     after = 0
827     split_line_features = {
828         feature
829         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
830         if supports_feature(versions, feature)
831     }
832     for current_line in lines.visit(src_node):
833         dst_contents.append(str(empty_line) * after)
834         before, after = elt.maybe_empty_lines(current_line)
835         dst_contents.append(str(empty_line) * before)
836         for line in transform_line(
837             current_line, mode=mode, features=split_line_features
838         ):
839             dst_contents.append(str(line))
840     return "".join(dst_contents)
841
842
843 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
844     """Return a tuple of (decoded_contents, encoding, newline).
845
846     `newline` is either CRLF or LF but `decoded_contents` is decoded with
847     universal newlines (i.e. only contains LF).
848     """
849     srcbuf = io.BytesIO(src)
850     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
851     if not lines:
852         return "", encoding, "\n"
853
854     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
855     srcbuf.seek(0)
856     with io.TextIOWrapper(srcbuf, encoding) as tiow:
857         return tiow.read(), encoding, newline
858
859
860 def get_features_used(node: Node) -> Set[Feature]:
861     """Return a set of (relatively) new Python features used in this file.
862
863     Currently looking for:
864     - f-strings;
865     - underscores in numeric literals;
866     - trailing commas after * or ** in function signatures and calls;
867     - positional only arguments in function signatures and lambdas;
868     - assignment expression;
869     - relaxed decorator syntax;
870     """
871     features: Set[Feature] = set()
872     for n in node.pre_order():
873         if n.type == token.STRING:
874             value_head = n.value[:2]  # type: ignore
875             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
876                 features.add(Feature.F_STRINGS)
877
878         elif n.type == token.NUMBER:
879             if "_" in n.value:  # type: ignore
880                 features.add(Feature.NUMERIC_UNDERSCORES)
881
882         elif n.type == token.SLASH:
883             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
884                 features.add(Feature.POS_ONLY_ARGUMENTS)
885
886         elif n.type == token.COLONEQUAL:
887             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
888
889         elif n.type == syms.decorator:
890             if len(n.children) > 1 and not is_simple_decorator_expression(
891                 n.children[1]
892             ):
893                 features.add(Feature.RELAXED_DECORATORS)
894
895         elif (
896             n.type in {syms.typedargslist, syms.arglist}
897             and n.children
898             and n.children[-1].type == token.COMMA
899         ):
900             if n.type == syms.typedargslist:
901                 feature = Feature.TRAILING_COMMA_IN_DEF
902             else:
903                 feature = Feature.TRAILING_COMMA_IN_CALL
904
905             for ch in n.children:
906                 if ch.type in STARS:
907                     features.add(feature)
908
909                 if ch.type == syms.argument:
910                     for argch in ch.children:
911                         if argch.type in STARS:
912                             features.add(feature)
913
914     return features
915
916
917 def detect_target_versions(node: Node) -> Set[TargetVersion]:
918     """Detect the version to target based on the nodes used."""
919     features = get_features_used(node)
920     return {
921         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
922     }
923
924
925 def get_future_imports(node: Node) -> Set[str]:
926     """Return a set of __future__ imports in the file."""
927     imports: Set[str] = set()
928
929     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
930         for child in children:
931             if isinstance(child, Leaf):
932                 if child.type == token.NAME:
933                     yield child.value
934
935             elif child.type == syms.import_as_name:
936                 orig_name = child.children[0]
937                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
938                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
939                 yield orig_name.value
940
941             elif child.type == syms.import_as_names:
942                 yield from get_imports_from_children(child.children)
943
944             else:
945                 raise AssertionError("Invalid syntax parsing imports")
946
947     for child in node.children:
948         if child.type != syms.simple_stmt:
949             break
950
951         first_child = child.children[0]
952         if isinstance(first_child, Leaf):
953             # Continue looking if we see a docstring; otherwise stop.
954             if (
955                 len(child.children) == 2
956                 and first_child.type == token.STRING
957                 and child.children[1].type == token.NEWLINE
958             ):
959                 continue
960
961             break
962
963         elif first_child.type == syms.import_from:
964             module_name = first_child.children[1]
965             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
966                 break
967
968             imports |= set(get_imports_from_children(first_child.children[3:]))
969         else:
970             break
971
972     return imports
973
974
975 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
976     """Raise AssertionError if `src` and `dst` aren't equivalent."""
977     try:
978         src_ast = parse_ast(src)
979     except Exception as exc:
980         raise AssertionError(
981             "cannot use --safe with this file; failed to parse source file.  AST"
982             f" error message: {exc}"
983         )
984
985     try:
986         dst_ast = parse_ast(dst)
987     except Exception as exc:
988         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
989         raise AssertionError(
990             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
991             "Please report a bug on https://github.com/psf/black/issues.  "
992             f"This invalid output might be helpful: {log}"
993         ) from None
994
995     src_ast_str = "\n".join(stringify_ast(src_ast))
996     dst_ast_str = "\n".join(stringify_ast(dst_ast))
997     if src_ast_str != dst_ast_str:
998         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
999         raise AssertionError(
1000             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1001             f" source on pass {pass_num}.  Please report a bug on "
1002             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1003         ) from None
1004
1005
1006 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1007     """Raise AssertionError if `dst` reformats differently the second time."""
1008     newdst = format_str(dst, mode=mode)
1009     if dst != newdst:
1010         log = dump_to_file(
1011             str(mode),
1012             diff(src, dst, "source", "first pass"),
1013             diff(dst, newdst, "first pass", "second pass"),
1014         )
1015         raise AssertionError(
1016             "INTERNAL ERROR: Black produced different code on the second pass of the"
1017             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1018             f"  This diff might be helpful: {log}"
1019         ) from None
1020
1021
1022 @contextmanager
1023 def nullcontext() -> Iterator[None]:
1024     """Return an empty context manager.
1025
1026     To be used like `nullcontext` in Python 3.7.
1027     """
1028     yield
1029
1030
1031 def patch_click() -> None:
1032     """Make Click not crash.
1033
1034     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1035     default which restricts paths that it can access during the lifetime of the
1036     application.  Click refuses to work in this scenario by raising a RuntimeError.
1037
1038     In case of Black the likelihood that non-ASCII characters are going to be used in
1039     file paths is minimal since it's Python source code.  Moreover, this crash was
1040     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1041     """
1042     try:
1043         from click import core
1044         from click import _unicodefun  # type: ignore
1045     except ModuleNotFoundError:
1046         return
1047
1048     for module in (core, _unicodefun):
1049         if hasattr(module, "_verify_python3_env"):
1050             module._verify_python3_env = lambda: None
1051
1052
1053 def patched_main() -> None:
1054     freeze_support()
1055     patch_click()
1056     main()
1057
1058
1059 if __name__ == "__main__":
1060     patched_main()