X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/62ed5389fca51384721245ff1c2f1c62a13a04ff..1d2ed2bb421df94a8d86728a187663f1c3898322:/src/black/__init__.py?ds=sidebyside diff --git a/src/black/__init__.py b/src/black/__init__.py index 5c6cb67..fa918ce 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -10,7 +10,7 @@ from multiprocessing import Manager, freeze_support import os from pathlib import Path from pathspec.patterns.gitwildmatch import GitWildMatchPatternError -import regex as re +import re import signal import sys import tokenize @@ -30,16 +30,19 @@ from typing import ( Union, ) -from dataclasses import replace import click +from click.core import ParameterSource +from dataclasses import replace +from mypy_extensions import mypyc_attr from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES from black.const import STDIN_PLACEHOLDER from black.nodes import STARS, syms, is_simple_decorator_expression +from black.nodes import is_string_token from black.lines import Line, EmptyLineTracker from black.linegen import transform_line, LineGenerator, LN from black.comments import normalize_fmt_off -from black.mode import Mode, TargetVersion +from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion from black.mode import Feature, supports_feature, VERSION_TO_FEATURES from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache from black.concurrency import cancel, shutdown, maybe_install_uvloop @@ -56,6 +59,7 @@ from black.handle_ipynb_magics import ( remove_trailing_semicolon, put_trailing_semicolon_back, TRANSFORMED_MAGICS, + PYTHON_CELL_MAGICS, jupyter_dependencies_are_installed, ) @@ -66,6 +70,8 @@ from blib2to3.pgen2 import token from _black_version import version as __version__ +COMPILED = Path(__file__).suffix in (".pyd", ".so") + # types FileContent = str Encoding = str @@ -173,11 +179,16 @@ def validate_regex( ) -> Optional[Pattern[str]]: try: return re_compile_maybe_verbose(value) if value is not None else None - except re.error: - raise click.BadParameter("Not a valid regular expression") from None + except re.error as e: + raise click.BadParameter(f"Not a valid regular expression: {e}") from None -@click.command(context_settings=dict(help_option_names=["-h", "--help"])) +@click.command( + context_settings={"help_option_names": ["-h", "--help"]}, + # While Click does set this field automatically using the docstring, mypyc + # (annoyingly) strips 'em so we need to set it here too. + help="The uncompromising code formatter.", +) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( "-l", @@ -346,7 +357,10 @@ def validate_regex( " due to exclusion patterns." ), ) -@click.version_option(version=__version__) +@click.version_option( + version=__version__, + message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})", +) @click.argument( "src", nargs=-1, @@ -387,7 +401,7 @@ def main( experimental_string_processing: bool, quiet: bool, verbose: bool, - required_version: str, + required_version: Optional[str], include: Pattern[str], exclude: Optional[Pattern[str]], extend_exclude: Optional[Pattern[str]], @@ -398,8 +412,37 @@ def main( config: Optional[str], ) -> None: """The uncompromising code formatter.""" - if config and verbose: - out(f"Using configuration from {config}.", bold=False, fg="blue") + ctx.ensure_object(dict) + root, method = find_project_root(src) if code is None else (None, None) + ctx.obj["root"] = root + + if verbose: + if root: + out( + f"Identified `{root}` as project root containing a {method}.", + fg="blue", + ) + + normalized = [ + (normalize_path_maybe_ignore(Path(source), root), source) + for source in src + ] + srcs_string = ", ".join( + [ + f'"{_norm}"' + if _norm + else f'\033[31m"{source} (skipping - invalid)"\033[34m' + for _norm, source in normalized + ] + ) + out(f"Sources to be formatted: {srcs_string}", fg="blue") + + if config: + config_source = ctx.get_parameter_source("config") + if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP): + out("Using configuration from project root.", fg="blue") + else: + out(f"Using configuration in '{config}'.", fg="blue") error_msg = "Oh no! 💥 💔 💥" if required_version and required_version != __version__: @@ -483,6 +526,8 @@ def main( ) if verbose or not quiet: + if code is None and (verbose or report.change_count or report.failure_count): + out() out(error_msg if report.return_code else "All done! ✨ 🍰 ✨") if code is None: click.echo(str(report), err=True) @@ -503,14 +548,12 @@ def get_sources( stdin_filename: Optional[str], ) -> Set[Path]: """Compute the set of files to be formatted.""" - - root = find_project_root(src) sources: Set[Path] = set() path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx) if exclude is None: exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) - gitignore = get_gitignore(root) + gitignore = get_gitignore(ctx.obj["root"]) else: gitignore = None @@ -523,7 +566,7 @@ def get_sources( is_stdin = False if is_stdin or p.is_file(): - normalized_path = normalize_path_maybe_ignore(p, root, report) + normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report) if normalized_path is None: continue @@ -550,7 +593,7 @@ def get_sources( sources.update( gen_python_files( p.iterdir(), - root, + ctx.obj["root"], include, exclude, extend_exclude, @@ -655,6 +698,9 @@ def reformat_one( report.failed(src, str(exc)) +# diff-shades depends on being to monkeypatch this function to operate. I know it's +# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26 +@mypyc_attr(patchable=True) def reformat_many( sources: Set[Path], fast: bool, @@ -669,10 +715,11 @@ def reformat_many( worker_count = workers if workers is not None else DEFAULT_WORKERS if sys.platform == "win32": # Work around https://bugs.python.org/issue26903 + assert worker_count is not None worker_count = min(worker_count, 60) try: executor = ProcessPoolExecutor(max_workers=worker_count) - except (ImportError, OSError): + except (ImportError, NotImplementedError, OSError): # we arrive here if the underlying system does not support multi-processing # like in AWS Lambda or Termux, in which case we gracefully fallback to # a ThreadPoolExecutor with just a single worker (more workers would not do us @@ -763,7 +810,10 @@ async def schedule_formatting( sources_to_cache.append(src) report.done(src, changed) if cancelled: - await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) + if sys.version_info >= (3, 7): + await asyncio.gather(*cancelled, return_exceptions=True) + else: + await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) if sources_to_cache: write_cache(cache, sources_to_cache, mode) @@ -925,7 +975,9 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo def validate_cell(src: str) -> None: - """Check that cell does not already contain TransformerManager transformations. + """Check that cell does not already contain TransformerManager transformations, + or non-Python cell magics, which might cause tokenizer_rt to break because of + indentations. If a cell contains ``!ls``, then it'll be transformed to ``get_ipython().system('ls')``. However, if the cell originally contained @@ -941,6 +993,8 @@ def validate_cell(src: str) -> None: """ if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): raise NothingChanged + if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS: + raise NothingChanged def format_cell(src: str, *, fast: bool, mode: Mode) -> str: @@ -1057,13 +1111,10 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: if mode.target_versions: versions = mode.target_versions else: - versions = detect_target_versions(src_node) + versions = detect_target_versions(src_node, future_imports=future_imports) + normalize_fmt_off(src_node) - lines = LineGenerator( - mode=mode, - remove_u_prefix="unicode_literals" in future_imports - or supports_feature(versions, Feature.UNICODE_LITERALS), - ) + lines = LineGenerator(mode=mode) elt = EmptyLineTracker(is_pyi=mode.is_pyi) empty_line = Line(mode=mode) after = 0 @@ -1100,7 +1151,9 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: return tiow.read(), encoding, newline -def get_features_used(node: Node) -> Set[Feature]: +def get_features_used( # noqa: C901 + node: Node, *, future_imports: Optional[Set[str]] = None +) -> Set[Feature]: """Return a set of (relatively) new Python features used in this file. Currently looking for: @@ -1110,16 +1163,26 @@ def get_features_used(node: Node) -> Set[Feature]: - positional only arguments in function signatures and lambdas; - assignment expression; - relaxed decorator syntax; + - usage of __future__ flags (annotations); + - print / exec statements; """ features: Set[Feature] = set() + if future_imports: + features |= { + FUTURE_FLAG_TO_FEATURE[future_import] + for future_import in future_imports + if future_import in FUTURE_FLAG_TO_FEATURE + } + for n in node.pre_order(): - if n.type == token.STRING: - value_head = n.value[:2] # type: ignore + if is_string_token(n): + value_head = n.value[:2] if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}: features.add(Feature.F_STRINGS) elif n.type == token.NUMBER: - if "_" in n.value: # type: ignore + assert isinstance(n, Leaf) + if "_" in n.value: features.add(Feature.NUMERIC_UNDERSCORES) elif n.type == token.SLASH: @@ -1158,12 +1221,29 @@ def get_features_used(node: Node) -> Set[Feature]: if argch.type in STARS: features.add(feature) + elif ( + n.type in {syms.return_stmt, syms.yield_expr} + and len(n.children) >= 2 + and n.children[1].type == syms.testlist_star_expr + and any(child.type == syms.star_expr for child in n.children[1].children) + ): + features.add(Feature.UNPACKING_ON_FLOW) + + elif ( + n.type == syms.annassign + and len(n.children) >= 4 + and n.children[3].type == syms.testlist_star_expr + ): + features.add(Feature.ANN_ASSIGN_EXTENDED_RHS) + return features -def detect_target_versions(node: Node) -> Set[TargetVersion]: +def detect_target_versions( + node: Node, *, future_imports: Optional[Set[str]] = None +) -> Set[TargetVersion]: """Detect the version to target based on the nodes used.""" - features = get_features_used(node) + features = get_features_used(node, future_imports=future_imports) return { version for version in TargetVersion if features <= VERSION_TO_FEATURES[version] } @@ -1225,7 +1305,7 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None: src_ast = parse_ast(src) except Exception as exc: raise AssertionError( - "cannot use --safe with this file; failed to parse source file." + f"cannot use --safe with this file; failed to parse source file: {exc}" ) from exc try: