]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Make sure to split lines that start with a string operator (#2286)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
3 from contextlib import contextmanager
4 from datetime import datetime
5 from enum import Enum
6 import io
7 from multiprocessing import Manager, freeze_support
8 import os
9 from pathlib import Path
10 import regex as re
11 import signal
12 import sys
13 import tokenize
14 import traceback
15 from typing import (
16     Any,
17     Dict,
18     Generator,
19     Iterator,
20     List,
21     Optional,
22     Pattern,
23     Set,
24     Sized,
25     Tuple,
26     Union,
27 )
28
29 from dataclasses import replace
30 import click
31
32 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
33 from black.const import STDIN_PLACEHOLDER
34 from black.nodes import STARS, syms, is_simple_decorator_expression
35 from black.lines import Line, EmptyLineTracker
36 from black.linegen import transform_line, LineGenerator, LN
37 from black.comments import normalize_fmt_off
38 from black.mode import Mode, TargetVersion
39 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
40 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
41 from black.concurrency import cancel, shutdown
42 from black.output import dump_to_file, diff, color_diff, out, err
43 from black.report import Report, Changed
44 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
45 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
46 from black.files import wrap_stream_for_windows
47 from black.parsing import InvalidInput  # noqa F401
48 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
49
50
51 # lib2to3 fork
52 from blib2to3.pytree import Node, Leaf
53 from blib2to3.pgen2 import token
54
55 from _black_version import version as __version__
56
57 # If our environment has uvloop installed lets use it
58 try:
59     import uvloop
60
61     uvloop.install()
62 except ImportError:
63     pass
64
65 # types
66 FileContent = str
67 Encoding = str
68 NewLine = str
69
70
71 class NothingChanged(UserWarning):
72     """Raised when reformatted code is the same as source."""
73
74
75 class WriteBack(Enum):
76     NO = 0
77     YES = 1
78     DIFF = 2
79     CHECK = 3
80     COLOR_DIFF = 4
81
82     @classmethod
83     def from_configuration(
84         cls, *, check: bool, diff: bool, color: bool = False
85     ) -> "WriteBack":
86         if check and not diff:
87             return cls.CHECK
88
89         if diff and color:
90             return cls.COLOR_DIFF
91
92         return cls.DIFF if diff else cls.YES
93
94
95 # Legacy name, left for integrations.
96 FileMode = Mode
97
98
99 def read_pyproject_toml(
100     ctx: click.Context, param: click.Parameter, value: Optional[str]
101 ) -> Optional[str]:
102     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
103
104     Returns the path to a successfully found and read configuration file, None
105     otherwise.
106     """
107     if not value:
108         value = find_pyproject_toml(ctx.params.get("src", ()))
109         if value is None:
110             return None
111
112     try:
113         config = parse_pyproject_toml(value)
114     except (OSError, ValueError) as e:
115         raise click.FileError(
116             filename=value, hint=f"Error reading configuration file: {e}"
117         )
118
119     if not config:
120         return None
121     else:
122         # Sanitize the values to be Click friendly. For more information please see:
123         # https://github.com/psf/black/issues/1458
124         # https://github.com/pallets/click/issues/1567
125         config = {
126             k: str(v) if not isinstance(v, (list, dict)) else v
127             for k, v in config.items()
128         }
129
130     target_version = config.get("target_version")
131     if target_version is not None and not isinstance(target_version, list):
132         raise click.BadOptionUsage(
133             "target-version", "Config key target-version must be a list"
134         )
135
136     default_map: Dict[str, Any] = {}
137     if ctx.default_map:
138         default_map.update(ctx.default_map)
139     default_map.update(config)
140
141     ctx.default_map = default_map
142     return value
143
144
145 def target_version_option_callback(
146     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
147 ) -> List[TargetVersion]:
148     """Compute the target versions from a --target-version flag.
149
150     This is its own function because mypy couldn't infer the type correctly
151     when it was a lambda, causing mypyc trouble.
152     """
153     return [TargetVersion[val.upper()] for val in v]
154
155
156 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
157     """Compile a regular expression string in `regex`.
158
159     If it contains newlines, use verbose mode.
160     """
161     if "\n" in regex:
162         regex = "(?x)" + regex
163     compiled: Pattern[str] = re.compile(regex)
164     return compiled
165
166
167 def validate_regex(
168     ctx: click.Context,
169     param: click.Parameter,
170     value: Optional[str],
171 ) -> Optional[Pattern]:
172     try:
173         return re_compile_maybe_verbose(value) if value is not None else None
174     except re.error:
175         raise click.BadParameter("Not a valid regular expression")
176
177
178 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
179 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
180 @click.option(
181     "-l",
182     "--line-length",
183     type=int,
184     default=DEFAULT_LINE_LENGTH,
185     help="How many characters per line to allow.",
186     show_default=True,
187 )
188 @click.option(
189     "-t",
190     "--target-version",
191     type=click.Choice([v.name.lower() for v in TargetVersion]),
192     callback=target_version_option_callback,
193     multiple=True,
194     help=(
195         "Python versions that should be supported by Black's output. [default: per-file"
196         " auto-detection]"
197     ),
198 )
199 @click.option(
200     "--pyi",
201     is_flag=True,
202     help=(
203         "Format all input files like typing stubs regardless of file extension (useful"
204         " when piping source on standard input)."
205     ),
206 )
207 @click.option(
208     "-S",
209     "--skip-string-normalization",
210     is_flag=True,
211     help="Don't normalize string quotes or prefixes.",
212 )
213 @click.option(
214     "-C",
215     "--skip-magic-trailing-comma",
216     is_flag=True,
217     help="Don't use trailing commas as a reason to split lines.",
218 )
219 @click.option(
220     "--experimental-string-processing",
221     is_flag=True,
222     hidden=True,
223     help=(
224         "Experimental option that performs more normalization on string literals."
225         " Currently disabled because it leads to some crashes."
226     ),
227 )
228 @click.option(
229     "--check",
230     is_flag=True,
231     help=(
232         "Don't write the files back, just return the status. Return code 0 means"
233         " nothing would change. Return code 1 means some files would be reformatted."
234         " Return code 123 means there was an internal error."
235     ),
236 )
237 @click.option(
238     "--diff",
239     is_flag=True,
240     help="Don't write the files back, just output a diff for each file on stdout.",
241 )
242 @click.option(
243     "--color/--no-color",
244     is_flag=True,
245     help="Show colored diff. Only applies when `--diff` is given.",
246 )
247 @click.option(
248     "--fast/--safe",
249     is_flag=True,
250     help="If --fast given, skip temporary sanity checks. [default: --safe]",
251 )
252 @click.option(
253     "--include",
254     type=str,
255     default=DEFAULT_INCLUDES,
256     callback=validate_regex,
257     help=(
258         "A regular expression that matches files and directories that should be"
259         " included on recursive searches. An empty value means all files are included"
260         " regardless of the name. Use forward slashes for directories on all platforms"
261         " (Windows, too). Exclusions are calculated first, inclusions later."
262     ),
263     show_default=True,
264 )
265 @click.option(
266     "--exclude",
267     type=str,
268     callback=validate_regex,
269     help=(
270         "A regular expression that matches files and directories that should be"
271         " excluded on recursive searches. An empty value means no paths are excluded."
272         " Use forward slashes for directories on all platforms (Windows, too)."
273         " Exclusions are calculated first, inclusions later. [default:"
274         f" {DEFAULT_EXCLUDES}]"
275     ),
276     show_default=False,
277 )
278 @click.option(
279     "--extend-exclude",
280     type=str,
281     callback=validate_regex,
282     help=(
283         "Like --exclude, but adds additional files and directories on top of the"
284         " excluded ones. (Useful if you simply want to add to the default)"
285     ),
286 )
287 @click.option(
288     "--force-exclude",
289     type=str,
290     callback=validate_regex,
291     help=(
292         "Like --exclude, but files and directories matching this regex will be "
293         "excluded even when they are passed explicitly as arguments."
294     ),
295 )
296 @click.option(
297     "--stdin-filename",
298     type=str,
299     help=(
300         "The name of the file when passing it through stdin. Useful to make "
301         "sure Black will respect --force-exclude option on some "
302         "editors that rely on using stdin."
303     ),
304 )
305 @click.option(
306     "-q",
307     "--quiet",
308     is_flag=True,
309     help=(
310         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
311         " those with 2>/dev/null."
312     ),
313 )
314 @click.option(
315     "-v",
316     "--verbose",
317     is_flag=True,
318     help=(
319         "Also emit messages to stderr about files that were not changed or were ignored"
320         " due to exclusion patterns."
321     ),
322 )
323 @click.version_option(version=__version__)
324 @click.argument(
325     "src",
326     nargs=-1,
327     type=click.Path(
328         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
329     ),
330     is_eager=True,
331 )
332 @click.option(
333     "--config",
334     type=click.Path(
335         exists=True,
336         file_okay=True,
337         dir_okay=False,
338         readable=True,
339         allow_dash=False,
340         path_type=str,
341     ),
342     is_eager=True,
343     callback=read_pyproject_toml,
344     help="Read configuration from FILE path.",
345 )
346 @click.pass_context
347 def main(
348     ctx: click.Context,
349     code: Optional[str],
350     line_length: int,
351     target_version: List[TargetVersion],
352     check: bool,
353     diff: bool,
354     color: bool,
355     fast: bool,
356     pyi: bool,
357     skip_string_normalization: bool,
358     skip_magic_trailing_comma: bool,
359     experimental_string_processing: bool,
360     quiet: bool,
361     verbose: bool,
362     include: Pattern,
363     exclude: Optional[Pattern],
364     extend_exclude: Optional[Pattern],
365     force_exclude: Optional[Pattern],
366     stdin_filename: Optional[str],
367     src: Tuple[str, ...],
368     config: Optional[str],
369 ) -> None:
370     """The uncompromising code formatter."""
371     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
372     if target_version:
373         versions = set(target_version)
374     else:
375         # We'll autodetect later.
376         versions = set()
377     mode = Mode(
378         target_versions=versions,
379         line_length=line_length,
380         is_pyi=pyi,
381         string_normalization=not skip_string_normalization,
382         magic_trailing_comma=not skip_magic_trailing_comma,
383         experimental_string_processing=experimental_string_processing,
384     )
385     if config and verbose:
386         out(f"Using configuration from {config}.", bold=False, fg="blue")
387     if code is not None:
388         print(format_str(code, mode=mode))
389         ctx.exit(0)
390     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
391     sources = get_sources(
392         ctx=ctx,
393         src=src,
394         quiet=quiet,
395         verbose=verbose,
396         include=include,
397         exclude=exclude,
398         extend_exclude=extend_exclude,
399         force_exclude=force_exclude,
400         report=report,
401         stdin_filename=stdin_filename,
402     )
403
404     path_empty(
405         sources,
406         "No Python files are present to be formatted. Nothing to do 😴",
407         quiet,
408         verbose,
409         ctx,
410     )
411
412     if len(sources) == 1:
413         reformat_one(
414             src=sources.pop(),
415             fast=fast,
416             write_back=write_back,
417             mode=mode,
418             report=report,
419         )
420     else:
421         reformat_many(
422             sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
423         )
424
425     if verbose or not quiet:
426         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
427         click.secho(str(report), err=True)
428     ctx.exit(report.return_code)
429
430
431 def get_sources(
432     *,
433     ctx: click.Context,
434     src: Tuple[str, ...],
435     quiet: bool,
436     verbose: bool,
437     include: Pattern[str],
438     exclude: Optional[Pattern[str]],
439     extend_exclude: Optional[Pattern[str]],
440     force_exclude: Optional[Pattern[str]],
441     report: "Report",
442     stdin_filename: Optional[str],
443 ) -> Set[Path]:
444     """Compute the set of files to be formatted."""
445
446     root = find_project_root(src)
447     sources: Set[Path] = set()
448     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
449
450     if exclude is None:
451         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
452         gitignore = get_gitignore(root)
453     else:
454         gitignore = None
455
456     for s in src:
457         if s == "-" and stdin_filename:
458             p = Path(stdin_filename)
459             is_stdin = True
460         else:
461             p = Path(s)
462             is_stdin = False
463
464         if is_stdin or p.is_file():
465             normalized_path = normalize_path_maybe_ignore(p, root, report)
466             if normalized_path is None:
467                 continue
468
469             normalized_path = "/" + normalized_path
470             # Hard-exclude any files that matches the `--force-exclude` regex.
471             if force_exclude:
472                 force_exclude_match = force_exclude.search(normalized_path)
473             else:
474                 force_exclude_match = None
475             if force_exclude_match and force_exclude_match.group(0):
476                 report.path_ignored(p, "matches the --force-exclude regular expression")
477                 continue
478
479             if is_stdin:
480                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
481
482             sources.add(p)
483         elif p.is_dir():
484             sources.update(
485                 gen_python_files(
486                     p.iterdir(),
487                     root,
488                     include,
489                     exclude,
490                     extend_exclude,
491                     force_exclude,
492                     report,
493                     gitignore,
494                 )
495             )
496         elif s == "-":
497             sources.add(p)
498         else:
499             err(f"invalid path: {s}")
500     return sources
501
502
503 def path_empty(
504     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
505 ) -> None:
506     """
507     Exit if there is no `src` provided for formatting
508     """
509     if not src:
510         if verbose or not quiet:
511             out(msg)
512         ctx.exit(0)
513
514
515 def reformat_one(
516     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
517 ) -> None:
518     """Reformat a single file under `src` without spawning child processes.
519
520     `fast`, `write_back`, and `mode` options are passed to
521     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
522     """
523     try:
524         changed = Changed.NO
525
526         if str(src) == "-":
527             is_stdin = True
528         elif str(src).startswith(STDIN_PLACEHOLDER):
529             is_stdin = True
530             # Use the original name again in case we want to print something
531             # to the user
532             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
533         else:
534             is_stdin = False
535
536         if is_stdin:
537             if src.suffix == ".pyi":
538                 mode = replace(mode, is_pyi=True)
539             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
540                 changed = Changed.YES
541         else:
542             cache: Cache = {}
543             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
544                 cache = read_cache(mode)
545                 res_src = src.resolve()
546                 res_src_s = str(res_src)
547                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
548                     changed = Changed.CACHED
549             if changed is not Changed.CACHED and format_file_in_place(
550                 src, fast=fast, write_back=write_back, mode=mode
551             ):
552                 changed = Changed.YES
553             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
554                 write_back is WriteBack.CHECK and changed is Changed.NO
555             ):
556                 write_cache(cache, [src], mode)
557         report.done(src, changed)
558     except Exception as exc:
559         if report.verbose:
560             traceback.print_exc()
561         report.failed(src, str(exc))
562
563
564 def reformat_many(
565     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
566 ) -> None:
567     """Reformat multiple files using a ProcessPoolExecutor."""
568     executor: Executor
569     loop = asyncio.get_event_loop()
570     worker_count = os.cpu_count()
571     if sys.platform == "win32":
572         # Work around https://bugs.python.org/issue26903
573         worker_count = min(worker_count, 60)
574     try:
575         executor = ProcessPoolExecutor(max_workers=worker_count)
576     except (ImportError, OSError):
577         # we arrive here if the underlying system does not support multi-processing
578         # like in AWS Lambda or Termux, in which case we gracefully fallback to
579         # a ThreadPoolExecutor with just a single worker (more workers would not do us
580         # any good due to the Global Interpreter Lock)
581         executor = ThreadPoolExecutor(max_workers=1)
582
583     try:
584         loop.run_until_complete(
585             schedule_formatting(
586                 sources=sources,
587                 fast=fast,
588                 write_back=write_back,
589                 mode=mode,
590                 report=report,
591                 loop=loop,
592                 executor=executor,
593             )
594         )
595     finally:
596         shutdown(loop)
597         if executor is not None:
598             executor.shutdown()
599
600
601 async def schedule_formatting(
602     sources: Set[Path],
603     fast: bool,
604     write_back: WriteBack,
605     mode: Mode,
606     report: "Report",
607     loop: asyncio.AbstractEventLoop,
608     executor: Executor,
609 ) -> None:
610     """Run formatting of `sources` in parallel using the provided `executor`.
611
612     (Use ProcessPoolExecutors for actual parallelism.)
613
614     `write_back`, `fast`, and `mode` options are passed to
615     :func:`format_file_in_place`.
616     """
617     cache: Cache = {}
618     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
619         cache = read_cache(mode)
620         sources, cached = filter_cached(cache, sources)
621         for src in sorted(cached):
622             report.done(src, Changed.CACHED)
623     if not sources:
624         return
625
626     cancelled = []
627     sources_to_cache = []
628     lock = None
629     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
630         # For diff output, we need locks to ensure we don't interleave output
631         # from different processes.
632         manager = Manager()
633         lock = manager.Lock()
634     tasks = {
635         asyncio.ensure_future(
636             loop.run_in_executor(
637                 executor, format_file_in_place, src, fast, mode, write_back, lock
638             )
639         ): src
640         for src in sorted(sources)
641     }
642     pending = tasks.keys()
643     try:
644         loop.add_signal_handler(signal.SIGINT, cancel, pending)
645         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
646     except NotImplementedError:
647         # There are no good alternatives for these on Windows.
648         pass
649     while pending:
650         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
651         for task in done:
652             src = tasks.pop(task)
653             if task.cancelled():
654                 cancelled.append(task)
655             elif task.exception():
656                 report.failed(src, str(task.exception()))
657             else:
658                 changed = Changed.YES if task.result() else Changed.NO
659                 # If the file was written back or was successfully checked as
660                 # well-formatted, store this information in the cache.
661                 if write_back is WriteBack.YES or (
662                     write_back is WriteBack.CHECK and changed is Changed.NO
663                 ):
664                     sources_to_cache.append(src)
665                 report.done(src, changed)
666     if cancelled:
667         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
668     if sources_to_cache:
669         write_cache(cache, sources_to_cache, mode)
670
671
672 def format_file_in_place(
673     src: Path,
674     fast: bool,
675     mode: Mode,
676     write_back: WriteBack = WriteBack.NO,
677     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
678 ) -> bool:
679     """Format file under `src` path. Return True if changed.
680
681     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
682     code to the file.
683     `mode` and `fast` options are passed to :func:`format_file_contents`.
684     """
685     if src.suffix == ".pyi":
686         mode = replace(mode, is_pyi=True)
687
688     then = datetime.utcfromtimestamp(src.stat().st_mtime)
689     with open(src, "rb") as buf:
690         src_contents, encoding, newline = decode_bytes(buf.read())
691     try:
692         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
693     except NothingChanged:
694         return False
695
696     if write_back == WriteBack.YES:
697         with open(src, "w", encoding=encoding, newline=newline) as f:
698             f.write(dst_contents)
699     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
700         now = datetime.utcnow()
701         src_name = f"{src}\t{then} +0000"
702         dst_name = f"{src}\t{now} +0000"
703         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
704
705         if write_back == WriteBack.COLOR_DIFF:
706             diff_contents = color_diff(diff_contents)
707
708         with lock or nullcontext():
709             f = io.TextIOWrapper(
710                 sys.stdout.buffer,
711                 encoding=encoding,
712                 newline=newline,
713                 write_through=True,
714             )
715             f = wrap_stream_for_windows(f)
716             f.write(diff_contents)
717             f.detach()
718
719     return True
720
721
722 def format_stdin_to_stdout(
723     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
724 ) -> bool:
725     """Format file on stdin. Return True if changed.
726
727     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
728     write a diff to stdout. The `mode` argument is passed to
729     :func:`format_file_contents`.
730     """
731     then = datetime.utcnow()
732     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
733     dst = src
734     try:
735         dst = format_file_contents(src, fast=fast, mode=mode)
736         return True
737
738     except NothingChanged:
739         return False
740
741     finally:
742         f = io.TextIOWrapper(
743             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
744         )
745         if write_back == WriteBack.YES:
746             f.write(dst)
747         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
748             now = datetime.utcnow()
749             src_name = f"STDIN\t{then} +0000"
750             dst_name = f"STDOUT\t{now} +0000"
751             d = diff(src, dst, src_name, dst_name)
752             if write_back == WriteBack.COLOR_DIFF:
753                 d = color_diff(d)
754                 f = wrap_stream_for_windows(f)
755             f.write(d)
756         f.detach()
757
758
759 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
760     """Reformat contents of a file and return new contents.
761
762     If `fast` is False, additionally confirm that the reformatted code is
763     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
764     `mode` is passed to :func:`format_str`.
765     """
766     if not src_contents.strip():
767         raise NothingChanged
768
769     dst_contents = format_str(src_contents, mode=mode)
770     if src_contents == dst_contents:
771         raise NothingChanged
772
773     if not fast:
774         assert_equivalent(src_contents, dst_contents)
775
776         # Forced second pass to work around optional trailing commas (becoming
777         # forced trailing commas on pass 2) interacting differently with optional
778         # parentheses.  Admittedly ugly.
779         dst_contents_pass2 = format_str(dst_contents, mode=mode)
780         if dst_contents != dst_contents_pass2:
781             dst_contents = dst_contents_pass2
782             assert_equivalent(src_contents, dst_contents, pass_num=2)
783             assert_stable(src_contents, dst_contents, mode=mode)
784         # Note: no need to explicitly call `assert_stable` if `dst_contents` was
785         # the same as `dst_contents_pass2`.
786     return dst_contents
787
788
789 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
790     """Reformat a string and return new contents.
791
792     `mode` determines formatting options, such as how many characters per line are
793     allowed.  Example:
794
795     >>> import black
796     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
797     def f(arg: str = "") -> None:
798         ...
799
800     A more complex example:
801
802     >>> print(
803     ...   black.format_str(
804     ...     "def f(arg:str='')->None: hey",
805     ...     mode=black.Mode(
806     ...       target_versions={black.TargetVersion.PY36},
807     ...       line_length=10,
808     ...       string_normalization=False,
809     ...       is_pyi=False,
810     ...     ),
811     ...   ),
812     ... )
813     def f(
814         arg: str = '',
815     ) -> None:
816         hey
817
818     """
819     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
820     dst_contents = []
821     future_imports = get_future_imports(src_node)
822     if mode.target_versions:
823         versions = mode.target_versions
824     else:
825         versions = detect_target_versions(src_node)
826     normalize_fmt_off(src_node)
827     lines = LineGenerator(
828         mode=mode,
829         remove_u_prefix="unicode_literals" in future_imports
830         or supports_feature(versions, Feature.UNICODE_LITERALS),
831     )
832     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
833     empty_line = Line(mode=mode)
834     after = 0
835     split_line_features = {
836         feature
837         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
838         if supports_feature(versions, feature)
839     }
840     for current_line in lines.visit(src_node):
841         dst_contents.append(str(empty_line) * after)
842         before, after = elt.maybe_empty_lines(current_line)
843         dst_contents.append(str(empty_line) * before)
844         for line in transform_line(
845             current_line, mode=mode, features=split_line_features
846         ):
847             dst_contents.append(str(line))
848     return "".join(dst_contents)
849
850
851 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
852     """Return a tuple of (decoded_contents, encoding, newline).
853
854     `newline` is either CRLF or LF but `decoded_contents` is decoded with
855     universal newlines (i.e. only contains LF).
856     """
857     srcbuf = io.BytesIO(src)
858     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
859     if not lines:
860         return "", encoding, "\n"
861
862     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
863     srcbuf.seek(0)
864     with io.TextIOWrapper(srcbuf, encoding) as tiow:
865         return tiow.read(), encoding, newline
866
867
868 def get_features_used(node: Node) -> Set[Feature]:
869     """Return a set of (relatively) new Python features used in this file.
870
871     Currently looking for:
872     - f-strings;
873     - underscores in numeric literals;
874     - trailing commas after * or ** in function signatures and calls;
875     - positional only arguments in function signatures and lambdas;
876     - assignment expression;
877     - relaxed decorator syntax;
878     """
879     features: Set[Feature] = set()
880     for n in node.pre_order():
881         if n.type == token.STRING:
882             value_head = n.value[:2]  # type: ignore
883             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
884                 features.add(Feature.F_STRINGS)
885
886         elif n.type == token.NUMBER:
887             if "_" in n.value:  # type: ignore
888                 features.add(Feature.NUMERIC_UNDERSCORES)
889
890         elif n.type == token.SLASH:
891             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
892                 features.add(Feature.POS_ONLY_ARGUMENTS)
893
894         elif n.type == token.COLONEQUAL:
895             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
896
897         elif n.type == syms.decorator:
898             if len(n.children) > 1 and not is_simple_decorator_expression(
899                 n.children[1]
900             ):
901                 features.add(Feature.RELAXED_DECORATORS)
902
903         elif (
904             n.type in {syms.typedargslist, syms.arglist}
905             and n.children
906             and n.children[-1].type == token.COMMA
907         ):
908             if n.type == syms.typedargslist:
909                 feature = Feature.TRAILING_COMMA_IN_DEF
910             else:
911                 feature = Feature.TRAILING_COMMA_IN_CALL
912
913             for ch in n.children:
914                 if ch.type in STARS:
915                     features.add(feature)
916
917                 if ch.type == syms.argument:
918                     for argch in ch.children:
919                         if argch.type in STARS:
920                             features.add(feature)
921
922     return features
923
924
925 def detect_target_versions(node: Node) -> Set[TargetVersion]:
926     """Detect the version to target based on the nodes used."""
927     features = get_features_used(node)
928     return {
929         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
930     }
931
932
933 def get_future_imports(node: Node) -> Set[str]:
934     """Return a set of __future__ imports in the file."""
935     imports: Set[str] = set()
936
937     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
938         for child in children:
939             if isinstance(child, Leaf):
940                 if child.type == token.NAME:
941                     yield child.value
942
943             elif child.type == syms.import_as_name:
944                 orig_name = child.children[0]
945                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
946                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
947                 yield orig_name.value
948
949             elif child.type == syms.import_as_names:
950                 yield from get_imports_from_children(child.children)
951
952             else:
953                 raise AssertionError("Invalid syntax parsing imports")
954
955     for child in node.children:
956         if child.type != syms.simple_stmt:
957             break
958
959         first_child = child.children[0]
960         if isinstance(first_child, Leaf):
961             # Continue looking if we see a docstring; otherwise stop.
962             if (
963                 len(child.children) == 2
964                 and first_child.type == token.STRING
965                 and child.children[1].type == token.NEWLINE
966             ):
967                 continue
968
969             break
970
971         elif first_child.type == syms.import_from:
972             module_name = first_child.children[1]
973             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
974                 break
975
976             imports |= set(get_imports_from_children(first_child.children[3:]))
977         else:
978             break
979
980     return imports
981
982
983 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
984     """Raise AssertionError if `src` and `dst` aren't equivalent."""
985     try:
986         src_ast = parse_ast(src)
987     except Exception as exc:
988         raise AssertionError(
989             "cannot use --safe with this file; failed to parse source file.  AST"
990             f" error message: {exc}"
991         )
992
993     try:
994         dst_ast = parse_ast(dst)
995     except Exception as exc:
996         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
997         raise AssertionError(
998             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
999             "Please report a bug on https://github.com/psf/black/issues.  "
1000             f"This invalid output might be helpful: {log}"
1001         ) from None
1002
1003     src_ast_str = "\n".join(stringify_ast(src_ast))
1004     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1005     if src_ast_str != dst_ast_str:
1006         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1007         raise AssertionError(
1008             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1009             f" source on pass {pass_num}.  Please report a bug on "
1010             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1011         ) from None
1012
1013
1014 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1015     """Raise AssertionError if `dst` reformats differently the second time."""
1016     newdst = format_str(dst, mode=mode)
1017     if dst != newdst:
1018         log = dump_to_file(
1019             str(mode),
1020             diff(src, dst, "source", "first pass"),
1021             diff(dst, newdst, "first pass", "second pass"),
1022         )
1023         raise AssertionError(
1024             "INTERNAL ERROR: Black produced different code on the second pass of the"
1025             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1026             f"  This diff might be helpful: {log}"
1027         ) from None
1028
1029
1030 @contextmanager
1031 def nullcontext() -> Iterator[None]:
1032     """Return an empty context manager.
1033
1034     To be used like `nullcontext` in Python 3.7.
1035     """
1036     yield
1037
1038
1039 def patch_click() -> None:
1040     """Make Click not crash on Python 3.6 with LANG=C.
1041
1042     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1043     default which restricts paths that it can access during the lifetime of the
1044     application.  Click refuses to work in this scenario by raising a RuntimeError.
1045
1046     In case of Black the likelihood that non-ASCII characters are going to be used in
1047     file paths is minimal since it's Python source code.  Moreover, this crash was
1048     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1049     """
1050     try:
1051         from click import core
1052         from click import _unicodefun  # type: ignore
1053     except ModuleNotFoundError:
1054         return
1055
1056     for module in (core, _unicodefun):
1057         if hasattr(module, "_verify_python3_env"):
1058             module._verify_python3_env = lambda: None  # type: ignore
1059         if hasattr(module, "_verify_python_env"):
1060             module._verify_python_env = lambda: None  # type: ignore
1061
1062
1063 def patched_main() -> None:
1064     freeze_support()
1065     patch_click()
1066     main()
1067
1068
1069 if __name__ == "__main__":
1070     patched_main()