X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/f80f49767cacafaeb20016d483b8315474504e6a..e0c572833a3e2b42cd45237c26a67c6f5be4b09d:/src/black/__init__.py

diff --git a/src/black/__init__.py b/src/black/__init__.py
index ba4d3de..eaf72f9 100644
--- a/src/black/__init__.py
+++ b/src/black/__init__.py
@@ -10,7 +10,7 @@ from multiprocessing import Manager, freeze_support
 import os
 from pathlib import Path
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
-import regex as re
+import re
 import signal
 import sys
 import tokenize
@@ -24,22 +24,26 @@ from typing import (
+    Sequence,
-from dataclasses import replace
 import click
+from click.core import ParameterSource
+from dataclasses import replace
+from mypy_extensions import mypyc_attr
 from black.const import STDIN_PLACEHOLDER
 from black.nodes import STARS, syms, is_simple_decorator_expression
+from black.nodes import is_string_token
 from black.lines import Line, EmptyLineTracker
 from black.linegen import transform_line, LineGenerator, LN
 from black.comments import normalize_fmt_off
-from black.mode import Mode, TargetVersion
+from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
 from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
 from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
 from black.concurrency import cancel, shutdown, maybe_install_uvloop
@@ -56,6 +60,7 @@ from black.handle_ipynb_magics import (
@@ -66,6 +71,8 @@ from blib2to3.pgen2 import token
 from _black_version import version as __version__
+COMPILED = Path(__file__).suffix in (".pyd", ".so")
 # types
 FileContent = str
 Encoding = str
@@ -173,11 +180,16 @@ def validate_regex(
 ) -> Optional[Pattern[str]]:
         return re_compile_maybe_verbose(value) if value is not None else None
-    except re.error:
-        raise click.BadParameter("Not a valid regular expression") from None
+    except re.error as e:
+        raise click.BadParameter(f"Not a valid regular expression: {e}") from None
-@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
+    context_settings={"help_option_names": ["-h", "--help"]},
+    # While Click does set this field automatically using the docstring, mypyc
+    # (annoyingly) strips 'em so we need to set it here too.
+    help="The uncompromising code formatter.",
 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
@@ -214,6 +226,16 @@ def validate_regex(
         "(useful when piping source on standard input)."
+    "--python-cell-magics",
+    multiple=True,
+    help=(
+        "When processing Jupyter Notebooks, add the given magic to the list"
+        f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
+        " Useful for formatting cells with custom python magics."
+    ),
+    default=[],
@@ -230,9 +252,14 @@ def validate_regex(
+    help="(DEPRECATED and now included in --preview) Normalize string literals.",
+    "--preview",
+    is_flag=True,
-        "Experimental option that performs more normalization on string literals."
-        " Currently disabled because it leads to some crashes."
+        "Enable potentially disruptive style changes that will be added to Black's main"
+        " functionality in the next major release."
@@ -346,7 +373,10 @@ def validate_regex(
         " due to exclusion patterns."
+    version=__version__,
+    message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
@@ -382,12 +412,14 @@ def main(
     fast: bool,
     pyi: bool,
     ipynb: bool,
+    python_cell_magics: Sequence[str],
     skip_string_normalization: bool,
     skip_magic_trailing_comma: bool,
     experimental_string_processing: bool,
+    preview: bool,
     quiet: bool,
     verbose: bool,
-    required_version: str,
+    required_version: Optional[str],
     include: Pattern[str],
     exclude: Optional[Pattern[str]],
     extend_exclude: Optional[Pattern[str]],
@@ -398,8 +430,37 @@ def main(
     config: Optional[str],
 ) -> None:
     """The uncompromising code formatter."""
-    if config and verbose:
-        out(f"Using configuration from {config}.", bold=False, fg="blue")
+    ctx.ensure_object(dict)
+    root, method = find_project_root(src) if code is None else (None, None)
+    ctx.obj["root"] = root
+    if verbose:
+        if root:
+            out(
+                f"Identified `{root}` as project root containing a {method}.",
+                fg="blue",
+            )
+            normalized = [
+                (normalize_path_maybe_ignore(Path(source), root), source)
+                for source in src
+            ]
+            srcs_string = ", ".join(
+                [
+                    f'"{_norm}"'
+                    if _norm
+                    else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+                    for _norm, source in normalized
+                ]
+            )
+            out(f"Sources to be formatted: {srcs_string}", fg="blue")
+        if config:
+            config_source = ctx.get_parameter_source("config")
+            if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP):
+                out("Using configuration from project root.", fg="blue")
+            else:
+                out(f"Using configuration in '{config}'.", fg="blue")
     error_msg = "Oh no! 💥 💔 💥"
     if required_version and required_version != __version__:
@@ -426,6 +487,8 @@ def main(
         string_normalization=not skip_string_normalization,
         magic_trailing_comma=not skip_magic_trailing_comma,
+        preview=preview,
+        python_cell_magics=set(python_cell_magics),
     if code is not None:
@@ -483,6 +546,8 @@ def main(
     if verbose or not quiet:
+        if code is None and (verbose or report.change_count or report.failure_count):
+            out()
         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
         if code is None:
             click.echo(str(report), err=True)
@@ -503,14 +568,12 @@ def get_sources(
     stdin_filename: Optional[str],
 ) -> Set[Path]:
     """Compute the set of files to be formatted."""
-    root = find_project_root(src)
     sources: Set[Path] = set()
     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
     if exclude is None:
         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
-        gitignore = get_gitignore(root)
+        gitignore = get_gitignore(ctx.obj["root"])
         gitignore = None
@@ -523,7 +586,7 @@ def get_sources(
             is_stdin = False
         if is_stdin or p.is_file():
-            normalized_path = normalize_path_maybe_ignore(p, root, report)
+            normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
             if normalized_path is None:
@@ -550,7 +613,7 @@ def get_sources(
-                    root,
+                    ctx.obj["root"],
@@ -655,6 +718,9 @@ def reformat_one(
         report.failed(src, str(exc))
+# diff-shades depends on being to monkeypatch this function to operate. I know it's
+# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
 def reformat_many(
     sources: Set[Path],
     fast: bool,
@@ -669,10 +735,11 @@ def reformat_many(
     worker_count = workers if workers is not None else DEFAULT_WORKERS
     if sys.platform == "win32":
         # Work around https://bugs.python.org/issue26903
+        assert worker_count is not None
         worker_count = min(worker_count, 60)
         executor = ProcessPoolExecutor(max_workers=worker_count)
-    except (ImportError, OSError):
+    except (ImportError, NotImplementedError, OSError):
         # we arrive here if the underlying system does not support multi-processing
         # like in AWS Lambda or Termux, in which case we gracefully fallback to
         # a ThreadPoolExecutor with just a single worker (more workers would not do us
@@ -927,8 +994,10 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
     return dst_contents
-def validate_cell(src: str) -> None:
-    """Check that cell does not already contain TransformerManager transformations.
+def validate_cell(src: str, mode: Mode) -> None:
+    """Check that cell does not already contain TransformerManager transformations,
+    or non-Python cell magics, which might cause tokenizer_rt to break because of
+    indentations.
     If a cell contains ``!ls``, then it'll be transformed to
     ``get_ipython().system('ls')``. However, if the cell originally contained
@@ -944,6 +1013,11 @@ def validate_cell(src: str) -> None:
     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
         raise NothingChanged
+    if (
+        src[:2] == "%%"
+        and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
+    ):
+        raise NothingChanged
 def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
@@ -962,7 +1036,7 @@ def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
     could potentially be automagics or multi-line magics, which
     are currently not supported.
-    validate_cell(src)
+    validate_cell(src, mode)
     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
@@ -1060,22 +1134,10 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent:
     if mode.target_versions:
         versions = mode.target_versions
-        versions = detect_target_versions(src_node)
-    # TODO: fully drop support and this code hopefully in January 2022 :D
-    if TargetVersion.PY27 in mode.target_versions or versions == {TargetVersion.PY27}:
-        msg = (
-            "DEPRECATION: Python 2 support will be removed in the first stable release "
-            "expected in January 2022."
-        )
-        err(msg, fg="yellow", bold=True)
+        versions = detect_target_versions(src_node, future_imports=future_imports)
-    lines = LineGenerator(
-        mode=mode,
-        remove_u_prefix="unicode_literals" in future_imports
-        or supports_feature(versions, Feature.UNICODE_LITERALS),
-    )
+    lines = LineGenerator(mode=mode)
     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
     empty_line = Line(mode=mode)
     after = 0
@@ -1112,7 +1174,9 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
         return tiow.read(), encoding, newline
-def get_features_used(node: Node) -> Set[Feature]:  # noqa: C901
+def get_features_used(  # noqa: C901
+    node: Node, *, future_imports: Optional[Set[str]] = None
+) -> Set[Feature]:
     """Return a set of (relatively) new Python features used in this file.
     Currently looking for:
@@ -1122,17 +1186,26 @@ def get_features_used(node: Node) -> Set[Feature]:  # noqa: C901
     - positional only arguments in function signatures and lambdas;
     - assignment expression;
     - relaxed decorator syntax;
+    - usage of __future__ flags (annotations);
     - print / exec statements;
     features: Set[Feature] = set()
+    if future_imports:
+        features |= {
+            FUTURE_FLAG_TO_FEATURE[future_import]
+            for future_import in future_imports
+            if future_import in FUTURE_FLAG_TO_FEATURE
+        }
     for n in node.pre_order():
-        if n.type == token.STRING:
-            value_head = n.value[:2]  # type: ignore
+        if is_string_token(n):
+            value_head = n.value[:2]
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
         elif n.type == token.NUMBER:
-            if "_" in n.value:  # type: ignore
+            assert isinstance(n, Leaf)
+            if "_" in n.value:
         elif n.type == token.SLASH:
@@ -1171,17 +1244,29 @@ def get_features_used(node: Node) -> Set[Feature]:  # noqa: C901
                         if argch.type in STARS:
-        elif n.type == token.PRINT_STMT:
-            features.add(Feature.PRINT_STMT)
-        elif n.type == token.EXEC_STMT:
-            features.add(Feature.EXEC_STMT)
+        elif (
+            n.type in {syms.return_stmt, syms.yield_expr}
+            and len(n.children) >= 2
+            and n.children[1].type == syms.testlist_star_expr
+            and any(child.type == syms.star_expr for child in n.children[1].children)
+        ):
+            features.add(Feature.UNPACKING_ON_FLOW)
+        elif (
+            n.type == syms.annassign
+            and len(n.children) >= 4
+            and n.children[3].type == syms.testlist_star_expr
+        ):
+            features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
     return features
-def detect_target_versions(node: Node) -> Set[TargetVersion]:
+def detect_target_versions(
+    node: Node, *, future_imports: Optional[Set[str]] = None
+) -> Set[TargetVersion]:
     """Detect the version to target based on the nodes used."""
-    features = get_features_used(node)
+    features = get_features_used(node, future_imports=future_imports)
     return {
         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
@@ -1243,7 +1328,10 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
         src_ast = parse_ast(src)
     except Exception as exc:
         raise AssertionError(
-            "cannot use --safe with this file; failed to parse source file."
+            f"cannot use --safe with this file; failed to parse source file AST: "
+            f"{exc}\n"
+            f"This could be caused by running Black with an older Python version "
+            f"that does not support new syntax used in your source file."
         ) from exc