]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Update pre-commit config (#2331)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
3 from contextlib import contextmanager
4 from datetime import datetime
5 from enum import Enum
6 import io
7 from multiprocessing import Manager, freeze_support
8 import os
9 from pathlib import Path
10 import regex as re
11 import signal
12 import sys
13 import tokenize
14 import traceback
15 from typing import (
16     Any,
17     Dict,
18     Generator,
19     Iterator,
20     List,
21     Optional,
22     Pattern,
23     Set,
24     Sized,
25     Tuple,
26     Union,
27 )
28
29 from dataclasses import replace
30 import click
31
32 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
33 from black.const import STDIN_PLACEHOLDER
34 from black.nodes import STARS, syms, is_simple_decorator_expression
35 from black.lines import Line, EmptyLineTracker
36 from black.linegen import transform_line, LineGenerator, LN
37 from black.comments import normalize_fmt_off
38 from black.mode import Mode, TargetVersion
39 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
40 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
41 from black.concurrency import cancel, shutdown, maybe_install_uvloop
42 from black.output import dump_to_file, diff, color_diff, out, err
43 from black.report import Report, Changed
44 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
45 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
46 from black.files import wrap_stream_for_windows
47 from black.parsing import InvalidInput  # noqa F401
48 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
49
50
51 # lib2to3 fork
52 from blib2to3.pytree import Node, Leaf
53 from blib2to3.pgen2 import token
54
55 from _black_version import version as __version__
56
57 # types
58 FileContent = str
59 Encoding = str
60 NewLine = str
61
62
63 class NothingChanged(UserWarning):
64     """Raised when reformatted code is the same as source."""
65
66
67 class WriteBack(Enum):
68     NO = 0
69     YES = 1
70     DIFF = 2
71     CHECK = 3
72     COLOR_DIFF = 4
73
74     @classmethod
75     def from_configuration(
76         cls, *, check: bool, diff: bool, color: bool = False
77     ) -> "WriteBack":
78         if check and not diff:
79             return cls.CHECK
80
81         if diff and color:
82             return cls.COLOR_DIFF
83
84         return cls.DIFF if diff else cls.YES
85
86
87 # Legacy name, left for integrations.
88 FileMode = Mode
89
90
91 def read_pyproject_toml(
92     ctx: click.Context, param: click.Parameter, value: Optional[str]
93 ) -> Optional[str]:
94     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
95
96     Returns the path to a successfully found and read configuration file, None
97     otherwise.
98     """
99     if not value:
100         value = find_pyproject_toml(ctx.params.get("src", ()))
101         if value is None:
102             return None
103
104     try:
105         config = parse_pyproject_toml(value)
106     except (OSError, ValueError) as e:
107         raise click.FileError(
108             filename=value, hint=f"Error reading configuration file: {e}"
109         )
110
111     if not config:
112         return None
113     else:
114         # Sanitize the values to be Click friendly. For more information please see:
115         # https://github.com/psf/black/issues/1458
116         # https://github.com/pallets/click/issues/1567
117         config = {
118             k: str(v) if not isinstance(v, (list, dict)) else v
119             for k, v in config.items()
120         }
121
122     target_version = config.get("target_version")
123     if target_version is not None and not isinstance(target_version, list):
124         raise click.BadOptionUsage(
125             "target-version", "Config key target-version must be a list"
126         )
127
128     default_map: Dict[str, Any] = {}
129     if ctx.default_map:
130         default_map.update(ctx.default_map)
131     default_map.update(config)
132
133     ctx.default_map = default_map
134     return value
135
136
137 def target_version_option_callback(
138     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
139 ) -> List[TargetVersion]:
140     """Compute the target versions from a --target-version flag.
141
142     This is its own function because mypy couldn't infer the type correctly
143     when it was a lambda, causing mypyc trouble.
144     """
145     return [TargetVersion[val.upper()] for val in v]
146
147
148 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
149     """Compile a regular expression string in `regex`.
150
151     If it contains newlines, use verbose mode.
152     """
153     if "\n" in regex:
154         regex = "(?x)" + regex
155     compiled: Pattern[str] = re.compile(regex)
156     return compiled
157
158
159 def validate_regex(
160     ctx: click.Context,
161     param: click.Parameter,
162     value: Optional[str],
163 ) -> Optional[Pattern]:
164     try:
165         return re_compile_maybe_verbose(value) if value is not None else None
166     except re.error:
167         raise click.BadParameter("Not a valid regular expression")
168
169
170 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
171 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
172 @click.option(
173     "-l",
174     "--line-length",
175     type=int,
176     default=DEFAULT_LINE_LENGTH,
177     help="How many characters per line to allow.",
178     show_default=True,
179 )
180 @click.option(
181     "-t",
182     "--target-version",
183     type=click.Choice([v.name.lower() for v in TargetVersion]),
184     callback=target_version_option_callback,
185     multiple=True,
186     help=(
187         "Python versions that should be supported by Black's output. [default: per-file"
188         " auto-detection]"
189     ),
190 )
191 @click.option(
192     "--pyi",
193     is_flag=True,
194     help=(
195         "Format all input files like typing stubs regardless of file extension (useful"
196         " when piping source on standard input)."
197     ),
198 )
199 @click.option(
200     "-S",
201     "--skip-string-normalization",
202     is_flag=True,
203     help="Don't normalize string quotes or prefixes.",
204 )
205 @click.option(
206     "-C",
207     "--skip-magic-trailing-comma",
208     is_flag=True,
209     help="Don't use trailing commas as a reason to split lines.",
210 )
211 @click.option(
212     "--experimental-string-processing",
213     is_flag=True,
214     hidden=True,
215     help=(
216         "Experimental option that performs more normalization on string literals."
217         " Currently disabled because it leads to some crashes."
218     ),
219 )
220 @click.option(
221     "--check",
222     is_flag=True,
223     help=(
224         "Don't write the files back, just return the status. Return code 0 means"
225         " nothing would change. Return code 1 means some files would be reformatted."
226         " Return code 123 means there was an internal error."
227     ),
228 )
229 @click.option(
230     "--diff",
231     is_flag=True,
232     help="Don't write the files back, just output a diff for each file on stdout.",
233 )
234 @click.option(
235     "--color/--no-color",
236     is_flag=True,
237     help="Show colored diff. Only applies when `--diff` is given.",
238 )
239 @click.option(
240     "--fast/--safe",
241     is_flag=True,
242     help="If --fast given, skip temporary sanity checks. [default: --safe]",
243 )
244 @click.option(
245     "--required-version",
246     type=str,
247     help=(
248         "Require a specific version of Black to be running (useful for unifying results"
249         " across many environments e.g. with a pyproject.toml file)."
250     ),
251 )
252 @click.option(
253     "--include",
254     type=str,
255     default=DEFAULT_INCLUDES,
256     callback=validate_regex,
257     help=(
258         "A regular expression that matches files and directories that should be"
259         " included on recursive searches. An empty value means all files are included"
260         " regardless of the name. Use forward slashes for directories on all platforms"
261         " (Windows, too). Exclusions are calculated first, inclusions later."
262     ),
263     show_default=True,
264 )
265 @click.option(
266     "--exclude",
267     type=str,
268     callback=validate_regex,
269     help=(
270         "A regular expression that matches files and directories that should be"
271         " excluded on recursive searches. An empty value means no paths are excluded."
272         " Use forward slashes for directories on all platforms (Windows, too)."
273         " Exclusions are calculated first, inclusions later. [default:"
274         f" {DEFAULT_EXCLUDES}]"
275     ),
276     show_default=False,
277 )
278 @click.option(
279     "--extend-exclude",
280     type=str,
281     callback=validate_regex,
282     help=(
283         "Like --exclude, but adds additional files and directories on top of the"
284         " excluded ones. (Useful if you simply want to add to the default)"
285     ),
286 )
287 @click.option(
288     "--force-exclude",
289     type=str,
290     callback=validate_regex,
291     help=(
292         "Like --exclude, but files and directories matching this regex will be "
293         "excluded even when they are passed explicitly as arguments."
294     ),
295 )
296 @click.option(
297     "--stdin-filename",
298     type=str,
299     help=(
300         "The name of the file when passing it through stdin. Useful to make "
301         "sure Black will respect --force-exclude option on some "
302         "editors that rely on using stdin."
303     ),
304 )
305 @click.option(
306     "-q",
307     "--quiet",
308     is_flag=True,
309     help=(
310         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
311         " those with 2>/dev/null."
312     ),
313 )
314 @click.option(
315     "-v",
316     "--verbose",
317     is_flag=True,
318     help=(
319         "Also emit messages to stderr about files that were not changed or were ignored"
320         " due to exclusion patterns."
321     ),
322 )
323 @click.version_option(version=__version__)
324 @click.argument(
325     "src",
326     nargs=-1,
327     type=click.Path(
328         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
329     ),
330     is_eager=True,
331 )
332 @click.option(
333     "--config",
334     type=click.Path(
335         exists=True,
336         file_okay=True,
337         dir_okay=False,
338         readable=True,
339         allow_dash=False,
340         path_type=str,
341     ),
342     is_eager=True,
343     callback=read_pyproject_toml,
344     help="Read configuration from FILE path.",
345 )
346 @click.pass_context
347 def main(
348     ctx: click.Context,
349     code: Optional[str],
350     line_length: int,
351     target_version: List[TargetVersion],
352     check: bool,
353     diff: bool,
354     color: bool,
355     fast: bool,
356     pyi: bool,
357     skip_string_normalization: bool,
358     skip_magic_trailing_comma: bool,
359     experimental_string_processing: bool,
360     quiet: bool,
361     verbose: bool,
362     required_version: str,
363     include: Pattern,
364     exclude: Optional[Pattern],
365     extend_exclude: Optional[Pattern],
366     force_exclude: Optional[Pattern],
367     stdin_filename: Optional[str],
368     src: Tuple[str, ...],
369     config: Optional[str],
370 ) -> None:
371     """The uncompromising code formatter."""
372     if config and verbose:
373         out(f"Using configuration from {config}.", bold=False, fg="blue")
374
375     error_msg = "Oh no! 💥 💔 💥"
376     if required_version and required_version != __version__:
377         err(
378             f"{error_msg} The required version `{required_version}` does not match"
379             f" the running version `{__version__}`!"
380         )
381         ctx.exit(1)
382
383     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
384     if target_version:
385         versions = set(target_version)
386     else:
387         # We'll autodetect later.
388         versions = set()
389     mode = Mode(
390         target_versions=versions,
391         line_length=line_length,
392         is_pyi=pyi,
393         string_normalization=not skip_string_normalization,
394         magic_trailing_comma=not skip_magic_trailing_comma,
395         experimental_string_processing=experimental_string_processing,
396     )
397
398     if code is not None:
399         # Run in quiet mode by default with -c; the extra output isn't useful.
400         # You can still pass -v to get verbose output.
401         quiet = True
402
403     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
404
405     if code is not None:
406         reformat_code(
407             content=code, fast=fast, write_back=write_back, mode=mode, report=report
408         )
409     else:
410         sources = get_sources(
411             ctx=ctx,
412             src=src,
413             quiet=quiet,
414             verbose=verbose,
415             include=include,
416             exclude=exclude,
417             extend_exclude=extend_exclude,
418             force_exclude=force_exclude,
419             report=report,
420             stdin_filename=stdin_filename,
421         )
422
423         path_empty(
424             sources,
425             "No Python files are present to be formatted. Nothing to do 😴",
426             quiet,
427             verbose,
428             ctx,
429         )
430
431         if len(sources) == 1:
432             reformat_one(
433                 src=sources.pop(),
434                 fast=fast,
435                 write_back=write_back,
436                 mode=mode,
437                 report=report,
438             )
439         else:
440             reformat_many(
441                 sources=sources,
442                 fast=fast,
443                 write_back=write_back,
444                 mode=mode,
445                 report=report,
446             )
447
448     if verbose or not quiet:
449         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
450         if code is None:
451             click.echo(str(report), err=True)
452     ctx.exit(report.return_code)
453
454
455 def get_sources(
456     *,
457     ctx: click.Context,
458     src: Tuple[str, ...],
459     quiet: bool,
460     verbose: bool,
461     include: Pattern[str],
462     exclude: Optional[Pattern[str]],
463     extend_exclude: Optional[Pattern[str]],
464     force_exclude: Optional[Pattern[str]],
465     report: "Report",
466     stdin_filename: Optional[str],
467 ) -> Set[Path]:
468     """Compute the set of files to be formatted."""
469
470     root = find_project_root(src)
471     sources: Set[Path] = set()
472     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
473
474     if exclude is None:
475         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
476         gitignore = get_gitignore(root)
477     else:
478         gitignore = None
479
480     for s in src:
481         if s == "-" and stdin_filename:
482             p = Path(stdin_filename)
483             is_stdin = True
484         else:
485             p = Path(s)
486             is_stdin = False
487
488         if is_stdin or p.is_file():
489             normalized_path = normalize_path_maybe_ignore(p, root, report)
490             if normalized_path is None:
491                 continue
492
493             normalized_path = "/" + normalized_path
494             # Hard-exclude any files that matches the `--force-exclude` regex.
495             if force_exclude:
496                 force_exclude_match = force_exclude.search(normalized_path)
497             else:
498                 force_exclude_match = None
499             if force_exclude_match and force_exclude_match.group(0):
500                 report.path_ignored(p, "matches the --force-exclude regular expression")
501                 continue
502
503             if is_stdin:
504                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
505
506             sources.add(p)
507         elif p.is_dir():
508             sources.update(
509                 gen_python_files(
510                     p.iterdir(),
511                     root,
512                     include,
513                     exclude,
514                     extend_exclude,
515                     force_exclude,
516                     report,
517                     gitignore,
518                 )
519             )
520         elif s == "-":
521             sources.add(p)
522         else:
523             err(f"invalid path: {s}")
524     return sources
525
526
527 def path_empty(
528     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
529 ) -> None:
530     """
531     Exit if there is no `src` provided for formatting
532     """
533     if not src:
534         if verbose or not quiet:
535             out(msg)
536         ctx.exit(0)
537
538
539 def reformat_code(
540     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
541 ) -> None:
542     """
543     Reformat and print out `content` without spawning child processes.
544     Similar to `reformat_one`, but for string content.
545
546     `fast`, `write_back`, and `mode` options are passed to
547     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
548     """
549     path = Path("<string>")
550     try:
551         changed = Changed.NO
552         if format_stdin_to_stdout(
553             content=content, fast=fast, write_back=write_back, mode=mode
554         ):
555             changed = Changed.YES
556         report.done(path, changed)
557     except Exception as exc:
558         if report.verbose:
559             traceback.print_exc()
560         report.failed(path, str(exc))
561
562
563 def reformat_one(
564     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
565 ) -> None:
566     """Reformat a single file under `src` without spawning child processes.
567
568     `fast`, `write_back`, and `mode` options are passed to
569     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
570     """
571     try:
572         changed = Changed.NO
573
574         if str(src) == "-":
575             is_stdin = True
576         elif str(src).startswith(STDIN_PLACEHOLDER):
577             is_stdin = True
578             # Use the original name again in case we want to print something
579             # to the user
580             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
581         else:
582             is_stdin = False
583
584         if is_stdin:
585             if src.suffix == ".pyi":
586                 mode = replace(mode, is_pyi=True)
587             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
588                 changed = Changed.YES
589         else:
590             cache: Cache = {}
591             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
592                 cache = read_cache(mode)
593                 res_src = src.resolve()
594                 res_src_s = str(res_src)
595                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
596                     changed = Changed.CACHED
597             if changed is not Changed.CACHED and format_file_in_place(
598                 src, fast=fast, write_back=write_back, mode=mode
599             ):
600                 changed = Changed.YES
601             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
602                 write_back is WriteBack.CHECK and changed is Changed.NO
603             ):
604                 write_cache(cache, [src], mode)
605         report.done(src, changed)
606     except Exception as exc:
607         if report.verbose:
608             traceback.print_exc()
609         report.failed(src, str(exc))
610
611
612 def reformat_many(
613     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
614 ) -> None:
615     """Reformat multiple files using a ProcessPoolExecutor."""
616     executor: Executor
617     loop = asyncio.get_event_loop()
618     worker_count = os.cpu_count()
619     if sys.platform == "win32":
620         # Work around https://bugs.python.org/issue26903
621         worker_count = min(worker_count, 60)
622     try:
623         executor = ProcessPoolExecutor(max_workers=worker_count)
624     except (ImportError, OSError):
625         # we arrive here if the underlying system does not support multi-processing
626         # like in AWS Lambda or Termux, in which case we gracefully fallback to
627         # a ThreadPoolExecutor with just a single worker (more workers would not do us
628         # any good due to the Global Interpreter Lock)
629         executor = ThreadPoolExecutor(max_workers=1)
630
631     try:
632         loop.run_until_complete(
633             schedule_formatting(
634                 sources=sources,
635                 fast=fast,
636                 write_back=write_back,
637                 mode=mode,
638                 report=report,
639                 loop=loop,
640                 executor=executor,
641             )
642         )
643     finally:
644         shutdown(loop)
645         if executor is not None:
646             executor.shutdown()
647
648
649 async def schedule_formatting(
650     sources: Set[Path],
651     fast: bool,
652     write_back: WriteBack,
653     mode: Mode,
654     report: "Report",
655     loop: asyncio.AbstractEventLoop,
656     executor: Executor,
657 ) -> None:
658     """Run formatting of `sources` in parallel using the provided `executor`.
659
660     (Use ProcessPoolExecutors for actual parallelism.)
661
662     `write_back`, `fast`, and `mode` options are passed to
663     :func:`format_file_in_place`.
664     """
665     cache: Cache = {}
666     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
667         cache = read_cache(mode)
668         sources, cached = filter_cached(cache, sources)
669         for src in sorted(cached):
670             report.done(src, Changed.CACHED)
671     if not sources:
672         return
673
674     cancelled = []
675     sources_to_cache = []
676     lock = None
677     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
678         # For diff output, we need locks to ensure we don't interleave output
679         # from different processes.
680         manager = Manager()
681         lock = manager.Lock()
682     tasks = {
683         asyncio.ensure_future(
684             loop.run_in_executor(
685                 executor, format_file_in_place, src, fast, mode, write_back, lock
686             )
687         ): src
688         for src in sorted(sources)
689     }
690     pending = tasks.keys()
691     try:
692         loop.add_signal_handler(signal.SIGINT, cancel, pending)
693         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
694     except NotImplementedError:
695         # There are no good alternatives for these on Windows.
696         pass
697     while pending:
698         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
699         for task in done:
700             src = tasks.pop(task)
701             if task.cancelled():
702                 cancelled.append(task)
703             elif task.exception():
704                 report.failed(src, str(task.exception()))
705             else:
706                 changed = Changed.YES if task.result() else Changed.NO
707                 # If the file was written back or was successfully checked as
708                 # well-formatted, store this information in the cache.
709                 if write_back is WriteBack.YES or (
710                     write_back is WriteBack.CHECK and changed is Changed.NO
711                 ):
712                     sources_to_cache.append(src)
713                 report.done(src, changed)
714     if cancelled:
715         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
716     if sources_to_cache:
717         write_cache(cache, sources_to_cache, mode)
718
719
720 def format_file_in_place(
721     src: Path,
722     fast: bool,
723     mode: Mode,
724     write_back: WriteBack = WriteBack.NO,
725     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
726 ) -> bool:
727     """Format file under `src` path. Return True if changed.
728
729     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
730     code to the file.
731     `mode` and `fast` options are passed to :func:`format_file_contents`.
732     """
733     if src.suffix == ".pyi":
734         mode = replace(mode, is_pyi=True)
735
736     then = datetime.utcfromtimestamp(src.stat().st_mtime)
737     with open(src, "rb") as buf:
738         src_contents, encoding, newline = decode_bytes(buf.read())
739     try:
740         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
741     except NothingChanged:
742         return False
743
744     if write_back == WriteBack.YES:
745         with open(src, "w", encoding=encoding, newline=newline) as f:
746             f.write(dst_contents)
747     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
748         now = datetime.utcnow()
749         src_name = f"{src}\t{then} +0000"
750         dst_name = f"{src}\t{now} +0000"
751         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
752
753         if write_back == WriteBack.COLOR_DIFF:
754             diff_contents = color_diff(diff_contents)
755
756         with lock or nullcontext():
757             f = io.TextIOWrapper(
758                 sys.stdout.buffer,
759                 encoding=encoding,
760                 newline=newline,
761                 write_through=True,
762             )
763             f = wrap_stream_for_windows(f)
764             f.write(diff_contents)
765             f.detach()
766
767     return True
768
769
770 def format_stdin_to_stdout(
771     fast: bool,
772     *,
773     content: Optional[str] = None,
774     write_back: WriteBack = WriteBack.NO,
775     mode: Mode,
776 ) -> bool:
777     """Format file on stdin. Return True if changed.
778
779     If content is None, it's read from sys.stdin.
780
781     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
782     write a diff to stdout. The `mode` argument is passed to
783     :func:`format_file_contents`.
784     """
785     then = datetime.utcnow()
786
787     if content is None:
788         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
789     else:
790         src, encoding, newline = content, "utf-8", ""
791
792     dst = src
793     try:
794         dst = format_file_contents(src, fast=fast, mode=mode)
795         return True
796
797     except NothingChanged:
798         return False
799
800     finally:
801         f = io.TextIOWrapper(
802             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
803         )
804         if write_back == WriteBack.YES:
805             # Make sure there's a newline after the content
806             dst += "" if dst[-1] == "\n" else "\n"
807             f.write(dst)
808         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
809             now = datetime.utcnow()
810             src_name = f"STDIN\t{then} +0000"
811             dst_name = f"STDOUT\t{now} +0000"
812             d = diff(src, dst, src_name, dst_name)
813             if write_back == WriteBack.COLOR_DIFF:
814                 d = color_diff(d)
815                 f = wrap_stream_for_windows(f)
816             f.write(d)
817         f.detach()
818
819
820 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
821     """Reformat contents of a file and return new contents.
822
823     If `fast` is False, additionally confirm that the reformatted code is
824     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
825     `mode` is passed to :func:`format_str`.
826     """
827     if not src_contents.strip():
828         raise NothingChanged
829
830     dst_contents = format_str(src_contents, mode=mode)
831     if src_contents == dst_contents:
832         raise NothingChanged
833
834     if not fast:
835         assert_equivalent(src_contents, dst_contents)
836
837         # Forced second pass to work around optional trailing commas (becoming
838         # forced trailing commas on pass 2) interacting differently with optional
839         # parentheses.  Admittedly ugly.
840         dst_contents_pass2 = format_str(dst_contents, mode=mode)
841         if dst_contents != dst_contents_pass2:
842             dst_contents = dst_contents_pass2
843             assert_equivalent(src_contents, dst_contents, pass_num=2)
844             assert_stable(src_contents, dst_contents, mode=mode)
845         # Note: no need to explicitly call `assert_stable` if `dst_contents` was
846         # the same as `dst_contents_pass2`.
847     return dst_contents
848
849
850 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
851     """Reformat a string and return new contents.
852
853     `mode` determines formatting options, such as how many characters per line are
854     allowed.  Example:
855
856     >>> import black
857     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
858     def f(arg: str = "") -> None:
859         ...
860
861     A more complex example:
862
863     >>> print(
864     ...   black.format_str(
865     ...     "def f(arg:str='')->None: hey",
866     ...     mode=black.Mode(
867     ...       target_versions={black.TargetVersion.PY36},
868     ...       line_length=10,
869     ...       string_normalization=False,
870     ...       is_pyi=False,
871     ...     ),
872     ...   ),
873     ... )
874     def f(
875         arg: str = '',
876     ) -> None:
877         hey
878
879     """
880     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
881     dst_contents = []
882     future_imports = get_future_imports(src_node)
883     if mode.target_versions:
884         versions = mode.target_versions
885     else:
886         versions = detect_target_versions(src_node)
887     normalize_fmt_off(src_node)
888     lines = LineGenerator(
889         mode=mode,
890         remove_u_prefix="unicode_literals" in future_imports
891         or supports_feature(versions, Feature.UNICODE_LITERALS),
892     )
893     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
894     empty_line = Line(mode=mode)
895     after = 0
896     split_line_features = {
897         feature
898         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
899         if supports_feature(versions, feature)
900     }
901     for current_line in lines.visit(src_node):
902         dst_contents.append(str(empty_line) * after)
903         before, after = elt.maybe_empty_lines(current_line)
904         dst_contents.append(str(empty_line) * before)
905         for line in transform_line(
906             current_line, mode=mode, features=split_line_features
907         ):
908             dst_contents.append(str(line))
909     return "".join(dst_contents)
910
911
912 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
913     """Return a tuple of (decoded_contents, encoding, newline).
914
915     `newline` is either CRLF or LF but `decoded_contents` is decoded with
916     universal newlines (i.e. only contains LF).
917     """
918     srcbuf = io.BytesIO(src)
919     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
920     if not lines:
921         return "", encoding, "\n"
922
923     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
924     srcbuf.seek(0)
925     with io.TextIOWrapper(srcbuf, encoding) as tiow:
926         return tiow.read(), encoding, newline
927
928
929 def get_features_used(node: Node) -> Set[Feature]:
930     """Return a set of (relatively) new Python features used in this file.
931
932     Currently looking for:
933     - f-strings;
934     - underscores in numeric literals;
935     - trailing commas after * or ** in function signatures and calls;
936     - positional only arguments in function signatures and lambdas;
937     - assignment expression;
938     - relaxed decorator syntax;
939     """
940     features: Set[Feature] = set()
941     for n in node.pre_order():
942         if n.type == token.STRING:
943             value_head = n.value[:2]  # type: ignore
944             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
945                 features.add(Feature.F_STRINGS)
946
947         elif n.type == token.NUMBER:
948             if "_" in n.value:  # type: ignore
949                 features.add(Feature.NUMERIC_UNDERSCORES)
950
951         elif n.type == token.SLASH:
952             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
953                 features.add(Feature.POS_ONLY_ARGUMENTS)
954
955         elif n.type == token.COLONEQUAL:
956             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
957
958         elif n.type == syms.decorator:
959             if len(n.children) > 1 and not is_simple_decorator_expression(
960                 n.children[1]
961             ):
962                 features.add(Feature.RELAXED_DECORATORS)
963
964         elif (
965             n.type in {syms.typedargslist, syms.arglist}
966             and n.children
967             and n.children[-1].type == token.COMMA
968         ):
969             if n.type == syms.typedargslist:
970                 feature = Feature.TRAILING_COMMA_IN_DEF
971             else:
972                 feature = Feature.TRAILING_COMMA_IN_CALL
973
974             for ch in n.children:
975                 if ch.type in STARS:
976                     features.add(feature)
977
978                 if ch.type == syms.argument:
979                     for argch in ch.children:
980                         if argch.type in STARS:
981                             features.add(feature)
982
983     return features
984
985
986 def detect_target_versions(node: Node) -> Set[TargetVersion]:
987     """Detect the version to target based on the nodes used."""
988     features = get_features_used(node)
989     return {
990         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
991     }
992
993
994 def get_future_imports(node: Node) -> Set[str]:
995     """Return a set of __future__ imports in the file."""
996     imports: Set[str] = set()
997
998     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
999         for child in children:
1000             if isinstance(child, Leaf):
1001                 if child.type == token.NAME:
1002                     yield child.value
1003
1004             elif child.type == syms.import_as_name:
1005                 orig_name = child.children[0]
1006                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1007                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1008                 yield orig_name.value
1009
1010             elif child.type == syms.import_as_names:
1011                 yield from get_imports_from_children(child.children)
1012
1013             else:
1014                 raise AssertionError("Invalid syntax parsing imports")
1015
1016     for child in node.children:
1017         if child.type != syms.simple_stmt:
1018             break
1019
1020         first_child = child.children[0]
1021         if isinstance(first_child, Leaf):
1022             # Continue looking if we see a docstring; otherwise stop.
1023             if (
1024                 len(child.children) == 2
1025                 and first_child.type == token.STRING
1026                 and child.children[1].type == token.NEWLINE
1027             ):
1028                 continue
1029
1030             break
1031
1032         elif first_child.type == syms.import_from:
1033             module_name = first_child.children[1]
1034             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1035                 break
1036
1037             imports |= set(get_imports_from_children(first_child.children[3:]))
1038         else:
1039             break
1040
1041     return imports
1042
1043
1044 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1045     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1046     try:
1047         src_ast = parse_ast(src)
1048     except Exception as exc:
1049         raise AssertionError(
1050             "cannot use --safe with this file; failed to parse source file.  AST"
1051             f" error message: {exc}"
1052         )
1053
1054     try:
1055         dst_ast = parse_ast(dst)
1056     except Exception as exc:
1057         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1058         raise AssertionError(
1059             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1060             "Please report a bug on https://github.com/psf/black/issues.  "
1061             f"This invalid output might be helpful: {log}"
1062         ) from None
1063
1064     src_ast_str = "\n".join(stringify_ast(src_ast))
1065     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1066     if src_ast_str != dst_ast_str:
1067         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1068         raise AssertionError(
1069             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1070             f" source on pass {pass_num}.  Please report a bug on "
1071             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1072         ) from None
1073
1074
1075 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1076     """Raise AssertionError if `dst` reformats differently the second time."""
1077     newdst = format_str(dst, mode=mode)
1078     if dst != newdst:
1079         log = dump_to_file(
1080             str(mode),
1081             diff(src, dst, "source", "first pass"),
1082             diff(dst, newdst, "first pass", "second pass"),
1083         )
1084         raise AssertionError(
1085             "INTERNAL ERROR: Black produced different code on the second pass of the"
1086             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1087             f"  This diff might be helpful: {log}"
1088         ) from None
1089
1090
1091 @contextmanager
1092 def nullcontext() -> Iterator[None]:
1093     """Return an empty context manager.
1094
1095     To be used like `nullcontext` in Python 3.7.
1096     """
1097     yield
1098
1099
1100 def patch_click() -> None:
1101     """Make Click not crash on Python 3.6 with LANG=C.
1102
1103     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1104     default which restricts paths that it can access during the lifetime of the
1105     application.  Click refuses to work in this scenario by raising a RuntimeError.
1106
1107     In case of Black the likelihood that non-ASCII characters are going to be used in
1108     file paths is minimal since it's Python source code.  Moreover, this crash was
1109     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1110     """
1111     try:
1112         from click import core
1113         from click import _unicodefun  # type: ignore
1114     except ModuleNotFoundError:
1115         return
1116
1117     for module in (core, _unicodefun):
1118         if hasattr(module, "_verify_python3_env"):
1119             module._verify_python3_env = lambda: None  # type: ignore
1120         if hasattr(module, "_verify_python_env"):
1121             module._verify_python_env = lambda: None  # type: ignore
1122
1123
1124 def patched_main() -> None:
1125     maybe_install_uvloop()
1126     freeze_support()
1127     patch_click()
1128     main()
1129
1130
1131 if __name__ == "__main__":
1132     patched_main()