]> git.madduck.net Git - etc/vim.git/blob - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Improve Python 2 only syntax detection (GH-2592)
[etc/vim.git] / src / black / __init__.py
1 import asyncio
2 from json.decoder import JSONDecodeError
3 import json
4 from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
5 from contextlib import contextmanager
6 from datetime import datetime
7 from enum import Enum
8 import io
9 from multiprocessing import Manager, freeze_support
10 import os
11 from pathlib import Path
12 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
13 import regex as re
14 import signal
15 import sys
16 import tokenize
17 import traceback
18 from typing import (
19     Any,
20     Dict,
21     Generator,
22     Iterator,
23     List,
24     MutableMapping,
25     Optional,
26     Pattern,
27     Set,
28     Sized,
29     Tuple,
30     Union,
31 )
32
33 from dataclasses import replace
34 import click
35
36 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
37 from black.const import STDIN_PLACEHOLDER
38 from black.nodes import STARS, syms, is_simple_decorator_expression
39 from black.lines import Line, EmptyLineTracker
40 from black.linegen import transform_line, LineGenerator, LN
41 from black.comments import normalize_fmt_off
42 from black.mode import Mode, TargetVersion
43 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
44 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
45 from black.concurrency import cancel, shutdown, maybe_install_uvloop
46 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
47 from black.report import Report, Changed, NothingChanged
48 from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
49 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
50 from black.files import wrap_stream_for_windows
51 from black.parsing import InvalidInput  # noqa F401
52 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
53 from black.handle_ipynb_magics import (
54     mask_cell,
55     unmask_cell,
56     remove_trailing_semicolon,
57     put_trailing_semicolon_back,
58     TRANSFORMED_MAGICS,
59     jupyter_dependencies_are_installed,
60 )
61
62
63 # lib2to3 fork
64 from blib2to3.pytree import Node, Leaf
65 from blib2to3.pgen2 import token
66
67 from _black_version import version as __version__
68
69 # types
70 FileContent = str
71 Encoding = str
72 NewLine = str
73
74
75 class WriteBack(Enum):
76     NO = 0
77     YES = 1
78     DIFF = 2
79     CHECK = 3
80     COLOR_DIFF = 4
81
82     @classmethod
83     def from_configuration(
84         cls, *, check: bool, diff: bool, color: bool = False
85     ) -> "WriteBack":
86         if check and not diff:
87             return cls.CHECK
88
89         if diff and color:
90             return cls.COLOR_DIFF
91
92         return cls.DIFF if diff else cls.YES
93
94
95 # Legacy name, left for integrations.
96 FileMode = Mode
97
98 DEFAULT_WORKERS = os.cpu_count()
99
100
101 def read_pyproject_toml(
102     ctx: click.Context, param: click.Parameter, value: Optional[str]
103 ) -> Optional[str]:
104     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
105
106     Returns the path to a successfully found and read configuration file, None
107     otherwise.
108     """
109     if not value:
110         value = find_pyproject_toml(ctx.params.get("src", ()))
111         if value is None:
112             return None
113
114     try:
115         config = parse_pyproject_toml(value)
116     except (OSError, ValueError) as e:
117         raise click.FileError(
118             filename=value, hint=f"Error reading configuration file: {e}"
119         ) from None
120
121     if not config:
122         return None
123     else:
124         # Sanitize the values to be Click friendly. For more information please see:
125         # https://github.com/psf/black/issues/1458
126         # https://github.com/pallets/click/issues/1567
127         config = {
128             k: str(v) if not isinstance(v, (list, dict)) else v
129             for k, v in config.items()
130         }
131
132     target_version = config.get("target_version")
133     if target_version is not None and not isinstance(target_version, list):
134         raise click.BadOptionUsage(
135             "target-version", "Config key target-version must be a list"
136         )
137
138     default_map: Dict[str, Any] = {}
139     if ctx.default_map:
140         default_map.update(ctx.default_map)
141     default_map.update(config)
142
143     ctx.default_map = default_map
144     return value
145
146
147 def target_version_option_callback(
148     c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
149 ) -> List[TargetVersion]:
150     """Compute the target versions from a --target-version flag.
151
152     This is its own function because mypy couldn't infer the type correctly
153     when it was a lambda, causing mypyc trouble.
154     """
155     return [TargetVersion[val.upper()] for val in v]
156
157
158 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
159     """Compile a regular expression string in `regex`.
160
161     If it contains newlines, use verbose mode.
162     """
163     if "\n" in regex:
164         regex = "(?x)" + regex
165     compiled: Pattern[str] = re.compile(regex)
166     return compiled
167
168
169 def validate_regex(
170     ctx: click.Context,
171     param: click.Parameter,
172     value: Optional[str],
173 ) -> Optional[Pattern[str]]:
174     try:
175         return re_compile_maybe_verbose(value) if value is not None else None
176     except re.error:
177         raise click.BadParameter("Not a valid regular expression") from None
178
179
180 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
181 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
182 @click.option(
183     "-l",
184     "--line-length",
185     type=int,
186     default=DEFAULT_LINE_LENGTH,
187     help="How many characters per line to allow.",
188     show_default=True,
189 )
190 @click.option(
191     "-t",
192     "--target-version",
193     type=click.Choice([v.name.lower() for v in TargetVersion]),
194     callback=target_version_option_callback,
195     multiple=True,
196     help=(
197         "Python versions that should be supported by Black's output. [default: per-file"
198         " auto-detection]"
199     ),
200 )
201 @click.option(
202     "--pyi",
203     is_flag=True,
204     help=(
205         "Format all input files like typing stubs regardless of file extension (useful"
206         " when piping source on standard input)."
207     ),
208 )
209 @click.option(
210     "--ipynb",
211     is_flag=True,
212     help=(
213         "Format all input files like Jupyter Notebooks regardless of file extension "
214         "(useful when piping source on standard input)."
215     ),
216 )
217 @click.option(
218     "-S",
219     "--skip-string-normalization",
220     is_flag=True,
221     help="Don't normalize string quotes or prefixes.",
222 )
223 @click.option(
224     "-C",
225     "--skip-magic-trailing-comma",
226     is_flag=True,
227     help="Don't use trailing commas as a reason to split lines.",
228 )
229 @click.option(
230     "--experimental-string-processing",
231     is_flag=True,
232     hidden=True,
233     help=(
234         "Experimental option that performs more normalization on string literals."
235         " Currently disabled because it leads to some crashes."
236     ),
237 )
238 @click.option(
239     "--check",
240     is_flag=True,
241     help=(
242         "Don't write the files back, just return the status. Return code 0 means"
243         " nothing would change. Return code 1 means some files would be reformatted."
244         " Return code 123 means there was an internal error."
245     ),
246 )
247 @click.option(
248     "--diff",
249     is_flag=True,
250     help="Don't write the files back, just output a diff for each file on stdout.",
251 )
252 @click.option(
253     "--color/--no-color",
254     is_flag=True,
255     help="Show colored diff. Only applies when `--diff` is given.",
256 )
257 @click.option(
258     "--fast/--safe",
259     is_flag=True,
260     help="If --fast given, skip temporary sanity checks. [default: --safe]",
261 )
262 @click.option(
263     "--required-version",
264     type=str,
265     help=(
266         "Require a specific version of Black to be running (useful for unifying results"
267         " across many environments e.g. with a pyproject.toml file)."
268     ),
269 )
270 @click.option(
271     "--include",
272     type=str,
273     default=DEFAULT_INCLUDES,
274     callback=validate_regex,
275     help=(
276         "A regular expression that matches files and directories that should be"
277         " included on recursive searches. An empty value means all files are included"
278         " regardless of the name. Use forward slashes for directories on all platforms"
279         " (Windows, too). Exclusions are calculated first, inclusions later."
280     ),
281     show_default=True,
282 )
283 @click.option(
284     "--exclude",
285     type=str,
286     callback=validate_regex,
287     help=(
288         "A regular expression that matches files and directories that should be"
289         " excluded on recursive searches. An empty value means no paths are excluded."
290         " Use forward slashes for directories on all platforms (Windows, too)."
291         " Exclusions are calculated first, inclusions later. [default:"
292         f" {DEFAULT_EXCLUDES}]"
293     ),
294     show_default=False,
295 )
296 @click.option(
297     "--extend-exclude",
298     type=str,
299     callback=validate_regex,
300     help=(
301         "Like --exclude, but adds additional files and directories on top of the"
302         " excluded ones. (Useful if you simply want to add to the default)"
303     ),
304 )
305 @click.option(
306     "--force-exclude",
307     type=str,
308     callback=validate_regex,
309     help=(
310         "Like --exclude, but files and directories matching this regex will be "
311         "excluded even when they are passed explicitly as arguments."
312     ),
313 )
314 @click.option(
315     "--stdin-filename",
316     type=str,
317     help=(
318         "The name of the file when passing it through stdin. Useful to make "
319         "sure Black will respect --force-exclude option on some "
320         "editors that rely on using stdin."
321     ),
322 )
323 @click.option(
324     "-W",
325     "--workers",
326     type=click.IntRange(min=1),
327     default=DEFAULT_WORKERS,
328     show_default=True,
329     help="Number of parallel workers",
330 )
331 @click.option(
332     "-q",
333     "--quiet",
334     is_flag=True,
335     help=(
336         "Don't emit non-error messages to stderr. Errors are still emitted; silence"
337         " those with 2>/dev/null."
338     ),
339 )
340 @click.option(
341     "-v",
342     "--verbose",
343     is_flag=True,
344     help=(
345         "Also emit messages to stderr about files that were not changed or were ignored"
346         " due to exclusion patterns."
347     ),
348 )
349 @click.version_option(version=__version__)
350 @click.argument(
351     "src",
352     nargs=-1,
353     type=click.Path(
354         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
355     ),
356     is_eager=True,
357     metavar="SRC ...",
358 )
359 @click.option(
360     "--config",
361     type=click.Path(
362         exists=True,
363         file_okay=True,
364         dir_okay=False,
365         readable=True,
366         allow_dash=False,
367         path_type=str,
368     ),
369     is_eager=True,
370     callback=read_pyproject_toml,
371     help="Read configuration from FILE path.",
372 )
373 @click.pass_context
374 def main(
375     ctx: click.Context,
376     code: Optional[str],
377     line_length: int,
378     target_version: List[TargetVersion],
379     check: bool,
380     diff: bool,
381     color: bool,
382     fast: bool,
383     pyi: bool,
384     ipynb: bool,
385     skip_string_normalization: bool,
386     skip_magic_trailing_comma: bool,
387     experimental_string_processing: bool,
388     quiet: bool,
389     verbose: bool,
390     required_version: str,
391     include: Pattern[str],
392     exclude: Optional[Pattern[str]],
393     extend_exclude: Optional[Pattern[str]],
394     force_exclude: Optional[Pattern[str]],
395     stdin_filename: Optional[str],
396     workers: int,
397     src: Tuple[str, ...],
398     config: Optional[str],
399 ) -> None:
400     """The uncompromising code formatter."""
401     if config and verbose:
402         out(f"Using configuration from {config}.", bold=False, fg="blue")
403
404     error_msg = "Oh no! 💥 💔 💥"
405     if required_version and required_version != __version__:
406         err(
407             f"{error_msg} The required version `{required_version}` does not match"
408             f" the running version `{__version__}`!"
409         )
410         ctx.exit(1)
411     if ipynb and pyi:
412         err("Cannot pass both `pyi` and `ipynb` flags!")
413         ctx.exit(1)
414
415     write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
416     if target_version:
417         versions = set(target_version)
418     else:
419         # We'll autodetect later.
420         versions = set()
421     mode = Mode(
422         target_versions=versions,
423         line_length=line_length,
424         is_pyi=pyi,
425         is_ipynb=ipynb,
426         string_normalization=not skip_string_normalization,
427         magic_trailing_comma=not skip_magic_trailing_comma,
428         experimental_string_processing=experimental_string_processing,
429     )
430
431     if code is not None:
432         # Run in quiet mode by default with -c; the extra output isn't useful.
433         # You can still pass -v to get verbose output.
434         quiet = True
435
436     report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
437
438     if code is not None:
439         reformat_code(
440             content=code, fast=fast, write_back=write_back, mode=mode, report=report
441         )
442     else:
443         try:
444             sources = get_sources(
445                 ctx=ctx,
446                 src=src,
447                 quiet=quiet,
448                 verbose=verbose,
449                 include=include,
450                 exclude=exclude,
451                 extend_exclude=extend_exclude,
452                 force_exclude=force_exclude,
453                 report=report,
454                 stdin_filename=stdin_filename,
455             )
456         except GitWildMatchPatternError:
457             ctx.exit(1)
458
459         path_empty(
460             sources,
461             "No Python files are present to be formatted. Nothing to do 😴",
462             quiet,
463             verbose,
464             ctx,
465         )
466
467         if len(sources) == 1:
468             reformat_one(
469                 src=sources.pop(),
470                 fast=fast,
471                 write_back=write_back,
472                 mode=mode,
473                 report=report,
474             )
475         else:
476             reformat_many(
477                 sources=sources,
478                 fast=fast,
479                 write_back=write_back,
480                 mode=mode,
481                 report=report,
482                 workers=workers,
483             )
484
485     if verbose or not quiet:
486         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
487         if code is None:
488             click.echo(str(report), err=True)
489     ctx.exit(report.return_code)
490
491
492 def get_sources(
493     *,
494     ctx: click.Context,
495     src: Tuple[str, ...],
496     quiet: bool,
497     verbose: bool,
498     include: Pattern[str],
499     exclude: Optional[Pattern[str]],
500     extend_exclude: Optional[Pattern[str]],
501     force_exclude: Optional[Pattern[str]],
502     report: "Report",
503     stdin_filename: Optional[str],
504 ) -> Set[Path]:
505     """Compute the set of files to be formatted."""
506
507     root = find_project_root(src)
508     sources: Set[Path] = set()
509     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
510
511     if exclude is None:
512         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
513         gitignore = get_gitignore(root)
514     else:
515         gitignore = None
516
517     for s in src:
518         if s == "-" and stdin_filename:
519             p = Path(stdin_filename)
520             is_stdin = True
521         else:
522             p = Path(s)
523             is_stdin = False
524
525         if is_stdin or p.is_file():
526             normalized_path = normalize_path_maybe_ignore(p, root, report)
527             if normalized_path is None:
528                 continue
529
530             normalized_path = "/" + normalized_path
531             # Hard-exclude any files that matches the `--force-exclude` regex.
532             if force_exclude:
533                 force_exclude_match = force_exclude.search(normalized_path)
534             else:
535                 force_exclude_match = None
536             if force_exclude_match and force_exclude_match.group(0):
537                 report.path_ignored(p, "matches the --force-exclude regular expression")
538                 continue
539
540             if is_stdin:
541                 p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
542
543             if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
544                 verbose=verbose, quiet=quiet
545             ):
546                 continue
547
548             sources.add(p)
549         elif p.is_dir():
550             sources.update(
551                 gen_python_files(
552                     p.iterdir(),
553                     root,
554                     include,
555                     exclude,
556                     extend_exclude,
557                     force_exclude,
558                     report,
559                     gitignore,
560                     verbose=verbose,
561                     quiet=quiet,
562                 )
563             )
564         elif s == "-":
565             sources.add(p)
566         else:
567             err(f"invalid path: {s}")
568     return sources
569
570
571 def path_empty(
572     src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
573 ) -> None:
574     """
575     Exit if there is no `src` provided for formatting
576     """
577     if not src:
578         if verbose or not quiet:
579             out(msg)
580         ctx.exit(0)
581
582
583 def reformat_code(
584     content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
585 ) -> None:
586     """
587     Reformat and print out `content` without spawning child processes.
588     Similar to `reformat_one`, but for string content.
589
590     `fast`, `write_back`, and `mode` options are passed to
591     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
592     """
593     path = Path("<string>")
594     try:
595         changed = Changed.NO
596         if format_stdin_to_stdout(
597             content=content, fast=fast, write_back=write_back, mode=mode
598         ):
599             changed = Changed.YES
600         report.done(path, changed)
601     except Exception as exc:
602         if report.verbose:
603             traceback.print_exc()
604         report.failed(path, str(exc))
605
606
607 def reformat_one(
608     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
609 ) -> None:
610     """Reformat a single file under `src` without spawning child processes.
611
612     `fast`, `write_back`, and `mode` options are passed to
613     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
614     """
615     try:
616         changed = Changed.NO
617
618         if str(src) == "-":
619             is_stdin = True
620         elif str(src).startswith(STDIN_PLACEHOLDER):
621             is_stdin = True
622             # Use the original name again in case we want to print something
623             # to the user
624             src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
625         else:
626             is_stdin = False
627
628         if is_stdin:
629             if src.suffix == ".pyi":
630                 mode = replace(mode, is_pyi=True)
631             elif src.suffix == ".ipynb":
632                 mode = replace(mode, is_ipynb=True)
633             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
634                 changed = Changed.YES
635         else:
636             cache: Cache = {}
637             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
638                 cache = read_cache(mode)
639                 res_src = src.resolve()
640                 res_src_s = str(res_src)
641                 if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
642                     changed = Changed.CACHED
643             if changed is not Changed.CACHED and format_file_in_place(
644                 src, fast=fast, write_back=write_back, mode=mode
645             ):
646                 changed = Changed.YES
647             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
648                 write_back is WriteBack.CHECK and changed is Changed.NO
649             ):
650                 write_cache(cache, [src], mode)
651         report.done(src, changed)
652     except Exception as exc:
653         if report.verbose:
654             traceback.print_exc()
655         report.failed(src, str(exc))
656
657
658 def reformat_many(
659     sources: Set[Path],
660     fast: bool,
661     write_back: WriteBack,
662     mode: Mode,
663     report: "Report",
664     workers: Optional[int],
665 ) -> None:
666     """Reformat multiple files using a ProcessPoolExecutor."""
667     executor: Executor
668     loop = asyncio.get_event_loop()
669     worker_count = workers if workers is not None else DEFAULT_WORKERS
670     if sys.platform == "win32":
671         # Work around https://bugs.python.org/issue26903
672         worker_count = min(worker_count, 60)
673     try:
674         executor = ProcessPoolExecutor(max_workers=worker_count)
675     except (ImportError, OSError):
676         # we arrive here if the underlying system does not support multi-processing
677         # like in AWS Lambda or Termux, in which case we gracefully fallback to
678         # a ThreadPoolExecutor with just a single worker (more workers would not do us
679         # any good due to the Global Interpreter Lock)
680         executor = ThreadPoolExecutor(max_workers=1)
681
682     try:
683         loop.run_until_complete(
684             schedule_formatting(
685                 sources=sources,
686                 fast=fast,
687                 write_back=write_back,
688                 mode=mode,
689                 report=report,
690                 loop=loop,
691                 executor=executor,
692             )
693         )
694     finally:
695         shutdown(loop)
696         if executor is not None:
697             executor.shutdown()
698
699
700 async def schedule_formatting(
701     sources: Set[Path],
702     fast: bool,
703     write_back: WriteBack,
704     mode: Mode,
705     report: "Report",
706     loop: asyncio.AbstractEventLoop,
707     executor: Executor,
708 ) -> None:
709     """Run formatting of `sources` in parallel using the provided `executor`.
710
711     (Use ProcessPoolExecutors for actual parallelism.)
712
713     `write_back`, `fast`, and `mode` options are passed to
714     :func:`format_file_in_place`.
715     """
716     cache: Cache = {}
717     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
718         cache = read_cache(mode)
719         sources, cached = filter_cached(cache, sources)
720         for src in sorted(cached):
721             report.done(src, Changed.CACHED)
722     if not sources:
723         return
724
725     cancelled = []
726     sources_to_cache = []
727     lock = None
728     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
729         # For diff output, we need locks to ensure we don't interleave output
730         # from different processes.
731         manager = Manager()
732         lock = manager.Lock()
733     tasks = {
734         asyncio.ensure_future(
735             loop.run_in_executor(
736                 executor, format_file_in_place, src, fast, mode, write_back, lock
737             )
738         ): src
739         for src in sorted(sources)
740     }
741     pending = tasks.keys()
742     try:
743         loop.add_signal_handler(signal.SIGINT, cancel, pending)
744         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
745     except NotImplementedError:
746         # There are no good alternatives for these on Windows.
747         pass
748     while pending:
749         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
750         for task in done:
751             src = tasks.pop(task)
752             if task.cancelled():
753                 cancelled.append(task)
754             elif task.exception():
755                 report.failed(src, str(task.exception()))
756             else:
757                 changed = Changed.YES if task.result() else Changed.NO
758                 # If the file was written back or was successfully checked as
759                 # well-formatted, store this information in the cache.
760                 if write_back is WriteBack.YES or (
761                     write_back is WriteBack.CHECK and changed is Changed.NO
762                 ):
763                     sources_to_cache.append(src)
764                 report.done(src, changed)
765     if cancelled:
766         if sys.version_info >= (3, 7):
767             await asyncio.gather(*cancelled, return_exceptions=True)
768         else:
769             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
770     if sources_to_cache:
771         write_cache(cache, sources_to_cache, mode)
772
773
774 def format_file_in_place(
775     src: Path,
776     fast: bool,
777     mode: Mode,
778     write_back: WriteBack = WriteBack.NO,
779     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
780 ) -> bool:
781     """Format file under `src` path. Return True if changed.
782
783     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
784     code to the file.
785     `mode` and `fast` options are passed to :func:`format_file_contents`.
786     """
787     if src.suffix == ".pyi":
788         mode = replace(mode, is_pyi=True)
789     elif src.suffix == ".ipynb":
790         mode = replace(mode, is_ipynb=True)
791
792     then = datetime.utcfromtimestamp(src.stat().st_mtime)
793     with open(src, "rb") as buf:
794         src_contents, encoding, newline = decode_bytes(buf.read())
795     try:
796         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
797     except NothingChanged:
798         return False
799     except JSONDecodeError:
800         raise ValueError(
801             f"File '{src}' cannot be parsed as valid Jupyter notebook."
802         ) from None
803
804     if write_back == WriteBack.YES:
805         with open(src, "w", encoding=encoding, newline=newline) as f:
806             f.write(dst_contents)
807     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
808         now = datetime.utcnow()
809         src_name = f"{src}\t{then} +0000"
810         dst_name = f"{src}\t{now} +0000"
811         if mode.is_ipynb:
812             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
813         else:
814             diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
815
816         if write_back == WriteBack.COLOR_DIFF:
817             diff_contents = color_diff(diff_contents)
818
819         with lock or nullcontext():
820             f = io.TextIOWrapper(
821                 sys.stdout.buffer,
822                 encoding=encoding,
823                 newline=newline,
824                 write_through=True,
825             )
826             f = wrap_stream_for_windows(f)
827             f.write(diff_contents)
828             f.detach()
829
830     return True
831
832
833 def format_stdin_to_stdout(
834     fast: bool,
835     *,
836     content: Optional[str] = None,
837     write_back: WriteBack = WriteBack.NO,
838     mode: Mode,
839 ) -> bool:
840     """Format file on stdin. Return True if changed.
841
842     If content is None, it's read from sys.stdin.
843
844     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
845     write a diff to stdout. The `mode` argument is passed to
846     :func:`format_file_contents`.
847     """
848     then = datetime.utcnow()
849
850     if content is None:
851         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
852     else:
853         src, encoding, newline = content, "utf-8", ""
854
855     dst = src
856     try:
857         dst = format_file_contents(src, fast=fast, mode=mode)
858         return True
859
860     except NothingChanged:
861         return False
862
863     finally:
864         f = io.TextIOWrapper(
865             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
866         )
867         if write_back == WriteBack.YES:
868             # Make sure there's a newline after the content
869             if dst and dst[-1] != "\n":
870                 dst += "\n"
871             f.write(dst)
872         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
873             now = datetime.utcnow()
874             src_name = f"STDIN\t{then} +0000"
875             dst_name = f"STDOUT\t{now} +0000"
876             d = diff(src, dst, src_name, dst_name)
877             if write_back == WriteBack.COLOR_DIFF:
878                 d = color_diff(d)
879                 f = wrap_stream_for_windows(f)
880             f.write(d)
881         f.detach()
882
883
884 def check_stability_and_equivalence(
885     src_contents: str, dst_contents: str, *, mode: Mode
886 ) -> None:
887     """Perform stability and equivalence checks.
888
889     Raise AssertionError if source and destination contents are not
890     equivalent, or if a second pass of the formatter would format the
891     content differently.
892     """
893     assert_equivalent(src_contents, dst_contents)
894
895     # Forced second pass to work around optional trailing commas (becoming
896     # forced trailing commas on pass 2) interacting differently with optional
897     # parentheses.  Admittedly ugly.
898     dst_contents_pass2 = format_str(dst_contents, mode=mode)
899     if dst_contents != dst_contents_pass2:
900         dst_contents = dst_contents_pass2
901         assert_equivalent(src_contents, dst_contents, pass_num=2)
902         assert_stable(src_contents, dst_contents, mode=mode)
903     # Note: no need to explicitly call `assert_stable` if `dst_contents` was
904     # the same as `dst_contents_pass2`.
905
906
907 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
908     """Reformat contents of a file and return new contents.
909
910     If `fast` is False, additionally confirm that the reformatted code is
911     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
912     `mode` is passed to :func:`format_str`.
913     """
914     if not src_contents.strip():
915         raise NothingChanged
916
917     if mode.is_ipynb:
918         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
919     else:
920         dst_contents = format_str(src_contents, mode=mode)
921     if src_contents == dst_contents:
922         raise NothingChanged
923
924     if not fast and not mode.is_ipynb:
925         # Jupyter notebooks will already have been checked above.
926         check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
927     return dst_contents
928
929
930 def validate_cell(src: str) -> None:
931     """Check that cell does not already contain TransformerManager transformations.
932
933     If a cell contains ``!ls``, then it'll be transformed to
934     ``get_ipython().system('ls')``. However, if the cell originally contained
935     ``get_ipython().system('ls')``, then it would get transformed in the same way:
936
937         >>> TransformerManager().transform_cell("get_ipython().system('ls')")
938         "get_ipython().system('ls')\n"
939         >>> TransformerManager().transform_cell("!ls")
940         "get_ipython().system('ls')\n"
941
942     Due to the impossibility of safely roundtripping in such situations, cells
943     containing transformed magics will be ignored.
944     """
945     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
946         raise NothingChanged
947
948
949 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
950     """Format code in given cell of Jupyter notebook.
951
952     General idea is:
953
954       - if cell has trailing semicolon, remove it;
955       - if cell has IPython magics, mask them;
956       - format cell;
957       - reinstate IPython magics;
958       - reinstate trailing semicolon (if originally present);
959       - strip trailing newlines.
960
961     Cells with syntax errors will not be processed, as they
962     could potentially be automagics or multi-line magics, which
963     are currently not supported.
964     """
965     validate_cell(src)
966     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
967         src
968     )
969     try:
970         masked_src, replacements = mask_cell(src_without_trailing_semicolon)
971     except SyntaxError:
972         raise NothingChanged from None
973     masked_dst = format_str(masked_src, mode=mode)
974     if not fast:
975         check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
976     dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
977     dst = put_trailing_semicolon_back(
978         dst_without_trailing_semicolon, has_trailing_semicolon
979     )
980     dst = dst.rstrip("\n")
981     if dst == src:
982         raise NothingChanged from None
983     return dst
984
985
986 def validate_metadata(nb: MutableMapping[str, Any]) -> None:
987     """If notebook is marked as non-Python, don't format it.
988
989     All notebook metadata fields are optional, see
990     https://nbformat.readthedocs.io/en/latest/format_description.html. So
991     if a notebook has empty metadata, we will try to parse it anyway.
992     """
993     language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
994     if language is not None and language != "python":
995         raise NothingChanged from None
996
997
998 def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
999     """Format Jupyter notebook.
1000
1001     Operate cell-by-cell, only on code cells, only for Python notebooks.
1002     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
1003     """
1004     trailing_newline = src_contents[-1] == "\n"
1005     modified = False
1006     nb = json.loads(src_contents)
1007     validate_metadata(nb)
1008     for cell in nb["cells"]:
1009         if cell.get("cell_type", None) == "code":
1010             try:
1011                 src = "".join(cell["source"])
1012                 dst = format_cell(src, fast=fast, mode=mode)
1013             except NothingChanged:
1014                 pass
1015             else:
1016                 cell["source"] = dst.splitlines(keepends=True)
1017                 modified = True
1018     if modified:
1019         dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
1020         if trailing_newline:
1021             dst_contents = dst_contents + "\n"
1022         return dst_contents
1023     else:
1024         raise NothingChanged
1025
1026
1027 def format_str(src_contents: str, *, mode: Mode) -> FileContent:
1028     """Reformat a string and return new contents.
1029
1030     `mode` determines formatting options, such as how many characters per line are
1031     allowed.  Example:
1032
1033     >>> import black
1034     >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
1035     def f(arg: str = "") -> None:
1036         ...
1037
1038     A more complex example:
1039
1040     >>> print(
1041     ...   black.format_str(
1042     ...     "def f(arg:str='')->None: hey",
1043     ...     mode=black.Mode(
1044     ...       target_versions={black.TargetVersion.PY36},
1045     ...       line_length=10,
1046     ...       string_normalization=False,
1047     ...       is_pyi=False,
1048     ...     ),
1049     ...   ),
1050     ... )
1051     def f(
1052         arg: str = '',
1053     ) -> None:
1054         hey
1055
1056     """
1057     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
1058     dst_contents = []
1059     future_imports = get_future_imports(src_node)
1060     if mode.target_versions:
1061         versions = mode.target_versions
1062     else:
1063         versions = detect_target_versions(src_node)
1064
1065     # TODO: fully drop support and this code hopefully in January 2022 :D
1066     if TargetVersion.PY27 in mode.target_versions or versions == {TargetVersion.PY27}:
1067         msg = (
1068             "DEPRECATION: Python 2 support will be removed in the first stable release "
1069             "expected in January 2022."
1070         )
1071         err(msg, fg="yellow", bold=True)
1072
1073     normalize_fmt_off(src_node)
1074     lines = LineGenerator(
1075         mode=mode,
1076         remove_u_prefix="unicode_literals" in future_imports
1077         or supports_feature(versions, Feature.UNICODE_LITERALS),
1078     )
1079     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
1080     empty_line = Line(mode=mode)
1081     after = 0
1082     split_line_features = {
1083         feature
1084         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
1085         if supports_feature(versions, feature)
1086     }
1087     for current_line in lines.visit(src_node):
1088         dst_contents.append(str(empty_line) * after)
1089         before, after = elt.maybe_empty_lines(current_line)
1090         dst_contents.append(str(empty_line) * before)
1091         for line in transform_line(
1092             current_line, mode=mode, features=split_line_features
1093         ):
1094             dst_contents.append(str(line))
1095     return "".join(dst_contents)
1096
1097
1098 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
1099     """Return a tuple of (decoded_contents, encoding, newline).
1100
1101     `newline` is either CRLF or LF but `decoded_contents` is decoded with
1102     universal newlines (i.e. only contains LF).
1103     """
1104     srcbuf = io.BytesIO(src)
1105     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
1106     if not lines:
1107         return "", encoding, "\n"
1108
1109     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
1110     srcbuf.seek(0)
1111     with io.TextIOWrapper(srcbuf, encoding) as tiow:
1112         return tiow.read(), encoding, newline
1113
1114
1115 def get_features_used(node: Node) -> Set[Feature]:  # noqa: C901
1116     """Return a set of (relatively) new Python features used in this file.
1117
1118     Currently looking for:
1119     - f-strings;
1120     - underscores in numeric literals;
1121     - trailing commas after * or ** in function signatures and calls;
1122     - positional only arguments in function signatures and lambdas;
1123     - assignment expression;
1124     - relaxed decorator syntax;
1125     - print / exec statements;
1126     """
1127     features: Set[Feature] = set()
1128     for n in node.pre_order():
1129         if n.type == token.STRING:
1130             value_head = n.value[:2]  # type: ignore
1131             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1132                 features.add(Feature.F_STRINGS)
1133
1134         elif n.type == token.NUMBER:
1135             assert isinstance(n, Leaf)
1136             if "_" in n.value:
1137                 features.add(Feature.NUMERIC_UNDERSCORES)
1138             elif n.value.endswith(("L", "l")):
1139                 # Python 2: 10L
1140                 features.add(Feature.LONG_INT_LITERAL)
1141             elif len(n.value) >= 2 and n.value[0] == "0" and n.value[1].isdigit():
1142                 # Python 2: 0123; 00123; ...
1143                 if not all(char == "0" for char in n.value):
1144                     # although we don't want to match 0000 or similar
1145                     features.add(Feature.OCTAL_INT_LITERAL)
1146
1147         elif n.type == token.SLASH:
1148             if n.parent and n.parent.type in {
1149                 syms.typedargslist,
1150                 syms.arglist,
1151                 syms.varargslist,
1152             }:
1153                 features.add(Feature.POS_ONLY_ARGUMENTS)
1154
1155         elif n.type == token.COLONEQUAL:
1156             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
1157
1158         elif n.type == syms.decorator:
1159             if len(n.children) > 1 and not is_simple_decorator_expression(
1160                 n.children[1]
1161             ):
1162                 features.add(Feature.RELAXED_DECORATORS)
1163
1164         elif (
1165             n.type in {syms.typedargslist, syms.arglist}
1166             and n.children
1167             and n.children[-1].type == token.COMMA
1168         ):
1169             if n.type == syms.typedargslist:
1170                 feature = Feature.TRAILING_COMMA_IN_DEF
1171             else:
1172                 feature = Feature.TRAILING_COMMA_IN_CALL
1173
1174             for ch in n.children:
1175                 if ch.type in STARS:
1176                     features.add(feature)
1177
1178                 if ch.type == syms.argument:
1179                     for argch in ch.children:
1180                         if argch.type in STARS:
1181                             features.add(feature)
1182
1183         # Python 2 only features (for its deprecation) except for integers, see above
1184         elif n.type == syms.print_stmt:
1185             features.add(Feature.PRINT_STMT)
1186         elif n.type == syms.exec_stmt:
1187             features.add(Feature.EXEC_STMT)
1188         elif n.type == syms.tfpdef:
1189             # def set_position((x, y), value):
1190             #     ...
1191             features.add(Feature.AUTOMATIC_PARAMETER_UNPACKING)
1192         elif n.type == syms.except_clause:
1193             # try:
1194             #     ...
1195             # except Exception, err:
1196             #     ...
1197             if len(n.children) >= 4:
1198                 if n.children[-2].type == token.COMMA:
1199                     features.add(Feature.COMMA_STYLE_EXCEPT)
1200         elif n.type == syms.raise_stmt:
1201             # raise Exception, "msg"
1202             if len(n.children) >= 4:
1203                 if n.children[-2].type == token.COMMA:
1204                     features.add(Feature.COMMA_STYLE_RAISE)
1205         elif n.type == token.BACKQUOTE:
1206             # `i'm surprised this ever existed`
1207             features.add(Feature.BACKQUOTE_REPR)
1208
1209     return features
1210
1211
1212 def detect_target_versions(node: Node) -> Set[TargetVersion]:
1213     """Detect the version to target based on the nodes used."""
1214     features = get_features_used(node)
1215     return {
1216         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
1217     }
1218
1219
1220 def get_future_imports(node: Node) -> Set[str]:
1221     """Return a set of __future__ imports in the file."""
1222     imports: Set[str] = set()
1223
1224     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
1225         for child in children:
1226             if isinstance(child, Leaf):
1227                 if child.type == token.NAME:
1228                     yield child.value
1229
1230             elif child.type == syms.import_as_name:
1231                 orig_name = child.children[0]
1232                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
1233                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
1234                 yield orig_name.value
1235
1236             elif child.type == syms.import_as_names:
1237                 yield from get_imports_from_children(child.children)
1238
1239             else:
1240                 raise AssertionError("Invalid syntax parsing imports")
1241
1242     for child in node.children:
1243         if child.type != syms.simple_stmt:
1244             break
1245
1246         first_child = child.children[0]
1247         if isinstance(first_child, Leaf):
1248             # Continue looking if we see a docstring; otherwise stop.
1249             if (
1250                 len(child.children) == 2
1251                 and first_child.type == token.STRING
1252                 and child.children[1].type == token.NEWLINE
1253             ):
1254                 continue
1255
1256             break
1257
1258         elif first_child.type == syms.import_from:
1259             module_name = first_child.children[1]
1260             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
1261                 break
1262
1263             imports |= set(get_imports_from_children(first_child.children[3:]))
1264         else:
1265             break
1266
1267     return imports
1268
1269
1270 def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
1271     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1272     try:
1273         src_ast = parse_ast(src)
1274     except Exception as exc:
1275         raise AssertionError(
1276             "cannot use --safe with this file; failed to parse source file."
1277         ) from exc
1278
1279     try:
1280         dst_ast = parse_ast(dst)
1281     except Exception as exc:
1282         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1283         raise AssertionError(
1284             f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
1285             "Please report a bug on https://github.com/psf/black/issues.  "
1286             f"This invalid output might be helpful: {log}"
1287         ) from None
1288
1289     src_ast_str = "\n".join(stringify_ast(src_ast))
1290     dst_ast_str = "\n".join(stringify_ast(dst_ast))
1291     if src_ast_str != dst_ast_str:
1292         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1293         raise AssertionError(
1294             "INTERNAL ERROR: Black produced code that is not equivalent to the"
1295             f" source on pass {pass_num}.  Please report a bug on "
1296             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
1297         ) from None
1298
1299
1300 def assert_stable(src: str, dst: str, mode: Mode) -> None:
1301     """Raise AssertionError if `dst` reformats differently the second time."""
1302     newdst = format_str(dst, mode=mode)
1303     if dst != newdst:
1304         log = dump_to_file(
1305             str(mode),
1306             diff(src, dst, "source", "first pass"),
1307             diff(dst, newdst, "first pass", "second pass"),
1308         )
1309         raise AssertionError(
1310             "INTERNAL ERROR: Black produced different code on the second pass of the"
1311             " formatter.  Please report a bug on https://github.com/psf/black/issues."
1312             f"  This diff might be helpful: {log}"
1313         ) from None
1314
1315
1316 @contextmanager
1317 def nullcontext() -> Iterator[None]:
1318     """Return an empty context manager.
1319
1320     To be used like `nullcontext` in Python 3.7.
1321     """
1322     yield
1323
1324
1325 def patch_click() -> None:
1326     """Make Click not crash on Python 3.6 with LANG=C.
1327
1328     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
1329     default which restricts paths that it can access during the lifetime of the
1330     application.  Click refuses to work in this scenario by raising a RuntimeError.
1331
1332     In case of Black the likelihood that non-ASCII characters are going to be used in
1333     file paths is minimal since it's Python source code.  Moreover, this crash was
1334     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
1335     """
1336     try:
1337         from click import core
1338         from click import _unicodefun
1339     except ModuleNotFoundError:
1340         return
1341
1342     for module in (core, _unicodefun):
1343         if hasattr(module, "_verify_python3_env"):
1344             module._verify_python3_env = lambda: None  # type: ignore
1345         if hasattr(module, "_verify_python_env"):
1346             module._verify_python_env = lambda: None  # type: ignore
1347
1348
1349 def patched_main() -> None:
1350     maybe_install_uvloop()
1351     freeze_support()
1352     patch_click()
1353     main()
1354
1355
1356 if __name__ == "__main__":
1357     patched_main()