]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Move `--code` #2259 change log to correct unlreased section of CHANGES.md
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
3 from contextlib import contextmanager
4 from datetime import datetime
5 from enum import Enum
6 import io
7 from multiprocessing import Manager, freeze_support
8 import os
9 from pathlib import Path
10 import regex as re
11 import signal
12 import sys
13 import tokenize
14 import traceback
15 from typing import (
16     Any,
17     Dict,
18     Generator,
19     Iterator,
20     List,
21     Optional,
22     Pattern,
23     Set,
24     Sized,
25     Tuple,
26     Union,
27 )
28
29 from dataclasses import replace
30 import click
31
32 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
33 from black.const import STDIN_PLACEHOLDER
34 from black.nodes import STARS, syms, is_simple_decorator_expression
35 from black.lines import Line, EmptyLineTracker
36 from black.linegen import transform_line, LineGenerator, LN
37 from black.comments import normalize_fmt_off
38 from black.mode import Mode, TargetVersion
39 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
40 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
41 from black.concurrency import cancel, shutdown
42 from black.output import dump_to_file, diff, color_diff, out, err
43 from black.report import Report, Changed
44 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
45 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
46 from black.files import wrap_stream_for_windows
47 from black.parsing import InvalidInput  # noqa F401
48 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
49
50
51 # lib2to3 fork
52 from blib2to3.pytree import Node, Leaf
53 from blib2to3.pgen2 import token
54
55 from _black_version import version as __version__
56
57 # If our environment has uvloop installed lets use it
58 try:
59     import uvloop
60
61     uvloop.install()
62 except ImportError:
63     pass
64
65 # types
66 FileContent = str
67 Encoding = str
68 NewLine = str
69
70
71 class NothingChanged(UserWarning):
72     """Raised when reformatted code is the same as source."""
73
74
75 class WriteBack(Enum):
76     NO = 0
77     YES = 1
78     DIFF = 2
79     CHECK = 3
80     COLOR_DIFF = 4
81
82     @classmethod
83     def from_configuration(
84         cls, *, check: bool, diff: bool, color: bool = False
85     ) -> "WriteBack":
86         if check and not diff:
87             return cls.CHECK
88
89         if diff and color:
90             return cls.COLOR_DIFF
91
92         return cls.DIFF if diff else cls.YES
93
94
95 # Legacy name, left for integrations.
96 FileMode = Mode
97
98
99 def read_pyproject_toml(
100     ctx: click.Context, param: click.Parameter, value: Optional[str]
101 ) -> Optional[str]:
102     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
103
104     Returns the path to a successfully found and read configuration file, None
105     otherwise.
106     """
107     if not value:
108         value = find_pyproject_toml(ctx.params.get("src", ()))
109         if value is None:
110             return None
111
112     try:
113         config = parse_pyproject_toml(value)
114     except (OSError, ValueError) as e:
115         raise click.FileError(
116             filename=value, hint=f"Error reading configuration file: {e}"
117         )
118
119     if not config:
120         return None
121     else:
122         # Sanitize the values to be Click friendly. For more information please see:
123         # https://github.com/psf/black/issues/1458
124         # https://github.com/pallets/click/issues/1567
125         config = {
126             k: str(v) if not isinstance(v, (list, dict)) else v
127             for k, v in config.items()
128         }
129
130     target_version = config.get("target_version")
131     if target_version is not None and not isinstance(target_version, list):
132         raise click.BadOptionUsage(
133             "target-version", "Config key target-version must be a list"
134         )
135
136     default_map: Dict[str, Any] = {}
137     if ctx.default_map:
138         default_map.update(ctx.default_map)
139     default_map.update(config)
140
141     ctx.default_map = default_map
142     return value
143
144
145 def target_version_option_callback(
146     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
147 ) -> List[TargetVersion]:
148     """Compute the target versions from a --target-version flag.
149
150     This is its own function because mypy couldn't infer the type correctly
151     when it was a lambda, causing mypyc trouble.
152     """
153     return [TargetVersion[val.upper()] for val in v]
154
155
156 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
157     """Compile a regular expression string in `regex`.
158
159     If it contains newlines, use verbose mode.
160     """
161     if "\n" in regex:
162         regex = "(?x)" + regex
163     compiled: Pattern[str] = re.compile(regex)
164     return compiled
165
166
167 def validate_regex(
168     ctx: click.Context,
169     param: click.Parameter,
170     value: Optional[str],
171 ) -> Optional[Pattern]:
172     try:
173         return re_compile_maybe_verbose(value) if value is not None else None
174     except re.error:
175         raise click.BadParameter("Not a valid regular expression")
176
177
178 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
179 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
180 @click.option(
181     "-l",
182     "--line-length",
183     type=int,
184     default=DEFAULT_LINE_LENGTH,
185     help="How many characters per line to allow.",
186     show_default=True,
187 )
188 @click.option(
189     "-t",
190     "--target-version",
191     type=click.Choice([v.name.lower() for v in TargetVersion]),
192     callback=target_version_option_callback,
193     multiple=True,
194     help=(
195         "Python versions that should be supported by Black's output. [default: per-file"
196         " auto-detection]"
197     ),
198 )
199 @click.option(
200     "--pyi",
201     is_flag=True,
202     help=(
203         "Format all input files like typing stubs regardless of file extension (useful"
204         " when piping source on standard input)."
205     ),
206 )
207 @click.option(
208     "-S",
209     "--skip-string-normalization",
210     is_flag=True,
211     help="Don't normalize string quotes or prefixes.",
212 )
213 @click.option(
214     "-C",
215     "--skip-magic-trailing-comma",
216     is_flag=True,
217     help="Don't use trailing commas as a reason to split lines.",
218 )
219 @click.option(
220     "--experimental-string-processing",
221     is_flag=True,
222     hidden=True,
223     help=(
224         "Experimental option that performs more normalization on string literals."
225         " Currently disabled because it leads to some crashes."
226     ),
227 )
228 @click.option(
229     "--check",
230     is_flag=True,
231     help=(
232         "Don't write the files back, just return the status. Return code 0 means"
233         " nothing would change. Return code 1 means some files would be reformatted."
234         " Return code 123 means there was an internal error."
235     ),
236 )
237 @click.option(
238     "--diff",
239     is_flag=True,
240     help="Don't write the files back, just output a diff for each file on stdout.",
241 )
242 @click.option(
243     "--color/--no-color",
244     is_flag=True,
245     help="Show colored diff. Only applies when `--diff` is given.",
246 )
247 @click.option(
248     "--fast/--safe",
249     is_flag=True,
250     help="If --fast given, skip temporary sanity checks. [default: --safe]",
251 )
252 @click.option(
253     "--include",
254     type=str,
255     default=DEFAULT_INCLUDES,
256     callback=validate_regex,
257     help=(
258         "A regular expression that matches files and directories that should be"
259         " included on recursive searches. An empty value means all files are included"
260         " regardless of the name. Use forward slashes for directories on all platforms"
261         " (Windows, too). Exclusions are calculated first, inclusions later."
262     ),
263     show_default=True,
264 )
265 @click.option(
266     "--exclude",
267     type=str,
268     callback=validate_regex,
269     help=(
270         "A regular expression that matches files and directories that should be"
271         " excluded on recursive searches. An empty value means no paths are excluded."
272         " Use forward slashes for directories on all platforms (Windows, too)."
273         " Exclusions are calculated first, inclusions later. [default:"
274         f" {DEFAULT_EXCLUDES}]"
275     ),
276     show_default=False,
277 )
278 @click.option(
279     "--extend-exclude",
280     type=str,
281     callback=validate_regex,
282     help=(
283         "Like --exclude, but adds additional files and directories on top of the"
284         " excluded ones. (Useful if you simply want to add to the default)"
285     ),
286 )
287 @click.option(
288     "--force-exclude",
289     type=str,
290     callback=validate_regex,
291     help=(
292         "Like --exclude, but files and directories matching this regex will be "
293         "excluded even when they are passed explicitly as arguments."
294     ),
295 )
296 @click.option(
297     "--stdin-filename",
298     type=str,
299     help=(
300         "The name of the file when passing it through stdin. Useful to make "
301         "sure Black will respect --force-exclude option on some "
302         "editors that rely on using stdin."
303     ),
304 )
305 @click.option(
306     "-q",
307     "--quiet",
308     is_flag=True,
309     help=(
310         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
311         " those with 2>/dev/null."
312     ),
313 )
314 @click.option(
315     "-v",
316     "--verbose",
317     is_flag=True,
318     help=(
319         "Also emit messages to stderr about files that were not changed or were ignored"
320         " due to exclusion patterns."
321     ),
322 )
323 @click.version_option(version=__version__)
324 @click.argument(
325     "src",
326     nargs=-1,
327     type=click.Path(
328         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
329     ),
330     is_eager=True,
331 )
332 @click.option(
333     "--config",
334     type=click.Path(
335         exists=True,
336         file_okay=True,
337         dir_okay=False,
338         readable=True,
339         allow_dash=False,
340         path_type=str,
341     ),
342     is_eager=True,
343     callback=read_pyproject_toml,
344     help="Read configuration from FILE path.",
345 )
346 @click.pass_context
347 def main(
348     ctx: click.Context,
349     code: Optional[str],
350     line_length: int,
351     target_version: List[TargetVersion],
352     check: bool,
353     diff: bool,
354     color: bool,
355     fast: bool,
356     pyi: bool,
357     skip_string_normalization: bool,
358     skip_magic_trailing_comma: bool,
359     experimental_string_processing: bool,
360     quiet: bool,
361     verbose: bool,
362     include: Pattern,
363     exclude: Optional[Pattern],
364     extend_exclude: Optional[Pattern],
365     force_exclude: Optional[Pattern],
366     stdin_filename: Optional[str],
367     src: Tuple[str, ...],
368     config: Optional[str],
369 ) -> None:
370     """The uncompromising code formatter."""
371     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
372     if target_version:
373         versions = set(target_version)
374     else:
375         # We'll autodetect later.
376         versions = set()
377     mode = Mode(
378         target_versions=versions,
379         line_length=line_length,
380         is_pyi=pyi,
381         string_normalization=not skip_string_normalization,
382         magic_trailing_comma=not skip_magic_trailing_comma,
383         experimental_string_processing=experimental_string_processing,
384     )
385     if config and verbose:
386         out(f"Using configuration from {config}.", bold=False, fg="blue")
387
388     if code is not None:
389         # Run in quiet mode by default with -c; the extra output isn't useful.
390         # You can still pass -v to get verbose output.
391         quiet = True
392
393     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
394
395     if code is not None:
396         reformat_code(
397             content=code, fast=fast, write_back=write_back, mode=mode, report=report
398         )
399     else:
400         sources = get_sources(
401             ctx=ctx,
402             src=src,
403             quiet=quiet,
404             verbose=verbose,
405             include=include,
406             exclude=exclude,
407             extend_exclude=extend_exclude,
408             force_exclude=force_exclude,
409             report=report,
410             stdin_filename=stdin_filename,
411         )
412
413         path_empty(
414             sources,
415             "No Python files are present to be formatted. Nothing to do 😴",
416             quiet,
417             verbose,
418             ctx,
419         )
420
421         if len(sources) == 1:
422             reformat_one(
423                 src=sources.pop(),
424                 fast=fast,
425                 write_back=write_back,
426                 mode=mode,
427                 report=report,
428             )
429         else:
430             reformat_many(
431                 sources=sources,
432                 fast=fast,
433                 write_back=write_back,
434                 mode=mode,
435                 report=report,
436             )
437
438     if verbose or not quiet:
439         out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
440         if code is None:
441             click.secho(str(report), err=True)
442     ctx.exit(report.return_code)
443
444
445 def get_sources(
446     *,
447     ctx: click.Context,
448     src: Tuple[str, ...],
449     quiet: bool,
450     verbose: bool,
451     include: Pattern[str],
452     exclude: Optional[Pattern[str]],
453     extend_exclude: Optional[Pattern[str]],
454     force_exclude: Optional[Pattern[str]],
455     report: "Report",
456     stdin_filename: Optional[str],
457 ) -> Set[Path]:
458     """Compute the set of files to be formatted."""
459
460     root = find_project_root(src)
461     sources: Set[Path] = set()
462     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
463
464     if exclude is None:
465         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
466         gitignore = get_gitignore(root)
467     else:
468         gitignore = None
469
470     for s in src:
471         if s == "-" and stdin_filename:
472             p = Path(stdin_filename)
473             is_stdin = True
474         else:
475             p = Path(s)
476             is_stdin = False
477
478         if is_stdin or p.is_file():
479             normalized_path = normalize_path_maybe_ignore(p, root, report)
480             if normalized_path is None:
481                 continue
482
483             normalized_path = "/" + normalized_path
484             # Hard-exclude any files that matches the `--force-exclude` regex.
485             if force_exclude:
486                 force_exclude_match = force_exclude.search(normalized_path)
487             else:
488                 force_exclude_match = None
489             if force_exclude_match and force_exclude_match.group(0):
490                 report.path_ignored(p, "matches the --force-exclude regular expression")
491                 continue
492
493             if is_stdin:
494                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
495
496             sources.add(p)
497         elif p.is_dir():
498             sources.update(
499                 gen_python_files(
500                     p.iterdir(),
501                     root,
502                     include,
503                     exclude,
504                     extend_exclude,
505                     force_exclude,
506                     report,
507                     gitignore,
508                 )
509             )
510         elif s == "-":
511             sources.add(p)
512         else:
513             err(f"invalid path: {s}")
514     return sources
515
516
517 def path_empty(
518     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
519 ) -> None:
520     """
521     Exit if there is no `src` provided for formatting
522     """
523     if not src:
524         if verbose or not quiet:
525             out(msg)
526         ctx.exit(0)
527
528
529 def reformat_code(
530     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
531 ) -> None:
532     """
533     Reformat and print out `content` without spawning child processes.
534     Similar to `reformat_one`, but for string content.
535
536     `fast`, `write_back`, and `mode` options are passed to
537     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
538     """
539     path = Path("<string>")
540     try:
541         changed = Changed.NO
542         if format_stdin_to_stdout(
543             content=content, fast=fast, write_back=write_back, mode=mode
544         ):
545             changed = Changed.YES
546         report.done(path, changed)
547     except Exception as exc:
548         if report.verbose:
549             traceback.print_exc()
550         report.failed(path, str(exc))
551
552
553 def reformat_one(
554     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
555 ) -> None:
556     """Reformat a single file under `src` without spawning child processes.
557
558     `fast`, `write_back`, and `mode` options are passed to
559     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
560     """
561     try:
562         changed = Changed.NO
563
564         if str(src) == "-":
565             is_stdin = True
566         elif str(src).startswith(STDIN_PLACEHOLDER):
567             is_stdin = True
568             # Use the original name again in case we want to print something
569             # to the user
570             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
571         else:
572             is_stdin = False
573
574         if is_stdin:
575             if src.suffix == ".pyi":
576                 mode = replace(mode, is_pyi=True)
577             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
578                 changed = Changed.YES
579         else:
580             cache: Cache = {}
581             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
582                 cache = read_cache(mode)
583                 res_src = src.resolve()
584                 res_src_s = str(res_src)
585                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
586                     changed = Changed.CACHED
587             if changed is not Changed.CACHED and format_file_in_place(
588                 src, fast=fast, write_back=write_back, mode=mode
589             ):
590                 changed = Changed.YES
591             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
592                 write_back is WriteBack.CHECK and changed is Changed.NO
593             ):
594                 write_cache(cache, [src], mode)
595         report.done(src, changed)
596     except Exception as exc:
597         if report.verbose:
598             traceback.print_exc()
599         report.failed(src, str(exc))
600
601
602 def reformat_many(
603     sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
604 ) -> None:
605     """Reformat multiple files using a ProcessPoolExecutor."""
606     executor: Executor
607     loop = asyncio.get_event_loop()
608     worker_count = os.cpu_count()
609     if sys.platform == "win32":
610         # Work around https://bugs.python.org/issue26903
611         worker_count = min(worker_count, 60)
612     try:
613         executor = ProcessPoolExecutor(max_workers=worker_count)
614     except (ImportError, OSError):
615         # we arrive here if the underlying system does not support multi-processing
616         # like in AWS Lambda or Termux, in which case we gracefully fallback to
617         # a ThreadPoolExecutor with just a single worker (more workers would not do us
618         # any good due to the Global Interpreter Lock)
619         executor = ThreadPoolExecutor(max_workers=1)
620
621     try:
622         loop.run_until_complete(
623             schedule_formatting(
624                 sources=sources,
625                 fast=fast,
626                 write_back=write_back,
627                 mode=mode,
628                 report=report,
629                 loop=loop,
630                 executor=executor,
631             )
632         )
633     finally:
634         shutdown(loop)
635         if executor is not None:
636             executor.shutdown()
637
638
639 async def schedule_formatting(
640     sources: Set[Path],
641     fast: bool,
642     write_back: WriteBack,
643     mode: Mode,
644     report: "Report",
645     loop: asyncio.AbstractEventLoop,
646     executor: Executor,
647 ) -> None:
648     """Run formatting of `sources` in parallel using the provided `executor`.
649
650     (Use ProcessPoolExecutors for actual parallelism.)
651
652     `write_back`, `fast`, and `mode` options are passed to
653     :func:`format_file_in_place`.
654     """
655     cache: Cache = {}
656     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
657         cache = read_cache(mode)
658         sources, cached = filter_cached(cache, sources)
659         for src in sorted(cached):
660             report.done(src, Changed.CACHED)
661     if not sources:
662         return
663
664     cancelled = []
665     sources_to_cache = []
666     lock = None
667     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
668         # For diff output, we need locks to ensure we don't interleave output
669         # from different processes.
670         manager = Manager()
671         lock = manager.Lock()
672     tasks = {
673         asyncio.ensure_future(
674             loop.run_in_executor(
675                 executor, format_file_in_place, src, fast, mode, write_back, lock
676             )
677         ): src
678         for src in sorted(sources)
679     }
680     pending = tasks.keys()
681     try:
682         loop.add_signal_handler(signal.SIGINT, cancel, pending)
683         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
684     except NotImplementedError:
685         # There are no good alternatives for these on Windows.
686         pass
687     while pending:
688         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
689         for task in done:
690             src = tasks.pop(task)
691             if task.cancelled():
692                 cancelled.append(task)
693             elif task.exception():
694                 report.failed(src, str(task.exception()))
695             else:
696                 changed = Changed.YES if task.result() else Changed.NO
697                 # If the file was written back or was successfully checked as
698                 # well-formatted, store this information in the cache.
699                 if write_back is WriteBack.YES or (
700                     write_back is WriteBack.CHECK and changed is Changed.NO
701                 ):
702                     sources_to_cache.append(src)
703                 report.done(src, changed)
704     if cancelled:
705         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
706     if sources_to_cache:
707         write_cache(cache, sources_to_cache, mode)
708
709
710 def format_file_in_place(
711     src: Path,
712     fast: bool,
713     mode: Mode,
714     write_back: WriteBack = WriteBack.NO,
715     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
716 ) -> bool:
717     """Format file under `src` path. Return True if changed.
718
719     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
720     code to the file.
721     `mode` and `fast` options are passed to :func:`format_file_contents`.
722     """
723     if src.suffix == ".pyi":
724         mode = replace(mode, is_pyi=True)
725
726     then = datetime.utcfromtimestamp(src.stat().st_mtime)
727     with open(src, "rb") as buf:
728         src_contents, encoding, newline = decode_bytes(buf.read())
729     try:
730         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
731     except NothingChanged:
732         return False
733
734     if write_back == WriteBack.YES:
735         with open(src, "w", encoding=encoding, newline=newline) as f:
736             f.write(dst_contents)
737     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
738         now = datetime.utcnow()
739         src_name = f"{src}\t{then} +0000"
740         dst_name = f"{src}\t{now} +0000"
741         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
742
743         if write_back == WriteBack.COLOR_DIFF:
744             diff_contents = color_diff(diff_contents)
745
746         with lock or nullcontext():
747             f = io.TextIOWrapper(
748                 sys.stdout.buffer,
749                 encoding=encoding,
750                 newline=newline,
751                 write_through=True,
752             )
753             f = wrap_stream_for_windows(f)
754             f.write(diff_contents)
755             f.detach()
756
757     return True
758
759
760 def format_stdin_to_stdout(
761     fast: bool,
762     *,
763     content: Optional[str] = None,
764     write_back: WriteBack = WriteBack.NO,
765     mode: Mode,
766 ) -> bool:
767     """Format file on stdin. Return True if changed.
768
769     If content is None, it's read from sys.stdin.
770
771     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
772     write a diff to stdout. The `mode` argument is passed to
773     :func:`format_file_contents`.
774     """
775     then = datetime.utcnow()
776
777     if content is None:
778         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
779     else:
780         src, encoding, newline = content, "utf-8", ""
781
782     dst = src
783     try:
784         dst = format_file_contents(src, fast=fast, mode=mode)
785         return True
786
787     except NothingChanged:
788         return False
789
790     finally:
791         f = io.TextIOWrapper(
792             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
793         )
794         if write_back == WriteBack.YES:
795             # Make sure there's a newline after the content
796             dst += "" if dst[-1] == "\n" else "\n"
797             f.write(dst)
798         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
799             now = datetime.utcnow()
800             src_name = f"STDIN\t{then} +0000"
801             dst_name = f"STDOUT\t{now} +0000"
802             d = diff(src, dst, src_name, dst_name)
803             if write_back == WriteBack.COLOR_DIFF:
804                 d = color_diff(d)
805                 f = wrap_stream_for_windows(f)
806             f.write(d)
807         f.detach()
808
809
810 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
811     """Reformat contents of a file and return new contents.
812
813     If `fast` is False, additionally confirm that the reformatted code is
814     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
815     `mode` is passed to :func:`format_str`.
816     """
817     if not src_contents.strip():
818         raise NothingChanged
819
820     dst_contents = format_str(src_contents, mode=mode)
821     if src_contents == dst_contents:
822         raise NothingChanged
823
824     if not fast:
825         assert_equivalent(src_contents, dst_contents)
826
827         # Forced second pass to work around optional trailing commas (becoming
828         # forced trailing commas on pass 2) interacting differently with optional
829         # parentheses.  Admittedly ugly.
830         dst_contents_pass2 = format_str(dst_contents, mode=mode)
831         if dst_contents != dst_contents_pass2:
832             dst_contents = dst_contents_pass2
833             assert_equivalent(src_contents, dst_contents, pass_num=2)
834             assert_stable(src_contents, dst_contents, mode=mode)
835         # Note: no need to explicitly call `assert_stable` if `dst_contents` was
836         # the same as `dst_contents_pass2`.
837     return dst_contents
838
839
840 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
841     """Reformat a string and return new contents.
842
843     `mode` determines formatting options, such as how many characters per line are
844     allowed.  Example:
845
846     >>> import black
847     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
848     def f(arg: str = "") -> None:
849         ...
850
851     A more complex example:
852
853     >>> print(
854     ...   black.format_str(
855     ...     "def f(arg:str='')->None: hey",
856     ...     mode=black.Mode(
857     ...       target_versions={black.TargetVersion.PY36},
858     ...       line_length=10,
859     ...       string_normalization=False,
860     ...       is_pyi=False,
861     ...     ),
862     ...   ),
863     ... )
864     def f(
865         arg: str = '',
866     ) -> None:
867         hey
868
869     """
870     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
871     dst_contents = []
872     future_imports = get_future_imports(src_node)
873     if mode.target_versions:
874         versions = mode.target_versions
875     else:
876         versions = detect_target_versions(src_node)
877     normalize_fmt_off(src_node)
878     lines = LineGenerator(
879         mode=mode,
880         remove_u_prefix="unicode_literals" in future_imports
881         or supports_feature(versions, Feature.UNICODE_LITERALS),
882     )
883     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
884     empty_line = Line(mode=mode)
885     after = 0
886     split_line_features = {
887         feature
888         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
889         if supports_feature(versions, feature)
890     }
891     for current_line in lines.visit(src_node):
892         dst_contents.append(str(empty_line) * after)
893         before, after = elt.maybe_empty_lines(current_line)
894         dst_contents.append(str(empty_line) * before)
895         for line in transform_line(
896             current_line, mode=mode, features=split_line_features
897         ):
898             dst_contents.append(str(line))
899     return "".join(dst_contents)
900
901
902 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
903     """Return a tuple of (decoded_contents, encoding, newline).
904
905     `newline` is either CRLF or LF but `decoded_contents` is decoded with
906     universal newlines (i.e. only contains LF).
907     """
908     srcbuf = io.BytesIO(src)
909     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
910     if not lines:
911         return "", encoding, "\n"
912
913     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
914     srcbuf.seek(0)
915     with io.TextIOWrapper(srcbuf, encoding) as tiow:
916         return tiow.read(), encoding, newline
917
918
919 def get_features_used(node: Node) -> Set[Feature]:
920     """Return a set of (relatively) new Python features used in this file.
921
922     Currently looking for:
923     - f-strings;
924     - underscores in numeric literals;
925     - trailing commas after * or ** in function signatures and calls;
926     - positional only arguments in function signatures and lambdas;
927     - assignment expression;
928     - relaxed decorator syntax;
929     """
930     features: Set[Feature] = set()
931     for n in node.pre_order():
932         if n.type == token.STRING:
933             value_head = n.value[:2]  # type: ignore
934             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
935                 features.add(Feature.F_STRINGS)
936
937         elif n.type == token.NUMBER:
938             if "_" in n.value:  # type: ignore
939                 features.add(Feature.NUMERIC_UNDERSCORES)
940
941         elif n.type == token.SLASH:
942             if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
943                 features.add(Feature.POS_ONLY_ARGUMENTS)
944
945         elif n.type == token.COLONEQUAL:
946             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
947
948         elif n.type == syms.decorator:
949             if len(n.children) > 1 and not is_simple_decorator_expression(
950                 n.children[1]
951             ):
952                 features.add(Feature.RELAXED_DECORATORS)
953
954         elif (
955             n.type in {syms.typedargslist, syms.arglist}
956             and n.children
957             and n.children[-1].type == token.COMMA
958         ):
959             if n.type == syms.typedargslist:
960                 feature = Feature.TRAILING_COMMA_IN_DEF
961             else:
962                 feature = Feature.TRAILING_COMMA_IN_CALL
963
964             for ch in n.children:
965                 if ch.type in STARS:
966                     features.add(feature)
967
968                 if ch.type == syms.argument:
969                     for argch in ch.children:
970                         if argch.type in STARS:
971                             features.add(feature)
972
973     return features
974
975
976 def detect_target_versions(node: Node) -> Set[TargetVersion]:
977     """Detect the version to target based on the nodes used."""
978     features = get_features_used(node)
979     return {
980         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
981     }
982
983
984 def get_future_imports(node: Node) -> Set[str]:
985     """Return a set of __future__ imports in the file."""
986     imports: Set[str] = set()
987
988     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
989         for child in children:
990             if isinstance(child, Leaf):
991                 if child.type == token.NAME:
992                     yield child.value
993
994             elif child.type == syms.import_as_name:
995                 orig_name = child.children[0]
996                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
997                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
998                 yield orig_name.value
999
1000             elif child.type == syms.import_as_names:
1001                 yield from get_imports_from_children(child.children)
1002
1003             else:
1004                 raise AssertionError("Invalid syntax parsing imports")
1005
1006     for child in node.children:
1007         if child.type != syms.simple_stmt:
1008             break
1009
1010         first_child = child.children[0]
1011         if isinstance(first_child, Leaf):
1012             # Continue looking if we see a docstring; otherwise stop.
1013             if (
1014                 len(child.children) == 2
1015                 and first_child.type == token.STRING
1016                 and child.children[1].type == token.NEWLINE
1017             ):
1018                 continue
1019
1020             break
1021
1022         elif first_child.type == syms.import_from:
1023             module_name = first_child.children[1]
1024             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1025                 break
1026
1027             imports |= set(get_imports_from_children(first_child.children[3:]))
1028         else:
1029             break
1030
1031     return imports
1032
1033
1034 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1035     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1036     try:
1037         src_ast = parse_ast(src)
1038     except Exception as exc:
1039         raise AssertionError(
1040             "cannot use --safe with this file; failed to parse source file.  AST"
1041             f" error message: {exc}"
1042         )
1043
1044     try:
1045         dst_ast = parse_ast(dst)
1046     except Exception as exc:
1047         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1048         raise AssertionError(
1049             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1050             "Please report a bug on https://github.com/psf/black/issues.  "
1051             f"This invalid output might be helpful: {log}"
1052         ) from None
1053
1054     src_ast_str = "\n".join(stringify_ast(src_ast))
1055     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1056     if src_ast_str != dst_ast_str:
1057         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1058         raise AssertionError(
1059             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1060             f" source on pass {pass_num}.  Please report a bug on "
1061             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1062         ) from None
1063
1064
1065 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1066     """Raise AssertionError if `dst` reformats differently the second time."""
1067     newdst = format_str(dst, mode=mode)
1068     if dst != newdst:
1069         log = dump_to_file(
1070             str(mode),
1071             diff(src, dst, "source", "first pass"),
1072             diff(dst, newdst, "first pass", "second pass"),
1073         )
1074         raise AssertionError(
1075             "INTERNAL ERROR: Black produced different code on the second pass of the"
1076             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1077             f"  This diff might be helpful: {log}"
1078         ) from None
1079
1080
1081 @contextmanager
1082 def nullcontext() -> Iterator[None]:
1083     """Return an empty context manager.
1084
1085     To be used like `nullcontext` in Python 3.7.
1086     """
1087     yield
1088
1089
1090 def patch_click() -> None:
1091     """Make Click not crash on Python 3.6 with LANG=C.
1092
1093     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1094     default which restricts paths that it can access during the lifetime of the
1095     application.  Click refuses to work in this scenario by raising a RuntimeError.
1096
1097     In case of Black the likelihood that non-ASCII characters are going to be used in
1098     file paths is minimal since it's Python source code.  Moreover, this crash was
1099     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1100     """
1101     try:
1102         from click import core
1103         from click import _unicodefun  # type: ignore
1104     except ModuleNotFoundError:
1105         return
1106
1107     for module in (core, _unicodefun):
1108         if hasattr(module, "_verify_python3_env"):
1109             module._verify_python3_env = lambda: None  # type: ignore
1110         if hasattr(module, "_verify_python_env"):
1111             module._verify_python_env = lambda: None  # type: ignore
1112
1113
1114 def patched_main() -> None:
1115     freeze_support()
1116     patch_click()
1117     main()
1118
1119
1120 if __name__ == "__main__":
1121     patched_main()