]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add --experimental-string-processing to future changes (#2273)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
3 from contextlib import contextmanager
4 from datetime import datetime
5 from enum import Enum
6 import io
7 from multiprocessing import Manager, freeze_support
8 import os
9 from pathlib import Path
10 import regex as re
11 import signal
12 import sys
13 import tokenize
14 import traceback
15 from typing import (
16     Any,
17     Dict,
18     Generator,
19     Iterator,
20     List,
21     Optional,
22     Pattern,
23     Set,
24     Sized,
25     Tuple,
26     Union,
27 )
28
29 from dataclasses import replace
30 import click
31
32 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
33 from black.const import STDIN_PLACEHOLDER
34 from black.nodes import STARS, syms, is_simple_decorator_expression
35 from black.lines import Line, EmptyLineTracker
36 from black.linegen import transform_line, LineGenerator, LN
37 from black.comments import normalize_fmt_off
38 from black.mode import Mode, TargetVersion
39 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
40 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
41 from black.concurrency import cancel, shutdown
42 from black.output import dump_to_file, diff, color_diff, out, err
43 from black.report import Report, Changed
44 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
45 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
46 from black.files import wrap_stream_for_windows
47 from black.parsing import InvalidInput  # noqa F401
48 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
49
50
51 # lib2to3 fork
52 from blib2to3.pytree import Node, Leaf
53 from blib2to3.pgen2 import token
54
55 from _black_version import version as __version__
56
57 # If our environment has uvloop installed lets use it
58 try:
59     import uvloop
60
61     uvloop.install()
62 except ImportError:
63     pass
64
65 # types
66 FileContent = str
67 Encoding = str
68 NewLine = str
69
70
71 class NothingChanged(UserWarning):
72     """Raised when reformatted code is the same as source."""
73
74
75 class WriteBack(Enum):
76     NO = 0
77     YES = 1
78     DIFF = 2
79     CHECK = 3
80     COLOR_DIFF = 4
81
82     @classmethod
83     def from_configuration(
84         cls, *, check: bool, diff: bool, color: bool = False
85     ) -> "WriteBack":
86         if check and not diff:
87             return cls.CHECK
88
89         if diff and color:
90             return cls.COLOR_DIFF
91
92         return cls.DIFF if diff else cls.YES
93
94
95 # Legacy name, left for integrations.
96 FileMode = Mode
97
98
99 def read_pyproject_toml(
100     ctx: click.Context, param: click.Parameter, value: Optional[str]
101 ) -> Optional[str]:
102     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
103
104     Returns the path to a successfully found and read configuration file, None
105     otherwise.
106     """
107     if not value:
108         value = find_pyproject_toml(ctx.params.get("src", ()))
109         if value is None:
110             return None
111
112     try:
113         config = parse_pyproject_toml(value)
114     except (OSError, ValueError) as e:
115         raise click.FileError(
116             filename=value, hint=f"Error reading configuration file: {e}"
117         )
118
119     if not config:
120         return None
121     else:
122         # Sanitize the values to be Click friendly. For more information please see:
123         # https://github.com/psf/black/issues/1458
124         # https://github.com/pallets/click/issues/1567
125         config = {
126             k: str(v) if not isinstance(v, (list, dict)) else v
127             for k, v in config.items()
128         }
129
130     target_version = config.get("target_version")
131     if target_version is not None and not isinstance(target_version, list):
132         raise click.BadOptionUsage(
133             "target-version", "Config key target-version must be a list"
134         )
135
136     default_map: Dict[str, Any] = {}
137     if ctx.default_map:
138         default_map.update(ctx.default_map)
139     default_map.update(config)
140
141     ctx.default_map = default_map
142     return value
143
144
145 def target_version_option_callback(
146     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
147 ) -> List[TargetVersion]:
148     """Compute the target versions from a --target-version flag.
149
150     This is its own function because mypy couldn't infer the type correctly
151     when it was a lambda, causing mypyc trouble.
152     """
153     return [TargetVersion[val.upper()] for val in v]
154
155
156 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
157     """Compile a regular expression string in `regex`.
158
159     If it contains newlines, use verbose mode.
160     """
161     if "\n" in regex:
162         regex = "(?x)" + regex
163     compiled: Pattern[str] = re.compile(regex)
164     return compiled
165
166
167 def validate_regex(
168     ctx: click.Context,
169     param: click.Parameter,
170     value: Optional[str],
171 ) -> Optional[Pattern]:
172     try:
173         return re_compile_maybe_verbose(value) if value is not None else None
174     except re.error:
175         raise click.BadParameter("Not a valid regular expression")
176
177
178 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
179 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
180 @click.option(
181     "-l",
182     "--line-length",
183     type=int,
184     default=DEFAULT_LINE_LENGTH,
185     help="How many characters per line to allow.",
186     show_default=True,
187 )
188 @click.option(
189     "-t",
190     "--target-version",
191     type=click.Choice([v.name.lower() for v in TargetVersion]),
192     callback=target_version_option_callback,
193     multiple=True,
194     help=(
195         "Python versions that should be supported by Black's output. [default: per-file"
196         " auto-detection]"
197     ),
198 )
199 @click.option(
200     "--pyi",
201     is_flag=True,
202     help=(
203         "Format all input files like typing stubs regardless of file extension (useful"
204         " when piping source on standard input)."
205     ),
206 )
207 @click.option(
208     "-S",
209     "--skip-string-normalization",
210     is_flag=True,
211     help="Don't normalize string quotes or prefixes.",
212 )
213 @click.option(
214     "-C",
215     "--skip-magic-trailing-comma",
216     is_flag=True,
217     help="Don't use trailing commas as a reason to split lines.",
218 )
219 @click.option(
220     "--experimental-string-processing",
221     is_flag=True,
222     hidden=True,
223     help=(
224         "Experimental option that performs more normalization on string literals."
225         " Currently disabled because it leads to some crashes."
226     ),
227 )
228 @click.option(
229     "--check",
230     is_flag=True,
231     help=(
232         "Don't write the files back, just return the status. Return code 0 means"
233         " nothing would change. Return code 1 means some files would be reformatted."
234         " Return code 123 means there was an internal error."
235     ),
236 )
237 @click.option(
238     "--diff",
239     is_flag=True,
240     help="Don't write the files back, just output a diff for each file on stdout.",
241 )
242 @click.option(
243     "--color/--no-color",
244     is_flag=True,
245     help="Show colored diff. Only applies when `--diff` is given.",
246 )
247 @click.option(
248     "--fast/--safe",
249     is_flag=True,
250     help="If --fast given, skip temporary sanity checks. [default: --safe]",
251 )
252 @click.option(
253     "--include",
254     type=str,
255     default=DEFAULT_INCLUDES,
256     callback=validate_regex,
257     help=(
258         "A regular expression that matches files and directories that should be"
259         " included on recursive searches. An empty value means all files are included"
260         " regardless of the name. Use forward slashes for directories on all platforms"
261         " (Windows, too). Exclusions are calculated first, inclusions later."
262     ),
263     show_default=True,
264 )
265 @click.option(
266     "--exclude",
267     type=str,
268     callback=validate_regex,
269     help=(
270         "A regular expression that matches files and directories that should be"
271         " excluded on recursive searches. An empty value means no paths are excluded."
272         " Use forward slashes for directories on all platforms (Windows, too)."
273         " Exclusions are calculated first, inclusions later. [default:"
274         f" {DEFAULT_EXCLUDES}]"
275     ),
276     show_default=False,
277 )
278 @click.option(
279     "--extend-exclude",
280     type=str,
281     callback=validate_regex,
282     help=(
283         "Like --exclude, but adds additional files and directories on top of the"
284         " excluded ones. (Useful if you simply want to add to the default)"
285     ),
286 )
287 @click.option(
288     "--force-exclude",
289     type=str,
290     callback=validate_regex,
291     help=(
292         "Like --exclude, but files and directories matching this regex will be "
293         "excluded even when they are passed explicitly as arguments."
294     ),
295 )
296 @click.option(
297     "--stdin-filename",
298     type=str,
299     help=(
300         "The name of the file when passing it through stdin. Useful to make "
301         "sure Black will respect --force-exclude option on some "
302         "editors that rely on using stdin."
303     ),
304 )
305 @click.option(
306     "-q",
307     "--quiet",
308     is_flag=True,
309     help=(
310         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
311         " those with 2>/dev/null."
312     ),
313 )
314 @click.option(
315     "-v",
316     "--verbose",
317     is_flag=True,
318     help=(
319         "Also emit messages to stderr about files that were not changed or were ignored"
320         " due to exclusion patterns."
321     ),
322 )
323 @click.version_option(version=__version__)
324 @click.argument(
325     "src",
326     nargs=-1,
327     type=click.Path(
328         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
329     ),
330     is_eager=True,
331 )
332 @click.option(
333     "--config",
334     type=click.Path(
335         exists=True,
336         file_okay=True,
337         dir_okay=False,
338         readable=True,
339         allow_dash=False,
340         path_type=str,
341     ),
342     is_eager=True,
343     callback=read_pyproject_toml,
344     help="Read configuration from FILE path.",
345 )
346 @click.pass_context
347 def main(
348     ctx: click.Context,
349     code: Optional[str],
350     line_length: int,
351     target_version: List[TargetVersion],
352     check: bool,
353     diff: bool,
354     color: bool,
355     fast: bool,
356     pyi: bool,
357     skip_string_normalization: bool,
358     skip_magic_trailing_comma: bool,
359     experimental_string_processing: bool,
360     quiet: bool,
361     verbose: bool,
362     include: Pattern,
363     exclude: Optional[Pattern],
364     extend_exclude: Optional[Pattern],
365     force_exclude: Optional[Pattern],
366     stdin_filename: Optional[str],
367     src: Tuple[str, ...],
368     config: Optional[str],
369 ) -> None:
370     """The uncompromising code formatter."""
371     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
372     if target_version:
373         versions = set(target_version)
374     else:
375         # We'll autodetect later.
376         versions = set()
377     mode = Mode(
378         target_versions=versions,
379         line_length=line_length,
380         is_pyi=pyi,
381         string_normalization=not skip_string_normalization,
382         magic_trailing_comma=not skip_magic_trailing_comma,
383         experimental_string_processing=experimental_string_processing,
384     )
385     if config and verbose:
386         out(f"Using configuration from {config}.", bold=False, fg="blue")
387     if code is not None:
388         print(format_str(code, mode=mode))
389         ctx.exit(0)
390     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
391     sources = get_sources(
392         ctx=ctx,
393         src=src,
394         quiet=quiet,
395         verbose=verbose,
396         include=include,
397         exclude=exclude,
398         extend_exclude=extend_exclude,
399         force_exclude=force_exclude,
400         report=report,
401         stdin_filename=stdin_filename,
402     )
403
404     path_empty(
405         sources,
406         "No Python files are present to be formatted. Nothing to do 😴",
407         quiet,
408         verbose,
409         ctx,
410     )
411
412     if len(sources) == 1:
413         reformat_one(
414             src=sources.pop(),
415             fast=fast,
416             write_back=write_back,
417             mode=mode,
418             report=report,
419         )
420     else:
421         reformat_many(
422             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
423         )
424
425     if verbose or not quiet:
426         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
427         click.secho(str(report), err=True)
428     ctx.exit(report.return_code)
429
430
431 def get_sources(
432     *,
433     ctx: click.Context,
434     src: Tuple[str, ...],
435     quiet: bool,
436     verbose: bool,
437     include: Pattern[str],
438     exclude: Optional[Pattern[str]],
439     extend_exclude: Optional[Pattern[str]],
440     force_exclude: Optional[Pattern[str]],
441     report: "Report",
442     stdin_filename: Optional[str],
443 ) -> Set[Path]:
444     """Compute the set of files to be formatted."""
445
446     root = find_project_root(src)
447     sources: Set[Path] = set()
448     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
449
450     if exclude is None:
451         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
452         gitignore = get_gitignore(root)
453     else:
454         gitignore = None
455
456     for s in src:
457         if s == "-" and stdin_filename:
458             p = Path(stdin_filename)
459             is_stdin = True
460         else:
461             p = Path(s)
462             is_stdin = False
463
464         if is_stdin or p.is_file():
465             normalized_path = normalize_path_maybe_ignore(p, root, report)
466             if normalized_path is None:
467                 continue
468
469             normalized_path = "/" + normalized_path
470             # Hard-exclude any files that matches the `--force-exclude` regex.
471             if force_exclude:
472                 force_exclude_match = force_exclude.search(normalized_path)
473             else:
474                 force_exclude_match = None
475             if force_exclude_match and force_exclude_match.group(0):
476                 report.path_ignored(p, "matches the --force-exclude regular expression")
477                 continue
478
479             if is_stdin:
480                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
481
482             sources.add(p)
483         elif p.is_dir():
484             sources.update(
485                 gen_python_files(
486                     p.iterdir(),
487                     root,
488                     include,
489                     exclude,
490                     extend_exclude,
491                     force_exclude,
492                     report,
493                     gitignore,
494                 )
495             )
496         elif s == "-":
497             sources.add(p)
498         else:
499             err(f"invalid path: {s}")
500     return sources
501
502
503 def path_empty(
504     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
505 ) -> None:
506     """
507     Exit if there is no `src` provided for formatting
508     """
509     if not src and (verbose or not quiet):
510         out(msg)
511         ctx.exit(0)
512
513
514 def reformat_one(
515     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
516 ) -> None:
517     """Reformat a single file under `src` without spawning child processes.
518
519     `fast`, `write_back`, and `mode` options are passed to
520     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
521     """
522     try:
523         changed = Changed.NO
524
525         if str(src) == "-":
526             is_stdin = True
527         elif str(src).startswith(STDIN_PLACEHOLDER):
528             is_stdin = True
529             # Use the original name again in case we want to print something
530             # to the user
531             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
532         else:
533             is_stdin = False
534
535         if is_stdin:
536             if src.suffix == ".pyi":
537                 mode = replace(mode, is_pyi=True)
538             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
539                 changed = Changed.YES
540         else:
541             cache: Cache = {}
542             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
543                 cache = read_cache(mode)
544                 res_src = src.resolve()
545                 res_src_s = str(res_src)
546                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
547                     changed = Changed.CACHED
548             if changed is not Changed.CACHED and format_file_in_place(
549                 src, fast=fast, write_back=write_back, mode=mode
550             ):
551                 changed = Changed.YES
552             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
553                 write_back is WriteBack.CHECK and changed is Changed.NO
554             ):
555                 write_cache(cache, [src], mode)
556         report.done(src, changed)
557     except Exception as exc:
558         if report.verbose:
559             traceback.print_exc()
560         report.failed(src, str(exc))
561
562
563 def reformat_many(
564     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
565 ) -> None:
566     """Reformat multiple files using a ProcessPoolExecutor."""
567     executor: Executor
568     loop = asyncio.get_event_loop()
569     worker_count = os.cpu_count()
570     if sys.platform == "win32":
571         # Work around https://bugs.python.org/issue26903
572         worker_count = min(worker_count, 60)
573     try:
574         executor = ProcessPoolExecutor(max_workers=worker_count)
575     except (ImportError, OSError):
576         # we arrive here if the underlying system does not support multi-processing
577         # like in AWS Lambda or Termux, in which case we gracefully fallback to
578         # a ThreadPoolExecutor with just a single worker (more workers would not do us
579         # any good due to the Global Interpreter Lock)
580         executor = ThreadPoolExecutor(max_workers=1)
581
582     try:
583         loop.run_until_complete(
584             schedule_formatting(
585                 sources=sources,
586                 fast=fast,
587                 write_back=write_back,
588                 mode=mode,
589                 report=report,
590                 loop=loop,
591                 executor=executor,
592             )
593         )
594     finally:
595         shutdown(loop)
596         if executor is not None:
597             executor.shutdown()
598
599
600 async def schedule_formatting(
601     sources: Set[Path],
602     fast: bool,
603     write_back: WriteBack,
604     mode: Mode,
605     report: "Report",
606     loop: asyncio.AbstractEventLoop,
607     executor: Executor,
608 ) -> None:
609     """Run formatting of `sources` in parallel using the provided `executor`.
610
611     (Use ProcessPoolExecutors for actual parallelism.)
612
613     `write_back`, `fast`, and `mode` options are passed to
614     :func:`format_file_in_place`.
615     """
616     cache: Cache = {}
617     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
618         cache = read_cache(mode)
619         sources, cached = filter_cached(cache, sources)
620         for src in sorted(cached):
621             report.done(src, Changed.CACHED)
622     if not sources:
623         return
624
625     cancelled = []
626     sources_to_cache = []
627     lock = None
628     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
629         # For diff output, we need locks to ensure we don't interleave output
630         # from different processes.
631         manager = Manager()
632         lock = manager.Lock()
633     tasks = {
634         asyncio.ensure_future(
635             loop.run_in_executor(
636                 executor, format_file_in_place, src, fast, mode, write_back, lock
637             )
638         ): src
639         for src in sorted(sources)
640     }
641     pending = tasks.keys()
642     try:
643         loop.add_signal_handler(signal.SIGINT, cancel, pending)
644         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
645     except NotImplementedError:
646         # There are no good alternatives for these on Windows.
647         pass
648     while pending:
649         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
650         for task in done:
651             src = tasks.pop(task)
652             if task.cancelled():
653                 cancelled.append(task)
654             elif task.exception():
655                 report.failed(src, str(task.exception()))
656             else:
657                 changed = Changed.YES if task.result() else Changed.NO
658                 # If the file was written back or was successfully checked as
659                 # well-formatted, store this information in the cache.
660                 if write_back is WriteBack.YES or (
661                     write_back is WriteBack.CHECK and changed is Changed.NO
662                 ):
663                     sources_to_cache.append(src)
664                 report.done(src, changed)
665     if cancelled:
666         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
667     if sources_to_cache:
668         write_cache(cache, sources_to_cache, mode)
669
670
671 def format_file_in_place(
672     src: Path,
673     fast: bool,
674     mode: Mode,
675     write_back: WriteBack = WriteBack.NO,
676     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
677 ) -> bool:
678     """Format file under `src` path. Return True if changed.
679
680     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
681     code to the file.
682     `mode` and `fast` options are passed to :func:`format_file_contents`.
683     """
684     if src.suffix == ".pyi":
685         mode = replace(mode, is_pyi=True)
686
687     then = datetime.utcfromtimestamp(src.stat().st_mtime)
688     with open(src, "rb") as buf:
689         src_contents, encoding, newline = decode_bytes(buf.read())
690     try:
691         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
692     except NothingChanged:
693         return False
694
695     if write_back == WriteBack.YES:
696         with open(src, "w", encoding=encoding, newline=newline) as f:
697             f.write(dst_contents)
698     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
699         now = datetime.utcnow()
700         src_name = f"{src}\t{then} +0000"
701         dst_name = f"{src}\t{now} +0000"
702         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
703
704         if write_back == WriteBack.COLOR_DIFF:
705             diff_contents = color_diff(diff_contents)
706
707         with lock or nullcontext():
708             f = io.TextIOWrapper(
709                 sys.stdout.buffer,
710                 encoding=encoding,
711                 newline=newline,
712                 write_through=True,
713             )
714             f = wrap_stream_for_windows(f)
715             f.write(diff_contents)
716             f.detach()
717
718     return True
719
720
721 def format_stdin_to_stdout(
722     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
723 ) -> bool:
724     """Format file on stdin. Return True if changed.
725
726     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
727     write a diff to stdout. The `mode` argument is passed to
728     :func:`format_file_contents`.
729     """
730     then = datetime.utcnow()
731     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
732     dst = src
733     try:
734         dst = format_file_contents(src, fast=fast, mode=mode)
735         return True
736
737     except NothingChanged:
738         return False
739
740     finally:
741         f = io.TextIOWrapper(
742             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
743         )
744         if write_back == WriteBack.YES:
745             f.write(dst)
746         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
747             now = datetime.utcnow()
748             src_name = f"STDIN\t{then} +0000"
749             dst_name = f"STDOUT\t{now} +0000"
750             d = diff(src, dst, src_name, dst_name)
751             if write_back == WriteBack.COLOR_DIFF:
752                 d = color_diff(d)
753                 f = wrap_stream_for_windows(f)
754             f.write(d)
755         f.detach()
756
757
758 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
759     """Reformat contents of a file and return new contents.
760
761     If `fast` is False, additionally confirm that the reformatted code is
762     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
763     `mode` is passed to :func:`format_str`.
764     """
765     if not src_contents.strip():
766         raise NothingChanged
767
768     dst_contents = format_str(src_contents, mode=mode)
769     if src_contents == dst_contents:
770         raise NothingChanged
771
772     if not fast:
773         assert_equivalent(src_contents, dst_contents)
774
775         # Forced second pass to work around optional trailing commas (becoming
776         # forced trailing commas on pass 2) interacting differently with optional
777         # parentheses.  Admittedly ugly.
778         dst_contents_pass2 = format_str(dst_contents, mode=mode)
779         if dst_contents != dst_contents_pass2:
780             dst_contents = dst_contents_pass2
781             assert_equivalent(src_contents, dst_contents, pass_num=2)
782             assert_stable(src_contents, dst_contents, mode=mode)
783         # Note: no need to explicitly call `assert_stable` if `dst_contents` was
784         # the same as `dst_contents_pass2`.
785     return dst_contents
786
787
788 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
789     """Reformat a string and return new contents.
790
791     `mode` determines formatting options, such as how many characters per line are
792     allowed.  Example:
793
794     >>> import black
795     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
796     def f(arg: str = "") -> None:
797         ...
798
799     A more complex example:
800
801     >>> print(
802     ...   black.format_str(
803     ...     "def f(arg:str='')->None: hey",
804     ...     mode=black.Mode(
805     ...       target_versions={black.TargetVersion.PY36},
806     ...       line_length=10,
807     ...       string_normalization=False,
808     ...       is_pyi=False,
809     ...     ),
810     ...   ),
811     ... )
812     def f(
813         arg: str = '',
814     ) -> None:
815         hey
816
817     """
818     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
819     dst_contents = []
820     future_imports = get_future_imports(src_node)
821     if mode.target_versions:
822         versions = mode.target_versions
823     else:
824         versions = detect_target_versions(src_node)
825     normalize_fmt_off(src_node)
826     lines = LineGenerator(
827         mode=mode,
828         remove_u_prefix="unicode_literals" in future_imports
829         or supports_feature(versions, Feature.UNICODE_LITERALS),
830     )
831     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
832     empty_line = Line(mode=mode)
833     after = 0
834     split_line_features = {
835         feature
836         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
837         if supports_feature(versions, feature)
838     }
839     for current_line in lines.visit(src_node):
840         dst_contents.append(str(empty_line) * after)
841         before, after = elt.maybe_empty_lines(current_line)
842         dst_contents.append(str(empty_line) * before)
843         for line in transform_line(
844             current_line, mode=mode, features=split_line_features
845         ):
846             dst_contents.append(str(line))
847     return "".join(dst_contents)
848
849
850 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
851     """Return a tuple of (decoded_contents, encoding, newline).
852
853     `newline` is either CRLF or LF but `decoded_contents` is decoded with
854     universal newlines (i.e. only contains LF).
855     """
856     srcbuf = io.BytesIO(src)
857     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
858     if not lines:
859         return "", encoding, "\n"
860
861     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
862     srcbuf.seek(0)
863     with io.TextIOWrapper(srcbuf, encoding) as tiow:
864         return tiow.read(), encoding, newline
865
866
867 def get_features_used(node: Node) -> Set[Feature]:
868     """Return a set of (relatively) new Python features used in this file.
869
870     Currently looking for:
871     - f-strings;
872     - underscores in numeric literals;
873     - trailing commas after * or ** in function signatures and calls;
874     - positional only arguments in function signatures and lambdas;
875     - assignment expression;
876     - relaxed decorator syntax;
877     """
878     features: Set[Feature] = set()
879     for n in node.pre_order():
880         if n.type == token.STRING:
881             value_head = n.value[:2]  # type: ignore
882             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
883                 features.add(Feature.F_STRINGS)
884
885         elif n.type == token.NUMBER:
886             if "_" in n.value:  # type: ignore
887                 features.add(Feature.NUMERIC_UNDERSCORES)
888
889         elif n.type == token.SLASH:
890             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
891                 features.add(Feature.POS_ONLY_ARGUMENTS)
892
893         elif n.type == token.COLONEQUAL:
894             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
895
896         elif n.type == syms.decorator:
897             if len(n.children) > 1 and not is_simple_decorator_expression(
898                 n.children[1]
899             ):
900                 features.add(Feature.RELAXED_DECORATORS)
901
902         elif (
903             n.type in {syms.typedargslist, syms.arglist}
904             and n.children
905             and n.children[-1].type == token.COMMA
906         ):
907             if n.type == syms.typedargslist:
908                 feature = Feature.TRAILING_COMMA_IN_DEF
909             else:
910                 feature = Feature.TRAILING_COMMA_IN_CALL
911
912             for ch in n.children:
913                 if ch.type in STARS:
914                     features.add(feature)
915
916                 if ch.type == syms.argument:
917                     for argch in ch.children:
918                         if argch.type in STARS:
919                             features.add(feature)
920
921     return features
922
923
924 def detect_target_versions(node: Node) -> Set[TargetVersion]:
925     """Detect the version to target based on the nodes used."""
926     features = get_features_used(node)
927     return {
928         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
929     }
930
931
932 def get_future_imports(node: Node) -> Set[str]:
933     """Return a set of __future__ imports in the file."""
934     imports: Set[str] = set()
935
936     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
937         for child in children:
938             if isinstance(child, Leaf):
939                 if child.type == token.NAME:
940                     yield child.value
941
942             elif child.type == syms.import_as_name:
943                 orig_name = child.children[0]
944                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
945                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
946                 yield orig_name.value
947
948             elif child.type == syms.import_as_names:
949                 yield from get_imports_from_children(child.children)
950
951             else:
952                 raise AssertionError("Invalid syntax parsing imports")
953
954     for child in node.children:
955         if child.type != syms.simple_stmt:
956             break
957
958         first_child = child.children[0]
959         if isinstance(first_child, Leaf):
960             # Continue looking if we see a docstring; otherwise stop.
961             if (
962                 len(child.children) == 2
963                 and first_child.type == token.STRING
964                 and child.children[1].type == token.NEWLINE
965             ):
966                 continue
967
968             break
969
970         elif first_child.type == syms.import_from:
971             module_name = first_child.children[1]
972             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
973                 break
974
975             imports |= set(get_imports_from_children(first_child.children[3:]))
976         else:
977             break
978
979     return imports
980
981
982 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
983     """Raise AssertionError if `src` and `dst` aren't equivalent."""
984     try:
985         src_ast = parse_ast(src)
986     except Exception as exc:
987         raise AssertionError(
988             "cannot use --safe with this file; failed to parse source file.  AST"
989             f" error message: {exc}"
990         )
991
992     try:
993         dst_ast = parse_ast(dst)
994     except Exception as exc:
995         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
996         raise AssertionError(
997             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
998             "Please report a bug on https://github.com/psf/black/issues.  "
999             f"This invalid output might be helpful: {log}"
1000         ) from None
1001
1002     src_ast_str = "\n".join(stringify_ast(src_ast))
1003     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1004     if src_ast_str != dst_ast_str:
1005         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1006         raise AssertionError(
1007             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1008             f" source on pass {pass_num}.  Please report a bug on "
1009             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1010         ) from None
1011
1012
1013 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1014     """Raise AssertionError if `dst` reformats differently the second time."""
1015     newdst = format_str(dst, mode=mode)
1016     if dst != newdst:
1017         log = dump_to_file(
1018             str(mode),
1019             diff(src, dst, "source", "first pass"),
1020             diff(dst, newdst, "first pass", "second pass"),
1021         )
1022         raise AssertionError(
1023             "INTERNAL ERROR: Black produced different code on the second pass of the"
1024             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1025             f"  This diff might be helpful: {log}"
1026         ) from None
1027
1028
1029 @contextmanager
1030 def nullcontext() -> Iterator[None]:
1031     """Return an empty context manager.
1032
1033     To be used like `nullcontext` in Python 3.7.
1034     """
1035     yield
1036
1037
1038 def patch_click() -> None:
1039     """Make Click not crash on Python 3.6 with LANG=C.
1040
1041     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1042     default which restricts paths that it can access during the lifetime of the
1043     application.  Click refuses to work in this scenario by raising a RuntimeError.
1044
1045     In case of Black the likelihood that non-ASCII characters are going to be used in
1046     file paths is minimal since it's Python source code.  Moreover, this crash was
1047     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1048     """
1049     try:
1050         from click import core
1051         from click import _unicodefun  # type: ignore
1052     except ModuleNotFoundError:
1053         return
1054
1055     for module in (core, _unicodefun):
1056         if hasattr(module, "_verify_python3_env"):
1057             module._verify_python3_env = lambda: None  # type: ignore
1058         if hasattr(module, "_verify_python_env"):
1059             module._verify_python_env = lambda: None  # type: ignore
1060
1061
1062 def patched_main() -> None:
1063     freeze_support()
1064     patch_click()
1065     main()
1066
1067
1068 if __name__ == "__main__":
1069     patched_main()