X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/b2ee211b5ad84b62738ac0997b73bf6ee9a74d06..5434407af7ba262f74d272c738006cbf1d0ab11a:/src/black/__init__.py?ds=inline diff --git a/src/black/__init__.py b/src/black/__init__.py index f46b866..5c6cb67 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -1,4 +1,6 @@ import asyncio +from json.decoder import JSONDecodeError +import json from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor from contextlib import contextmanager from datetime import datetime @@ -7,6 +9,7 @@ import io from multiprocessing import Manager, freeze_support import os from pathlib import Path +from pathspec.patterns.gitwildmatch import GitWildMatchPatternError import regex as re import signal import sys @@ -18,6 +21,7 @@ from typing import ( Generator, Iterator, List, + MutableMapping, Optional, Pattern, Set, @@ -38,14 +42,22 @@ from black.comments import normalize_fmt_off from black.mode import Mode, TargetVersion from black.mode import Feature, supports_feature, VERSION_TO_FEATURES from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache -from black.concurrency import cancel, shutdown -from black.output import dump_to_file, diff, color_diff, out, err -from black.report import Report, Changed +from black.concurrency import cancel, shutdown, maybe_install_uvloop +from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err +from black.report import Report, Changed, NothingChanged from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore from black.files import wrap_stream_for_windows from black.parsing import InvalidInput # noqa F401 from black.parsing import lib2to3_parse, parse_ast, stringify_ast +from black.handle_ipynb_magics import ( + mask_cell, + unmask_cell, + remove_trailing_semicolon, + put_trailing_semicolon_back, + TRANSFORMED_MAGICS, + jupyter_dependencies_are_installed, +) # lib2to3 fork @@ -54,17 +66,12 @@ from blib2to3.pgen2 import token from _black_version import version as __version__ - # types FileContent = str Encoding = str NewLine = str -class NothingChanged(UserWarning): - """Raised when reformatted code is the same as source.""" - - class WriteBack(Enum): NO = 0 YES = 1 @@ -88,6 +95,8 @@ class WriteBack(Enum): # Legacy name, left for integrations. FileMode = Mode +DEFAULT_WORKERS = os.cpu_count() + def read_pyproject_toml( ctx: click.Context, param: click.Parameter, value: Optional[str] @@ -107,7 +116,7 @@ def read_pyproject_toml( except (OSError, ValueError) as e: raise click.FileError( filename=value, hint=f"Error reading configuration file: {e}" - ) + ) from None if not config: return None @@ -161,11 +170,11 @@ def validate_regex( ctx: click.Context, param: click.Parameter, value: Optional[str], -) -> Optional[Pattern]: +) -> Optional[Pattern[str]]: try: return re_compile_maybe_verbose(value) if value is not None else None except re.error: - raise click.BadParameter("Not a valid regular expression") + raise click.BadParameter("Not a valid regular expression") from None @click.command(context_settings=dict(help_option_names=["-h", "--help"])) @@ -197,6 +206,14 @@ def validate_regex( " when piping source on standard input)." ), ) +@click.option( + "--ipynb", + is_flag=True, + help=( + "Format all input files like Jupyter Notebooks regardless of file extension " + "(useful when piping source on standard input)." + ), +) @click.option( "-S", "--skip-string-normalization", @@ -242,6 +259,14 @@ def validate_regex( is_flag=True, help="If --fast given, skip temporary sanity checks. [default: --safe]", ) +@click.option( + "--required-version", + type=str, + help=( + "Require a specific version of Black to be running (useful for unifying results" + " across many environments e.g. with a pyproject.toml file)." + ), +) @click.option( "--include", type=str, @@ -295,6 +320,14 @@ def validate_regex( "editors that rely on using stdin." ), ) +@click.option( + "-W", + "--workers", + type=click.IntRange(min=1), + default=DEFAULT_WORKERS, + show_default=True, + help="Number of parallel workers", +) @click.option( "-q", "--quiet", @@ -321,6 +354,7 @@ def validate_regex( exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True ), is_eager=True, + metavar="SRC ...", ) @click.option( "--config", @@ -347,20 +381,37 @@ def main( color: bool, fast: bool, pyi: bool, + ipynb: bool, skip_string_normalization: bool, skip_magic_trailing_comma: bool, experimental_string_processing: bool, quiet: bool, verbose: bool, - include: Pattern, - exclude: Optional[Pattern], - extend_exclude: Optional[Pattern], - force_exclude: Optional[Pattern], + required_version: str, + include: Pattern[str], + exclude: Optional[Pattern[str]], + extend_exclude: Optional[Pattern[str]], + force_exclude: Optional[Pattern[str]], stdin_filename: Optional[str], + workers: int, src: Tuple[str, ...], config: Optional[str], ) -> None: """The uncompromising code formatter.""" + if config and verbose: + out(f"Using configuration from {config}.", bold=False, fg="blue") + + error_msg = "Oh no! 💥 💔 💥" + if required_version and required_version != __version__: + err( + f"{error_msg} The required version `{required_version}` does not match" + f" the running version `{__version__}`!" + ) + ctx.exit(1) + if ipynb and pyi: + err("Cannot pass both `pyi` and `ipynb` flags!") + ctx.exit(1) + write_back = WriteBack.from_configuration(check=check, diff=diff, color=color) if target_version: versions = set(target_version) @@ -371,53 +422,70 @@ def main( target_versions=versions, line_length=line_length, is_pyi=pyi, + is_ipynb=ipynb, string_normalization=not skip_string_normalization, magic_trailing_comma=not skip_magic_trailing_comma, experimental_string_processing=experimental_string_processing, ) - if config and verbose: - out(f"Using configuration from {config}.", bold=False, fg="blue") + if code is not None: - print(format_str(code, mode=mode)) - ctx.exit(0) - report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose) - sources = get_sources( - ctx=ctx, - src=src, - quiet=quiet, - verbose=verbose, - include=include, - exclude=exclude, - extend_exclude=extend_exclude, - force_exclude=force_exclude, - report=report, - stdin_filename=stdin_filename, - ) + # Run in quiet mode by default with -c; the extra output isn't useful. + # You can still pass -v to get verbose output. + quiet = True - path_empty( - sources, - "No Python files are present to be formatted. Nothing to do 😴", - quiet, - verbose, - ctx, - ) + report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose) - if len(sources) == 1: - reformat_one( - src=sources.pop(), - fast=fast, - write_back=write_back, - mode=mode, - report=report, + if code is not None: + reformat_code( + content=code, fast=fast, write_back=write_back, mode=mode, report=report ) else: - reformat_many( - sources=sources, fast=fast, write_back=write_back, mode=mode, report=report + try: + sources = get_sources( + ctx=ctx, + src=src, + quiet=quiet, + verbose=verbose, + include=include, + exclude=exclude, + extend_exclude=extend_exclude, + force_exclude=force_exclude, + report=report, + stdin_filename=stdin_filename, + ) + except GitWildMatchPatternError: + ctx.exit(1) + + path_empty( + sources, + "No Python files are present to be formatted. Nothing to do 😴", + quiet, + verbose, + ctx, ) + if len(sources) == 1: + reformat_one( + src=sources.pop(), + fast=fast, + write_back=write_back, + mode=mode, + report=report, + ) + else: + reformat_many( + sources=sources, + fast=fast, + write_back=write_back, + mode=mode, + report=report, + workers=workers, + ) + if verbose or not quiet: - out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨") - click.secho(str(report), err=True) + out(error_msg if report.return_code else "All done! ✨ 🍰 ✨") + if code is None: + click.echo(str(report), err=True) ctx.exit(report.return_code) @@ -472,6 +540,11 @@ def get_sources( if is_stdin: p = Path(f"{STDIN_PLACEHOLDER}{str(p)}") + if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed( + verbose=verbose, quiet=quiet + ): + continue + sources.add(p) elif p.is_dir(): sources.update( @@ -484,6 +557,8 @@ def get_sources( force_exclude, report, gitignore, + verbose=verbose, + quiet=quiet, ) ) elif s == "-": @@ -499,11 +574,36 @@ def path_empty( """ Exit if there is no `src` provided for formatting """ - if not src and (verbose or not quiet): - out(msg) + if not src: + if verbose or not quiet: + out(msg) ctx.exit(0) +def reformat_code( + content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report +) -> None: + """ + Reformat and print out `content` without spawning child processes. + Similar to `reformat_one`, but for string content. + + `fast`, `write_back`, and `mode` options are passed to + :func:`format_file_in_place` or :func:`format_stdin_to_stdout`. + """ + path = Path("") + try: + changed = Changed.NO + if format_stdin_to_stdout( + content=content, fast=fast, write_back=write_back, mode=mode + ): + changed = Changed.YES + report.done(path, changed) + except Exception as exc: + if report.verbose: + traceback.print_exc() + report.failed(path, str(exc)) + + def reformat_one( src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: @@ -528,6 +628,8 @@ def reformat_one( if is_stdin: if src.suffix == ".pyi": mode = replace(mode, is_pyi=True) + elif src.suffix == ".ipynb": + mode = replace(mode, is_ipynb=True) if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode): changed = Changed.YES else: @@ -554,12 +656,17 @@ def reformat_one( def reformat_many( - sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report" + sources: Set[Path], + fast: bool, + write_back: WriteBack, + mode: Mode, + report: "Report", + workers: Optional[int], ) -> None: """Reformat multiple files using a ProcessPoolExecutor.""" executor: Executor loop = asyncio.get_event_loop() - worker_count = os.cpu_count() + worker_count = workers if workers is not None else DEFAULT_WORKERS if sys.platform == "win32": # Work around https://bugs.python.org/issue26903 worker_count = min(worker_count, 60) @@ -676,6 +783,8 @@ def format_file_in_place( """ if src.suffix == ".pyi": mode = replace(mode, is_pyi=True) + elif src.suffix == ".ipynb": + mode = replace(mode, is_ipynb=True) then = datetime.utcfromtimestamp(src.stat().st_mtime) with open(src, "rb") as buf: @@ -684,6 +793,10 @@ def format_file_in_place( dst_contents = format_file_contents(src_contents, fast=fast, mode=mode) except NothingChanged: return False + except JSONDecodeError: + raise ValueError( + f"File '{src}' cannot be parsed as valid Jupyter notebook." + ) from None if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: @@ -692,7 +805,10 @@ def format_file_in_place( now = datetime.utcnow() src_name = f"{src}\t{then} +0000" dst_name = f"{src}\t{now} +0000" - diff_contents = diff(src_contents, dst_contents, src_name, dst_name) + if mode.is_ipynb: + diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name) + else: + diff_contents = diff(src_contents, dst_contents, src_name, dst_name) if write_back == WriteBack.COLOR_DIFF: diff_contents = color_diff(diff_contents) @@ -712,16 +828,27 @@ def format_file_in_place( def format_stdin_to_stdout( - fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode + fast: bool, + *, + content: Optional[str] = None, + write_back: WriteBack = WriteBack.NO, + mode: Mode, ) -> bool: """Format file on stdin. Return True if changed. + If content is None, it's read from sys.stdin. + If `write_back` is YES, write reformatted code back to stdout. If it is DIFF, write a diff to stdout. The `mode` argument is passed to :func:`format_file_contents`. """ then = datetime.utcnow() - src, encoding, newline = decode_bytes(sys.stdin.buffer.read()) + + if content is None: + src, encoding, newline = decode_bytes(sys.stdin.buffer.read()) + else: + src, encoding, newline = content, "utf-8", "" + dst = src try: dst = format_file_contents(src, fast=fast, mode=mode) @@ -735,6 +862,9 @@ def format_stdin_to_stdout( sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True ) if write_back == WriteBack.YES: + # Make sure there's a newline after the content + if dst and dst[-1] != "\n": + dst += "\n" f.write(dst) elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): now = datetime.utcnow() @@ -748,6 +878,29 @@ def format_stdin_to_stdout( f.detach() +def check_stability_and_equivalence( + src_contents: str, dst_contents: str, *, mode: Mode +) -> None: + """Perform stability and equivalence checks. + + Raise AssertionError if source and destination contents are not + equivalent, or if a second pass of the formatter would format the + content differently. + """ + assert_equivalent(src_contents, dst_contents) + + # Forced second pass to work around optional trailing commas (becoming + # forced trailing commas on pass 2) interacting differently with optional + # parentheses. Admittedly ugly. + dst_contents_pass2 = format_str(dst_contents, mode=mode) + if dst_contents != dst_contents_pass2: + dst_contents = dst_contents_pass2 + assert_equivalent(src_contents, dst_contents, pass_num=2) + assert_stable(src_contents, dst_contents, mode=mode) + # Note: no need to explicitly call `assert_stable` if `dst_contents` was + # the same as `dst_contents_pass2`. + + def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: """Reformat contents of a file and return new contents. @@ -758,26 +911,116 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo if not src_contents.strip(): raise NothingChanged - dst_contents = format_str(src_contents, mode=mode) + if mode.is_ipynb: + dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode) + else: + dst_contents = format_str(src_contents, mode=mode) if src_contents == dst_contents: raise NothingChanged - if not fast: - assert_equivalent(src_contents, dst_contents) - - # Forced second pass to work around optional trailing commas (becoming - # forced trailing commas on pass 2) interacting differently with optional - # parentheses. Admittedly ugly. - dst_contents_pass2 = format_str(dst_contents, mode=mode) - if dst_contents != dst_contents_pass2: - dst_contents = dst_contents_pass2 - assert_equivalent(src_contents, dst_contents, pass_num=2) - assert_stable(src_contents, dst_contents, mode=mode) - # Note: no need to explicitly call `assert_stable` if `dst_contents` was - # the same as `dst_contents_pass2`. + if not fast and not mode.is_ipynb: + # Jupyter notebooks will already have been checked above. + check_stability_and_equivalence(src_contents, dst_contents, mode=mode) return dst_contents +def validate_cell(src: str) -> None: + """Check that cell does not already contain TransformerManager transformations. + + If a cell contains ``!ls``, then it'll be transformed to + ``get_ipython().system('ls')``. However, if the cell originally contained + ``get_ipython().system('ls')``, then it would get transformed in the same way: + + >>> TransformerManager().transform_cell("get_ipython().system('ls')") + "get_ipython().system('ls')\n" + >>> TransformerManager().transform_cell("!ls") + "get_ipython().system('ls')\n" + + Due to the impossibility of safely roundtripping in such situations, cells + containing transformed magics will be ignored. + """ + if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): + raise NothingChanged + + +def format_cell(src: str, *, fast: bool, mode: Mode) -> str: + """Format code in given cell of Jupyter notebook. + + General idea is: + + - if cell has trailing semicolon, remove it; + - if cell has IPython magics, mask them; + - format cell; + - reinstate IPython magics; + - reinstate trailing semicolon (if originally present); + - strip trailing newlines. + + Cells with syntax errors will not be processed, as they + could potentially be automagics or multi-line magics, which + are currently not supported. + """ + validate_cell(src) + src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon( + src + ) + try: + masked_src, replacements = mask_cell(src_without_trailing_semicolon) + except SyntaxError: + raise NothingChanged from None + masked_dst = format_str(masked_src, mode=mode) + if not fast: + check_stability_and_equivalence(masked_src, masked_dst, mode=mode) + dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements) + dst = put_trailing_semicolon_back( + dst_without_trailing_semicolon, has_trailing_semicolon + ) + dst = dst.rstrip("\n") + if dst == src: + raise NothingChanged from None + return dst + + +def validate_metadata(nb: MutableMapping[str, Any]) -> None: + """If notebook is marked as non-Python, don't format it. + + All notebook metadata fields are optional, see + https://nbformat.readthedocs.io/en/latest/format_description.html. So + if a notebook has empty metadata, we will try to parse it anyway. + """ + language = nb.get("metadata", {}).get("language_info", {}).get("name", None) + if language is not None and language != "python": + raise NothingChanged from None + + +def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: + """Format Jupyter notebook. + + Operate cell-by-cell, only on code cells, only for Python notebooks. + If the ``.ipynb`` originally had a trailing newline, it'll be preserved. + """ + trailing_newline = src_contents[-1] == "\n" + modified = False + nb = json.loads(src_contents) + validate_metadata(nb) + for cell in nb["cells"]: + if cell.get("cell_type", None) == "code": + try: + src = "".join(cell["source"]) + dst = format_cell(src, fast=fast, mode=mode) + except NothingChanged: + pass + else: + cell["source"] = dst.splitlines(keepends=True) + modified = True + if modified: + dst_contents = json.dumps(nb, indent=1, ensure_ascii=False) + if trailing_newline: + dst_contents = dst_contents + "\n" + return dst_contents + else: + raise NothingChanged + + def format_str(src_contents: str, *, mode: Mode) -> FileContent: """Reformat a string and return new contents. @@ -880,7 +1123,11 @@ def get_features_used(node: Node) -> Set[Feature]: features.add(Feature.NUMERIC_UNDERSCORES) elif n.type == token.SLASH: - if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}: + if n.parent and n.parent.type in { + syms.typedargslist, + syms.arglist, + syms.varargslist, + }: features.add(Feature.POS_ONLY_ARGUMENTS) elif n.type == token.COLONEQUAL: @@ -978,9 +1225,8 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None: src_ast = parse_ast(src) except Exception as exc: raise AssertionError( - "cannot use --safe with this file; failed to parse source file. AST" - f" error message: {exc}" - ) + "cannot use --safe with this file; failed to parse source file." + ) from exc try: dst_ast = parse_ast(dst) @@ -1041,7 +1287,7 @@ def patch_click() -> None: """ try: from click import core - from click import _unicodefun # type: ignore + from click import _unicodefun except ModuleNotFoundError: return @@ -1053,6 +1299,7 @@ def patch_click() -> None: def patched_main() -> None: + maybe_install_uvloop() freeze_support() patch_click() main()