]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

add context manager to temporarily change the cwd (#2377)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
3 from contextlib import contextmanager
4 from datetime import datetime
5 from enum import Enum
6 import io
7 from multiprocessing import Manager, freeze_support
8 import os
9 from pathlib import Path
10 import regex as re
11 import signal
12 import sys
13 import tokenize
14 import traceback
15 from typing import (
16     Any,
17     Dict,
18     Generator,
19     Iterator,
20     List,
21     Optional,
22     Pattern,
23     Set,
24     Sized,
25     Tuple,
26     Union,
27 )
28
29 from dataclasses import replace
30 import click
31
32 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
33 from black.const import STDIN_PLACEHOLDER
34 from black.nodes import STARS, syms, is_simple_decorator_expression
35 from black.lines import Line, EmptyLineTracker
36 from black.linegen import transform_line, LineGenerator, LN
37 from black.comments import normalize_fmt_off
38 from black.mode import Mode, TargetVersion
39 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
40 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
41 from black.concurrency import cancel, shutdown, maybe_install_uvloop
42 from black.output import dump_to_file, diff, color_diff, out, err
43 from black.report import Report, Changed
44 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
45 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
46 from black.files import wrap_stream_for_windows
47 from black.parsing import InvalidInput  # noqa F401
48 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
49
50
51 # lib2to3 fork
52 from blib2to3.pytree import Node, Leaf
53 from blib2to3.pgen2 import token
54
55 from _black_version import version as __version__
56
57 # types
58 FileContent = str
59 Encoding = str
60 NewLine = str
61
62
63 class NothingChanged(UserWarning):
64     """Raised when reformatted code is the same as source."""
65
66
67 class WriteBack(Enum):
68     NO = 0
69     YES = 1
70     DIFF = 2
71     CHECK = 3
72     COLOR_DIFF = 4
73
74     @classmethod
75     def from_configuration(
76         cls, *, check: bool, diff: bool, color: bool = False
77     ) -> "WriteBack":
78         if check and not diff:
79             return cls.CHECK
80
81         if diff and color:
82             return cls.COLOR_DIFF
83
84         return cls.DIFF if diff else cls.YES
85
86
87 # Legacy name, left for integrations.
88 FileMode = Mode
89
90
91 def read_pyproject_toml(
92     ctx: click.Context, param: click.Parameter, value: Optional[str]
93 ) -> Optional[str]:
94     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
95
96     Returns the path to a successfully found and read configuration file, None
97     otherwise.
98     """
99     if not value:
100         value = find_pyproject_toml(ctx.params.get("src", ()))
101         if value is None:
102             return None
103
104     try:
105         config = parse_pyproject_toml(value)
106     except (OSError, ValueError) as e:
107         raise click.FileError(
108             filename=value, hint=f"Error reading configuration file: {e}"
109         )
110
111     if not config:
112         return None
113     else:
114         # Sanitize the values to be Click friendly. For more information please see:
115         # https://github.com/psf/black/issues/1458
116         # https://github.com/pallets/click/issues/1567
117         config = {
118             k: str(v) if not isinstance(v, (list, dict)) else v
119             for k, v in config.items()
120         }
121
122     target_version = config.get("target_version")
123     if target_version is not None and not isinstance(target_version, list):
124         raise click.BadOptionUsage(
125             "target-version", "Config key target-version must be a list"
126         )
127
128     default_map: Dict[str, Any] = {}
129     if ctx.default_map:
130         default_map.update(ctx.default_map)
131     default_map.update(config)
132
133     ctx.default_map = default_map
134     return value
135
136
137 def target_version_option_callback(
138     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
139 ) -> List[TargetVersion]:
140     """Compute the target versions from a --target-version flag.
141
142     This is its own function because mypy couldn't infer the type correctly
143     when it was a lambda, causing mypyc trouble.
144     """
145     return [TargetVersion[val.upper()] for val in v]
146
147
148 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
149     """Compile a regular expression string in `regex`.
150
151     If it contains newlines, use verbose mode.
152     """
153     if "\n" in regex:
154         regex = "(?x)" + regex
155     compiled: Pattern[str] = re.compile(regex)
156     return compiled
157
158
159 def validate_regex(
160     ctx: click.Context,
161     param: click.Parameter,
162     value: Optional[str],
163 ) -> Optional[Pattern]:
164     try:
165         return re_compile_maybe_verbose(value) if value is not None else None
166     except re.error:
167         raise click.BadParameter("Not a valid regular expression")
168
169
170 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
171 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
172 @click.option(
173     "-l",
174     "--line-length",
175     type=int,
176     default=DEFAULT_LINE_LENGTH,
177     help="How many characters per line to allow.",
178     show_default=True,
179 )
180 @click.option(
181     "-t",
182     "--target-version",
183     type=click.Choice([v.name.lower() for v in TargetVersion]),
184     callback=target_version_option_callback,
185     multiple=True,
186     help=(
187         "Python versions that should be supported by Black's output. [default: per-file"
188         " auto-detection]"
189     ),
190 )
191 @click.option(
192     "--pyi",
193     is_flag=True,
194     help=(
195         "Format all input files like typing stubs regardless of file extension (useful"
196         " when piping source on standard input)."
197     ),
198 )
199 @click.option(
200     "-S",
201     "--skip-string-normalization",
202     is_flag=True,
203     help="Don't normalize string quotes or prefixes.",
204 )
205 @click.option(
206     "-C",
207     "--skip-magic-trailing-comma",
208     is_flag=True,
209     help="Don't use trailing commas as a reason to split lines.",
210 )
211 @click.option(
212     "--experimental-string-processing",
213     is_flag=True,
214     hidden=True,
215     help=(
216         "Experimental option that performs more normalization on string literals."
217         " Currently disabled because it leads to some crashes."
218     ),
219 )
220 @click.option(
221     "--check",
222     is_flag=True,
223     help=(
224         "Don't write the files back, just return the status. Return code 0 means"
225         " nothing would change. Return code 1 means some files would be reformatted."
226         " Return code 123 means there was an internal error."
227     ),
228 )
229 @click.option(
230     "--diff",
231     is_flag=True,
232     help="Don't write the files back, just output a diff for each file on stdout.",
233 )
234 @click.option(
235     "--color/--no-color",
236     is_flag=True,
237     help="Show colored diff. Only applies when `--diff` is given.",
238 )
239 @click.option(
240     "--fast/--safe",
241     is_flag=True,
242     help="If --fast given, skip temporary sanity checks. [default: --safe]",
243 )
244 @click.option(
245     "--required-version",
246     type=str,
247     help=(
248         "Require a specific version of Black to be running (useful for unifying results"
249         " across many environments e.g. with a pyproject.toml file)."
250     ),
251 )
252 @click.option(
253     "--include",
254     type=str,
255     default=DEFAULT_INCLUDES,
256     callback=validate_regex,
257     help=(
258         "A regular expression that matches files and directories that should be"
259         " included on recursive searches. An empty value means all files are included"
260         " regardless of the name. Use forward slashes for directories on all platforms"
261         " (Windows, too). Exclusions are calculated first, inclusions later."
262     ),
263     show_default=True,
264 )
265 @click.option(
266     "--exclude",
267     type=str,
268     callback=validate_regex,
269     help=(
270         "A regular expression that matches files and directories that should be"
271         " excluded on recursive searches. An empty value means no paths are excluded."
272         " Use forward slashes for directories on all platforms (Windows, too)."
273         " Exclusions are calculated first, inclusions later. [default:"
274         f" {DEFAULT_EXCLUDES}]"
275     ),
276     show_default=False,
277 )
278 @click.option(
279     "--extend-exclude",
280     type=str,
281     callback=validate_regex,
282     help=(
283         "Like --exclude, but adds additional files and directories on top of the"
284         " excluded ones. (Useful if you simply want to add to the default)"
285     ),
286 )
287 @click.option(
288     "--force-exclude",
289     type=str,
290     callback=validate_regex,
291     help=(
292         "Like --exclude, but files and directories matching this regex will be "
293         "excluded even when they are passed explicitly as arguments."
294     ),
295 )
296 @click.option(
297     "--stdin-filename",
298     type=str,
299     help=(
300         "The name of the file when passing it through stdin. Useful to make "
301         "sure Black will respect --force-exclude option on some "
302         "editors that rely on using stdin."
303     ),
304 )
305 @click.option(
306     "-q",
307     "--quiet",
308     is_flag=True,
309     help=(
310         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
311         " those with 2>/dev/null."
312     ),
313 )
314 @click.option(
315     "-v",
316     "--verbose",
317     is_flag=True,
318     help=(
319         "Also emit messages to stderr about files that were not changed or were ignored"
320         " due to exclusion patterns."
321     ),
322 )
323 @click.version_option(version=__version__)
324 @click.argument(
325     "src",
326     nargs=-1,
327     type=click.Path(
328         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
329     ),
330     is_eager=True,
331     metavar="SRC ...",
332 )
333 @click.option(
334     "--config",
335     type=click.Path(
336         exists=True,
337         file_okay=True,
338         dir_okay=False,
339         readable=True,
340         allow_dash=False,
341         path_type=str,
342     ),
343     is_eager=True,
344     callback=read_pyproject_toml,
345     help="Read configuration from FILE path.",
346 )
347 @click.pass_context
348 def main(
349     ctx: click.Context,
350     code: Optional[str],
351     line_length: int,
352     target_version: List[TargetVersion],
353     check: bool,
354     diff: bool,
355     color: bool,
356     fast: bool,
357     pyi: bool,
358     skip_string_normalization: bool,
359     skip_magic_trailing_comma: bool,
360     experimental_string_processing: bool,
361     quiet: bool,
362     verbose: bool,
363     required_version: str,
364     include: Pattern,
365     exclude: Optional[Pattern],
366     extend_exclude: Optional[Pattern],
367     force_exclude: Optional[Pattern],
368     stdin_filename: Optional[str],
369     src: Tuple[str, ...],
370     config: Optional[str],
371 ) -> None:
372     """The uncompromising code formatter."""
373     if config and verbose:
374         out(f"Using configuration from {config}.", bold=False, fg="blue")
375
376     error_msg = "Oh no! 💥 💔 💥"
377     if required_version and required_version != __version__:
378         err(
379             f"{error_msg} The required version `{required_version}` does not match"
380             f" the running version `{__version__}`!"
381         )
382         ctx.exit(1)
383
384     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
385     if target_version:
386         versions = set(target_version)
387     else:
388         # We'll autodetect later.
389         versions = set()
390     mode = Mode(
391         target_versions=versions,
392         line_length=line_length,
393         is_pyi=pyi,
394         string_normalization=not skip_string_normalization,
395         magic_trailing_comma=not skip_magic_trailing_comma,
396         experimental_string_processing=experimental_string_processing,
397     )
398
399     if code is not None:
400         # Run in quiet mode by default with -c; the extra output isn't useful.
401         # You can still pass -v to get verbose output.
402         quiet = True
403
404     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
405
406     if code is not None:
407         reformat_code(
408             content=code, fast=fast, write_back=write_back, mode=mode, report=report
409         )
410     else:
411         sources = get_sources(
412             ctx=ctx,
413             src=src,
414             quiet=quiet,
415             verbose=verbose,
416             include=include,
417             exclude=exclude,
418             extend_exclude=extend_exclude,
419             force_exclude=force_exclude,
420             report=report,
421             stdin_filename=stdin_filename,
422         )
423
424         path_empty(
425             sources,
426             "No Python files are present to be formatted. Nothing to do 😴",
427             quiet,
428             verbose,
429             ctx,
430         )
431
432         if len(sources) == 1:
433             reformat_one(
434                 src=sources.pop(),
435                 fast=fast,
436                 write_back=write_back,
437                 mode=mode,
438                 report=report,
439             )
440         else:
441             reformat_many(
442                 sources=sources,
443                 fast=fast,
444                 write_back=write_back,
445                 mode=mode,
446                 report=report,
447             )
448
449     if verbose or not quiet:
450         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
451         if code is None:
452             click.echo(str(report), err=True)
453     ctx.exit(report.return_code)
454
455
456 def get_sources(
457     *,
458     ctx: click.Context,
459     src: Tuple[str, ...],
460     quiet: bool,
461     verbose: bool,
462     include: Pattern[str],
463     exclude: Optional[Pattern[str]],
464     extend_exclude: Optional[Pattern[str]],
465     force_exclude: Optional[Pattern[str]],
466     report: "Report",
467     stdin_filename: Optional[str],
468 ) -> Set[Path]:
469     """Compute the set of files to be formatted."""
470
471     root = find_project_root(src)
472     sources: Set[Path] = set()
473     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
474
475     if exclude is None:
476         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
477         gitignore = get_gitignore(root)
478     else:
479         gitignore = None
480
481     for s in src:
482         if s == "-" and stdin_filename:
483             p = Path(stdin_filename)
484             is_stdin = True
485         else:
486             p = Path(s)
487             is_stdin = False
488
489         if is_stdin or p.is_file():
490             normalized_path = normalize_path_maybe_ignore(p, root, report)
491             if normalized_path is None:
492                 continue
493
494             normalized_path = "/" + normalized_path
495             # Hard-exclude any files that matches the `--force-exclude` regex.
496             if force_exclude:
497                 force_exclude_match = force_exclude.search(normalized_path)
498             else:
499                 force_exclude_match = None
500             if force_exclude_match and force_exclude_match.group(0):
501                 report.path_ignored(p, "matches the --force-exclude regular expression")
502                 continue
503
504             if is_stdin:
505                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
506
507             sources.add(p)
508         elif p.is_dir():
509             sources.update(
510                 gen_python_files(
511                     p.iterdir(),
512                     root,
513                     include,
514                     exclude,
515                     extend_exclude,
516                     force_exclude,
517                     report,
518                     gitignore,
519                 )
520             )
521         elif s == "-":
522             sources.add(p)
523         else:
524             err(f"invalid path: {s}")
525     return sources
526
527
528 def path_empty(
529     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
530 ) -> None:
531     """
532     Exit if there is no `src` provided for formatting
533     """
534     if not src:
535         if verbose or not quiet:
536             out(msg)
537         ctx.exit(0)
538
539
540 def reformat_code(
541     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
542 ) -> None:
543     """
544     Reformat and print out `content` without spawning child processes.
545     Similar to `reformat_one`, but for string content.
546
547     `fast`, `write_back`, and `mode` options are passed to
548     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
549     """
550     path = Path("<string>")
551     try:
552         changed = Changed.NO
553         if format_stdin_to_stdout(
554             content=content, fast=fast, write_back=write_back, mode=mode
555         ):
556             changed = Changed.YES
557         report.done(path, changed)
558     except Exception as exc:
559         if report.verbose:
560             traceback.print_exc()
561         report.failed(path, str(exc))
562
563
564 def reformat_one(
565     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
566 ) -> None:
567     """Reformat a single file under `src` without spawning child processes.
568
569     `fast`, `write_back`, and `mode` options are passed to
570     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
571     """
572     try:
573         changed = Changed.NO
574
575         if str(src) == "-":
576             is_stdin = True
577         elif str(src).startswith(STDIN_PLACEHOLDER):
578             is_stdin = True
579             # Use the original name again in case we want to print something
580             # to the user
581             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
582         else:
583             is_stdin = False
584
585         if is_stdin:
586             if src.suffix == ".pyi":
587                 mode = replace(mode, is_pyi=True)
588             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
589                 changed = Changed.YES
590         else:
591             cache: Cache = {}
592             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
593                 cache = read_cache(mode)
594                 res_src = src.resolve()
595                 res_src_s = str(res_src)
596                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
597                     changed = Changed.CACHED
598             if changed is not Changed.CACHED and format_file_in_place(
599                 src, fast=fast, write_back=write_back, mode=mode
600             ):
601                 changed = Changed.YES
602             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
603                 write_back is WriteBack.CHECK and changed is Changed.NO
604             ):
605                 write_cache(cache, [src], mode)
606         report.done(src, changed)
607     except Exception as exc:
608         if report.verbose:
609             traceback.print_exc()
610         report.failed(src, str(exc))
611
612
613 def reformat_many(
614     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
615 ) -> None:
616     """Reformat multiple files using a ProcessPoolExecutor."""
617     executor: Executor
618     loop = asyncio.get_event_loop()
619     worker_count = os.cpu_count()
620     if sys.platform == "win32":
621         # Work around https://bugs.python.org/issue26903
622         worker_count = min(worker_count, 60)
623     try:
624         executor = ProcessPoolExecutor(max_workers=worker_count)
625     except (ImportError, OSError):
626         # we arrive here if the underlying system does not support multi-processing
627         # like in AWS Lambda or Termux, in which case we gracefully fallback to
628         # a ThreadPoolExecutor with just a single worker (more workers would not do us
629         # any good due to the Global Interpreter Lock)
630         executor = ThreadPoolExecutor(max_workers=1)
631
632     try:
633         loop.run_until_complete(
634             schedule_formatting(
635                 sources=sources,
636                 fast=fast,
637                 write_back=write_back,
638                 mode=mode,
639                 report=report,
640                 loop=loop,
641                 executor=executor,
642             )
643         )
644     finally:
645         shutdown(loop)
646         if executor is not None:
647             executor.shutdown()
648
649
650 async def schedule_formatting(
651     sources: Set[Path],
652     fast: bool,
653     write_back: WriteBack,
654     mode: Mode,
655     report: "Report",
656     loop: asyncio.AbstractEventLoop,
657     executor: Executor,
658 ) -> None:
659     """Run formatting of `sources` in parallel using the provided `executor`.
660
661     (Use ProcessPoolExecutors for actual parallelism.)
662
663     `write_back`, `fast`, and `mode` options are passed to
664     :func:`format_file_in_place`.
665     """
666     cache: Cache = {}
667     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
668         cache = read_cache(mode)
669         sources, cached = filter_cached(cache, sources)
670         for src in sorted(cached):
671             report.done(src, Changed.CACHED)
672     if not sources:
673         return
674
675     cancelled = []
676     sources_to_cache = []
677     lock = None
678     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
679         # For diff output, we need locks to ensure we don't interleave output
680         # from different processes.
681         manager = Manager()
682         lock = manager.Lock()
683     tasks = {
684         asyncio.ensure_future(
685             loop.run_in_executor(
686                 executor, format_file_in_place, src, fast, mode, write_back, lock
687             )
688         ): src
689         for src in sorted(sources)
690     }
691     pending = tasks.keys()
692     try:
693         loop.add_signal_handler(signal.SIGINT, cancel, pending)
694         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
695     except NotImplementedError:
696         # There are no good alternatives for these on Windows.
697         pass
698     while pending:
699         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
700         for task in done:
701             src = tasks.pop(task)
702             if task.cancelled():
703                 cancelled.append(task)
704             elif task.exception():
705                 report.failed(src, str(task.exception()))
706             else:
707                 changed = Changed.YES if task.result() else Changed.NO
708                 # If the file was written back or was successfully checked as
709                 # well-formatted, store this information in the cache.
710                 if write_back is WriteBack.YES or (
711                     write_back is WriteBack.CHECK and changed is Changed.NO
712                 ):
713                     sources_to_cache.append(src)
714                 report.done(src, changed)
715     if cancelled:
716         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
717     if sources_to_cache:
718         write_cache(cache, sources_to_cache, mode)
719
720
721 def format_file_in_place(
722     src: Path,
723     fast: bool,
724     mode: Mode,
725     write_back: WriteBack = WriteBack.NO,
726     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
727 ) -> bool:
728     """Format file under `src` path. Return True if changed.
729
730     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
731     code to the file.
732     `mode` and `fast` options are passed to :func:`format_file_contents`.
733     """
734     if src.suffix == ".pyi":
735         mode = replace(mode, is_pyi=True)
736
737     then = datetime.utcfromtimestamp(src.stat().st_mtime)
738     with open(src, "rb") as buf:
739         src_contents, encoding, newline = decode_bytes(buf.read())
740     try:
741         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
742     except NothingChanged:
743         return False
744
745     if write_back == WriteBack.YES:
746         with open(src, "w", encoding=encoding, newline=newline) as f:
747             f.write(dst_contents)
748     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
749         now = datetime.utcnow()
750         src_name = f"{src}\t{then} +0000"
751         dst_name = f"{src}\t{now} +0000"
752         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
753
754         if write_back == WriteBack.COLOR_DIFF:
755             diff_contents = color_diff(diff_contents)
756
757         with lock or nullcontext():
758             f = io.TextIOWrapper(
759                 sys.stdout.buffer,
760                 encoding=encoding,
761                 newline=newline,
762                 write_through=True,
763             )
764             f = wrap_stream_for_windows(f)
765             f.write(diff_contents)
766             f.detach()
767
768     return True
769
770
771 def format_stdin_to_stdout(
772     fast: bool,
773     *,
774     content: Optional[str] = None,
775     write_back: WriteBack = WriteBack.NO,
776     mode: Mode,
777 ) -> bool:
778     """Format file on stdin. Return True if changed.
779
780     If content is None, it's read from sys.stdin.
781
782     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
783     write a diff to stdout. The `mode` argument is passed to
784     :func:`format_file_contents`.
785     """
786     then = datetime.utcnow()
787
788     if content is None:
789         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
790     else:
791         src, encoding, newline = content, "utf-8", ""
792
793     dst = src
794     try:
795         dst = format_file_contents(src, fast=fast, mode=mode)
796         return True
797
798     except NothingChanged:
799         return False
800
801     finally:
802         f = io.TextIOWrapper(
803             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
804         )
805         if write_back == WriteBack.YES:
806             # Make sure there's a newline after the content
807             if dst and dst[-1] != "\n":
808                 dst += "\n"
809             f.write(dst)
810         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
811             now = datetime.utcnow()
812             src_name = f"STDIN\t{then} +0000"
813             dst_name = f"STDOUT\t{now} +0000"
814             d = diff(src, dst, src_name, dst_name)
815             if write_back == WriteBack.COLOR_DIFF:
816                 d = color_diff(d)
817                 f = wrap_stream_for_windows(f)
818             f.write(d)
819         f.detach()
820
821
822 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
823     """Reformat contents of a file and return new contents.
824
825     If `fast` is False, additionally confirm that the reformatted code is
826     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
827     `mode` is passed to :func:`format_str`.
828     """
829     if not src_contents.strip():
830         raise NothingChanged
831
832     dst_contents = format_str(src_contents, mode=mode)
833     if src_contents == dst_contents:
834         raise NothingChanged
835
836     if not fast:
837         assert_equivalent(src_contents, dst_contents)
838
839         # Forced second pass to work around optional trailing commas (becoming
840         # forced trailing commas on pass 2) interacting differently with optional
841         # parentheses.  Admittedly ugly.
842         dst_contents_pass2 = format_str(dst_contents, mode=mode)
843         if dst_contents != dst_contents_pass2:
844             dst_contents = dst_contents_pass2
845             assert_equivalent(src_contents, dst_contents, pass_num=2)
846             assert_stable(src_contents, dst_contents, mode=mode)
847         # Note: no need to explicitly call `assert_stable` if `dst_contents` was
848         # the same as `dst_contents_pass2`.
849     return dst_contents
850
851
852 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
853     """Reformat a string and return new contents.
854
855     `mode` determines formatting options, such as how many characters per line are
856     allowed.  Example:
857
858     >>> import black
859     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
860     def f(arg: str = "") -> None:
861         ...
862
863     A more complex example:
864
865     >>> print(
866     ...   black.format_str(
867     ...     "def f(arg:str='')->None: hey",
868     ...     mode=black.Mode(
869     ...       target_versions={black.TargetVersion.PY36},
870     ...       line_length=10,
871     ...       string_normalization=False,
872     ...       is_pyi=False,
873     ...     ),
874     ...   ),
875     ... )
876     def f(
877         arg: str = '',
878     ) -> None:
879         hey
880
881     """
882     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
883     dst_contents = []
884     future_imports = get_future_imports(src_node)
885     if mode.target_versions:
886         versions = mode.target_versions
887     else:
888         versions = detect_target_versions(src_node)
889     normalize_fmt_off(src_node)
890     lines = LineGenerator(
891         mode=mode,
892         remove_u_prefix="unicode_literals" in future_imports
893         or supports_feature(versions, Feature.UNICODE_LITERALS),
894     )
895     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
896     empty_line = Line(mode=mode)
897     after = 0
898     split_line_features = {
899         feature
900         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
901         if supports_feature(versions, feature)
902     }
903     for current_line in lines.visit(src_node):
904         dst_contents.append(str(empty_line) * after)
905         before, after = elt.maybe_empty_lines(current_line)
906         dst_contents.append(str(empty_line) * before)
907         for line in transform_line(
908             current_line, mode=mode, features=split_line_features
909         ):
910             dst_contents.append(str(line))
911     return "".join(dst_contents)
912
913
914 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
915     """Return a tuple of (decoded_contents, encoding, newline).
916
917     `newline` is either CRLF or LF but `decoded_contents` is decoded with
918     universal newlines (i.e. only contains LF).
919     """
920     srcbuf = io.BytesIO(src)
921     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
922     if not lines:
923         return "", encoding, "\n"
924
925     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
926     srcbuf.seek(0)
927     with io.TextIOWrapper(srcbuf, encoding) as tiow:
928         return tiow.read(), encoding, newline
929
930
931 def get_features_used(node: Node) -> Set[Feature]:
932     """Return a set of (relatively) new Python features used in this file.
933
934     Currently looking for:
935     - f-strings;
936     - underscores in numeric literals;
937     - trailing commas after * or ** in function signatures and calls;
938     - positional only arguments in function signatures and lambdas;
939     - assignment expression;
940     - relaxed decorator syntax;
941     """
942     features: Set[Feature] = set()
943     for n in node.pre_order():
944         if n.type == token.STRING:
945             value_head = n.value[:2]  # type: ignore
946             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
947                 features.add(Feature.F_STRINGS)
948
949         elif n.type == token.NUMBER:
950             if "_" in n.value:  # type: ignore
951                 features.add(Feature.NUMERIC_UNDERSCORES)
952
953         elif n.type == token.SLASH:
954             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
955                 features.add(Feature.POS_ONLY_ARGUMENTS)
956
957         elif n.type == token.COLONEQUAL:
958             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
959
960         elif n.type == syms.decorator:
961             if len(n.children) > 1 and not is_simple_decorator_expression(
962                 n.children[1]
963             ):
964                 features.add(Feature.RELAXED_DECORATORS)
965
966         elif (
967             n.type in {syms.typedargslist, syms.arglist}
968             and n.children
969             and n.children[-1].type == token.COMMA
970         ):
971             if n.type == syms.typedargslist:
972                 feature = Feature.TRAILING_COMMA_IN_DEF
973             else:
974                 feature = Feature.TRAILING_COMMA_IN_CALL
975
976             for ch in n.children:
977                 if ch.type in STARS:
978                     features.add(feature)
979
980                 if ch.type == syms.argument:
981                     for argch in ch.children:
982                         if argch.type in STARS:
983                             features.add(feature)
984
985     return features
986
987
988 def detect_target_versions(node: Node) -> Set[TargetVersion]:
989     """Detect the version to target based on the nodes used."""
990     features = get_features_used(node)
991     return {
992         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
993     }
994
995
996 def get_future_imports(node: Node) -> Set[str]:
997     """Return a set of __future__ imports in the file."""
998     imports: Set[str] = set()
999
1000     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1001         for child in children:
1002             if isinstance(child, Leaf):
1003                 if child.type == token.NAME:
1004                     yield child.value
1005
1006             elif child.type == syms.import_as_name:
1007                 orig_name = child.children[0]
1008                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1009                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1010                 yield orig_name.value
1011
1012             elif child.type == syms.import_as_names:
1013                 yield from get_imports_from_children(child.children)
1014
1015             else:
1016                 raise AssertionError("Invalid syntax parsing imports")
1017
1018     for child in node.children:
1019         if child.type != syms.simple_stmt:
1020             break
1021
1022         first_child = child.children[0]
1023         if isinstance(first_child, Leaf):
1024             # Continue looking if we see a docstring; otherwise stop.
1025             if (
1026                 len(child.children) == 2
1027                 and first_child.type == token.STRING
1028                 and child.children[1].type == token.NEWLINE
1029             ):
1030                 continue
1031
1032             break
1033
1034         elif first_child.type == syms.import_from:
1035             module_name = first_child.children[1]
1036             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1037                 break
1038
1039             imports |= set(get_imports_from_children(first_child.children[3:]))
1040         else:
1041             break
1042
1043     return imports
1044
1045
1046 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1047     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1048     try:
1049         src_ast = parse_ast(src)
1050     except Exception as exc:
1051         raise AssertionError(
1052             "cannot use --safe with this file; failed to parse source file.  AST"
1053             f" error message: {exc}"
1054         )
1055
1056     try:
1057         dst_ast = parse_ast(dst)
1058     except Exception as exc:
1059         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1060         raise AssertionError(
1061             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1062             "Please report a bug on https://github.com/psf/black/issues.  "
1063             f"This invalid output might be helpful: {log}"
1064         ) from None
1065
1066     src_ast_str = "\n".join(stringify_ast(src_ast))
1067     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1068     if src_ast_str != dst_ast_str:
1069         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1070         raise AssertionError(
1071             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1072             f" source on pass {pass_num}.  Please report a bug on "
1073             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1074         ) from None
1075
1076
1077 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1078     """Raise AssertionError if `dst` reformats differently the second time."""
1079     newdst = format_str(dst, mode=mode)
1080     if dst != newdst:
1081         log = dump_to_file(
1082             str(mode),
1083             diff(src, dst, "source", "first pass"),
1084             diff(dst, newdst, "first pass", "second pass"),
1085         )
1086         raise AssertionError(
1087             "INTERNAL ERROR: Black produced different code on the second pass of the"
1088             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1089             f"  This diff might be helpful: {log}"
1090         ) from None
1091
1092
1093 @contextmanager
1094 def nullcontext() -> Iterator[None]:
1095     """Return an empty context manager.
1096
1097     To be used like `nullcontext` in Python 3.7.
1098     """
1099     yield
1100
1101
1102 def patch_click() -> None:
1103     """Make Click not crash on Python 3.6 with LANG=C.
1104
1105     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1106     default which restricts paths that it can access during the lifetime of the
1107     application.  Click refuses to work in this scenario by raising a RuntimeError.
1108
1109     In case of Black the likelihood that non-ASCII characters are going to be used in
1110     file paths is minimal since it's Python source code.  Moreover, this crash was
1111     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1112     """
1113     try:
1114         from click import core
1115         from click import _unicodefun  # type: ignore
1116     except ModuleNotFoundError:
1117         return
1118
1119     for module in (core, _unicodefun):
1120         if hasattr(module, "_verify_python3_env"):
1121             module._verify_python3_env = lambda: None  # type: ignore
1122         if hasattr(module, "_verify_python_env"):
1123             module._verify_python_env = lambda: None  # type: ignore
1124
1125
1126 def patched_main() -> None:
1127     maybe_install_uvloop()
1128     freeze_support()
1129     patch_click()
1130     main()
1131
1132
1133 if __name__ == "__main__":
1134     patched_main()