X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/104aec555fae0883ef5b53709569bd9c4d420bc5..93701d249e2cadf0ec096a752a5cbbe8da1a1130:/src/black/__init__.py?ds=inline diff --git a/src/black/__init__.py b/src/black/__init__.py index 60f4fa3..59018d0 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -10,7 +10,7 @@ from multiprocessing import Manager, freeze_support import os from pathlib import Path from pathspec.patterns.gitwildmatch import GitWildMatchPatternError -import regex as re +import re import signal import sys import tokenize @@ -30,8 +30,9 @@ from typing import ( Union, ) -from dataclasses import replace import click +from dataclasses import replace +from mypy_extensions import mypyc_attr from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES from black.const import STDIN_PLACEHOLDER @@ -39,7 +40,7 @@ from black.nodes import STARS, syms, is_simple_decorator_expression from black.lines import Line, EmptyLineTracker from black.linegen import transform_line, LineGenerator, LN from black.comments import normalize_fmt_off -from black.mode import Mode, TargetVersion +from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion from black.mode import Feature, supports_feature, VERSION_TO_FEATURES from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache from black.concurrency import cancel, shutdown, maybe_install_uvloop @@ -56,6 +57,7 @@ from black.handle_ipynb_magics import ( remove_trailing_semicolon, put_trailing_semicolon_back, TRANSFORMED_MAGICS, + PYTHON_CELL_MAGICS, jupyter_dependencies_are_installed, ) @@ -66,6 +68,8 @@ from blib2to3.pgen2 import token from _black_version import version as __version__ +COMPILED = Path(__file__).suffix in (".pyd", ".so") + # types FileContent = str Encoding = str @@ -95,6 +99,8 @@ class WriteBack(Enum): # Legacy name, left for integrations. FileMode = Mode +DEFAULT_WORKERS = os.cpu_count() + def read_pyproject_toml( ctx: click.Context, param: click.Parameter, value: Optional[str] @@ -114,7 +120,7 @@ def read_pyproject_toml( except (OSError, ValueError) as e: raise click.FileError( filename=value, hint=f"Error reading configuration file: {e}" - ) + ) from None if not config: return None @@ -168,14 +174,19 @@ def validate_regex( ctx: click.Context, param: click.Parameter, value: Optional[str], -) -> Optional[Pattern]: +) -> Optional[Pattern[str]]: try: return re_compile_maybe_verbose(value) if value is not None else None - except re.error: - raise click.BadParameter("Not a valid regular expression") + except re.error as e: + raise click.BadParameter(f"Not a valid regular expression: {e}") from None -@click.command(context_settings=dict(help_option_names=["-h", "--help"])) +@click.command( + context_settings={"help_option_names": ["-h", "--help"]}, + # While Click does set this field automatically using the docstring, mypyc + # (annoyingly) strips 'em so we need to set it here too. + help="The uncompromising code formatter.", +) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( "-l", @@ -318,6 +329,14 @@ def validate_regex( "editors that rely on using stdin." ), ) +@click.option( + "-W", + "--workers", + type=click.IntRange(min=1), + default=DEFAULT_WORKERS, + show_default=True, + help="Number of parallel workers", +) @click.option( "-q", "--quiet", @@ -336,7 +355,10 @@ def validate_regex( " due to exclusion patterns." ), ) -@click.version_option(version=__version__) +@click.version_option( + version=__version__, + message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})", +) @click.argument( "src", nargs=-1, @@ -377,12 +399,13 @@ def main( experimental_string_processing: bool, quiet: bool, verbose: bool, - required_version: str, - include: Pattern, - exclude: Optional[Pattern], - extend_exclude: Optional[Pattern], - force_exclude: Optional[Pattern], + required_version: Optional[str], + include: Pattern[str], + exclude: Optional[Pattern[str]], + extend_exclude: Optional[Pattern[str]], + force_exclude: Optional[Pattern[str]], stdin_filename: Optional[str], + workers: int, src: Tuple[str, ...], config: Optional[str], ) -> None: @@ -468,6 +491,7 @@ def main( write_back=write_back, mode=mode, report=report, + workers=workers, ) if verbose or not quiet: @@ -643,19 +667,28 @@ def reformat_one( report.failed(src, str(exc)) +# diff-shades depends on being to monkeypatch this function to operate. I know it's +# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26 +@mypyc_attr(patchable=True) def reformat_many( - sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report" + sources: Set[Path], + fast: bool, + write_back: WriteBack, + mode: Mode, + report: "Report", + workers: Optional[int], ) -> None: """Reformat multiple files using a ProcessPoolExecutor.""" executor: Executor loop = asyncio.get_event_loop() - worker_count = os.cpu_count() + worker_count = workers if workers is not None else DEFAULT_WORKERS if sys.platform == "win32": # Work around https://bugs.python.org/issue26903 + assert worker_count is not None worker_count = min(worker_count, 60) try: executor = ProcessPoolExecutor(max_workers=worker_count) - except (ImportError, OSError): + except (ImportError, NotImplementedError, OSError): # we arrive here if the underlying system does not support multi-processing # like in AWS Lambda or Termux, in which case we gracefully fallback to # a ThreadPoolExecutor with just a single worker (more workers would not do us @@ -746,7 +779,10 @@ async def schedule_formatting( sources_to_cache.append(src) report.done(src, changed) if cancelled: - await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) + if sys.version_info >= (3, 7): + await asyncio.gather(*cancelled, return_exceptions=True) + else: + await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) if sources_to_cache: write_cache(cache, sources_to_cache, mode) @@ -777,7 +813,9 @@ def format_file_in_place( except NothingChanged: return False except JSONDecodeError: - raise ValueError(f"File '{src}' cannot be parsed as valid Jupyter notebook.") + raise ValueError( + f"File '{src}' cannot be parsed as valid Jupyter notebook." + ) from None if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: @@ -906,7 +944,9 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo def validate_cell(src: str) -> None: - """Check that cell does not already contain TransformerManager transformations. + """Check that cell does not already contain TransformerManager transformations, + or non-Python cell magics, which might cause tokenizer_rt to break because of + indentations. If a cell contains ``!ls``, then it'll be transformed to ``get_ipython().system('ls')``. However, if the cell originally contained @@ -922,6 +962,8 @@ def validate_cell(src: str) -> None: """ if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): raise NothingChanged + if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS: + raise NothingChanged def format_cell(src: str, *, fast: bool, mode: Mode) -> str: @@ -947,7 +989,7 @@ def format_cell(src: str, *, fast: bool, mode: Mode) -> str: try: masked_src, replacements = mask_cell(src_without_trailing_semicolon) except SyntaxError: - raise NothingChanged + raise NothingChanged from None masked_dst = format_str(masked_src, mode=mode) if not fast: check_stability_and_equivalence(masked_src, masked_dst, mode=mode) @@ -957,7 +999,7 @@ def format_cell(src: str, *, fast: bool, mode: Mode) -> str: ) dst = dst.rstrip("\n") if dst == src: - raise NothingChanged + raise NothingChanged from None return dst @@ -970,14 +1012,14 @@ def validate_metadata(nb: MutableMapping[str, Any]) -> None: """ language = nb.get("metadata", {}).get("language_info", {}).get("name", None) if language is not None and language != "python": - raise NothingChanged + raise NothingChanged from None def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: """Format Jupyter notebook. Operate cell-by-cell, only on code cells, only for Python notebooks. - If the ``.ipynb`` originally had a trailing newline, it'll be preseved. + If the ``.ipynb`` originally had a trailing newline, it'll be preserved. """ trailing_newline = src_contents[-1] == "\n" modified = False @@ -1038,7 +1080,16 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: if mode.target_versions: versions = mode.target_versions else: - versions = detect_target_versions(src_node) + versions = detect_target_versions(src_node, future_imports=future_imports) + + # TODO: fully drop support and this code hopefully in January 2022 :D + if TargetVersion.PY27 in mode.target_versions or versions == {TargetVersion.PY27}: + msg = ( + "DEPRECATION: Python 2 support will be removed in the first stable release " + "expected in January 2022." + ) + err(msg, fg="yellow", bold=True) + normalize_fmt_off(src_node) lines = LineGenerator( mode=mode, @@ -1081,7 +1132,9 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: return tiow.read(), encoding, newline -def get_features_used(node: Node) -> Set[Feature]: +def get_features_used( # noqa: C901 + node: Node, *, future_imports: Optional[Set[str]] = None +) -> Set[Feature]: """Return a set of (relatively) new Python features used in this file. Currently looking for: @@ -1091,8 +1144,17 @@ def get_features_used(node: Node) -> Set[Feature]: - positional only arguments in function signatures and lambdas; - assignment expression; - relaxed decorator syntax; + - usage of __future__ flags (annotations); + - print / exec statements; """ features: Set[Feature] = set() + if future_imports: + features |= { + FUTURE_FLAG_TO_FEATURE[future_import] + for future_import in future_imports + if future_import in FUTURE_FLAG_TO_FEATURE + } + for n in node.pre_order(): if n.type == token.STRING: value_head = n.value[:2] # type: ignore @@ -1100,11 +1162,24 @@ def get_features_used(node: Node) -> Set[Feature]: features.add(Feature.F_STRINGS) elif n.type == token.NUMBER: - if "_" in n.value: # type: ignore + assert isinstance(n, Leaf) + if "_" in n.value: features.add(Feature.NUMERIC_UNDERSCORES) + elif n.value.endswith(("L", "l")): + # Python 2: 10L + features.add(Feature.LONG_INT_LITERAL) + elif len(n.value) >= 2 and n.value[0] == "0" and n.value[1].isdigit(): + # Python 2: 0123; 00123; ... + if not all(char == "0" for char in n.value): + # although we don't want to match 0000 or similar + features.add(Feature.OCTAL_INT_LITERAL) elif n.type == token.SLASH: - if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}: + if n.parent and n.parent.type in { + syms.typedargslist, + syms.arglist, + syms.varargslist, + }: features.add(Feature.POS_ONLY_ARGUMENTS) elif n.type == token.COLONEQUAL: @@ -1135,12 +1210,40 @@ def get_features_used(node: Node) -> Set[Feature]: if argch.type in STARS: features.add(feature) + # Python 2 only features (for its deprecation) except for integers, see above + elif n.type == syms.print_stmt: + features.add(Feature.PRINT_STMT) + elif n.type == syms.exec_stmt: + features.add(Feature.EXEC_STMT) + elif n.type == syms.tfpdef: + # def set_position((x, y), value): + # ... + features.add(Feature.AUTOMATIC_PARAMETER_UNPACKING) + elif n.type == syms.except_clause: + # try: + # ... + # except Exception, err: + # ... + if len(n.children) >= 4: + if n.children[-2].type == token.COMMA: + features.add(Feature.COMMA_STYLE_EXCEPT) + elif n.type == syms.raise_stmt: + # raise Exception, "msg" + if len(n.children) >= 4: + if n.children[-2].type == token.COMMA: + features.add(Feature.COMMA_STYLE_RAISE) + elif n.type == token.BACKQUOTE: + # `i'm surprised this ever existed` + features.add(Feature.BACKQUOTE_REPR) + return features -def detect_target_versions(node: Node) -> Set[TargetVersion]: +def detect_target_versions( + node: Node, *, future_imports: Optional[Set[str]] = None +) -> Set[TargetVersion]: """Detect the version to target based on the nodes used.""" - features = get_features_used(node) + features = get_features_used(node, future_imports=future_imports) return { version for version in TargetVersion if features <= VERSION_TO_FEATURES[version] } @@ -1202,9 +1305,8 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None: src_ast = parse_ast(src) except Exception as exc: raise AssertionError( - "cannot use --safe with this file; failed to parse source file. AST" - f" error message: {exc}" - ) + "cannot use --safe with this file; failed to parse source file." + ) from exc try: dst_ast = parse_ast(dst) @@ -1265,7 +1367,7 @@ def patch_click() -> None: """ try: from click import core - from click import _unicodefun # type: ignore + from click import _unicodefun except ModuleNotFoundError: return