X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/521d1b8129c2d83b4ab49270fe7473802259c2a2..7af77d1cf1fdeb54a45ddae422e1ebc3329129fa:/src/black/__init__.py

diff --git a/src/black/__init__.py b/src/black/__init__.py
index cfa2c76..2d04cf8 100644
--- a/src/black/__init__.py
+++ b/src/black/__init__.py
@@ -1,7 +1,6 @@
 import asyncio
 from json.decoder import JSONDecodeError
 import json
-from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
 from contextlib import contextmanager
 from datetime import datetime
 from enum import Enum
@@ -10,12 +9,14 @@ from multiprocessing import Manager, freeze_support
 import os
 from pathlib import Path
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
+import platform
 import re
 import signal
 import sys
 import tokenize
 import traceback
 from typing import (
+    TYPE_CHECKING,
     Any,
     Dict,
     Generator,
@@ -24,6 +25,7 @@ from typing import (
     MutableMapping,
     Optional,
     Pattern,
+    Sequence,
     Set,
     Sized,
     Tuple,
@@ -38,7 +40,7 @@ from mypy_extensions import mypyc_attr
 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
 from black.const import STDIN_PLACEHOLDER
 from black.nodes import STARS, syms, is_simple_decorator_expression
-from black.nodes import is_string_token
+from black.nodes import is_string_token, is_number_token
 from black.lines import Line, EmptyLineTracker
 from black.linegen import transform_line, LineGenerator, LN
 from black.comments import normalize_fmt_off
@@ -48,7 +50,12 @@ from black.cache import read_cache, write_cache, get_cache_info, filter_cached,
 from black.concurrency import cancel, shutdown, maybe_install_uvloop
 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
 from black.report import Report, Changed, NothingChanged
-from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
+from black.files import (
+    find_project_root,
+    find_pyproject_toml,
+    parse_pyproject_toml,
+    find_user_pyproject_toml,
+)
 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
 from black.files import wrap_stream_for_windows
 from black.parsing import InvalidInput  # noqa F401
@@ -70,6 +77,9 @@ from blib2to3.pgen2 import token
 
 from _black_version import version as __version__
 
+if TYPE_CHECKING:
+    from concurrent.futures import Executor
+
 COMPILED = Path(__file__).suffix in (".pyd", ".so")
 
 # types
@@ -225,6 +235,16 @@ def validate_regex(
         "(useful when piping source on standard input)."
     ),
 )
+@click.option(
+    "--python-cell-magics",
+    multiple=True,
+    help=(
+        "When processing Jupyter Notebooks, add the given magic to the list"
+        f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
+        " Useful for formatting cells with custom python magics."
+    ),
+    default=[],
+)
 @click.option(
     "-S",
     "--skip-string-normalization",
@@ -241,9 +261,14 @@ def validate_regex(
     "--experimental-string-processing",
     is_flag=True,
     hidden=True,
+    help="(DEPRECATED and now included in --preview) Normalize string literals.",
+)
+@click.option(
+    "--preview",
+    is_flag=True,
     help=(
-        "Experimental option that performs more normalization on string literals."
-        " Currently disabled because it leads to some crashes."
+        "Enable potentially disruptive style changes that may be added to Black's main"
+        " functionality in the next major release."
     ),
 )
 @click.option(
@@ -275,7 +300,8 @@ def validate_regex(
     type=str,
     help=(
         "Require a specific version of Black to be running (useful for unifying results"
-        " across many environments e.g. with a pyproject.toml file)."
+        " across many environments e.g. with a pyproject.toml file). It can be"
+        " either a major version number or an exact version."
     ),
 )
 @click.option(
@@ -359,7 +385,10 @@ def validate_regex(
 )
 @click.version_option(
     version=__version__,
-    message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
+    message=(
+        f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
+        f"Python ({platform.python_implementation()}) {platform.python_version()}"
+    ),
 )
 @click.argument(
     "src",
@@ -385,7 +414,7 @@ def validate_regex(
     help="Read configuration from FILE path.",
 )
 @click.pass_context
-def main(
+def main(  # noqa: C901
     ctx: click.Context,
     code: Optional[str],
     line_length: int,
@@ -396,9 +425,11 @@ def main(
     fast: bool,
     pyi: bool,
     ipynb: bool,
+    python_cell_magics: Sequence[str],
     skip_string_normalization: bool,
     skip_magic_trailing_comma: bool,
     experimental_string_processing: bool,
+    preview: bool,
     quiet: bool,
     verbose: bool,
     required_version: Optional[str],
@@ -413,6 +444,17 @@ def main(
 ) -> None:
     """The uncompromising code formatter."""
     ctx.ensure_object(dict)
+
+    if src and code is not None:
+        out(
+            main.get_usage(ctx)
+            + "\n\n'SRC' and 'code' cannot be passed simultaneously."
+        )
+        ctx.exit(1)
+    if not src and code is None:
+        out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
+        ctx.exit(1)
+
     root, method = find_project_root(src) if code is None else (None, None)
     ctx.obj["root"] = root
 
@@ -439,13 +481,27 @@ def main(
 
         if config:
             config_source = ctx.get_parameter_source("config")
-            if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP):
+            user_level_config = str(find_user_pyproject_toml())
+            if config == user_level_config:
+                out(
+                    "Using configuration from user-level config at "
+                    f"'{user_level_config}'.",
+                    fg="blue",
+                )
+            elif config_source in (
+                ParameterSource.DEFAULT,
+                ParameterSource.DEFAULT_MAP,
+            ):
                 out("Using configuration from project root.", fg="blue")
             else:
                 out(f"Using configuration in '{config}'.", fg="blue")
 
     error_msg = "Oh no! 💥 💔 💥"
-    if required_version and required_version != __version__:
+    if (
+        required_version
+        and required_version != __version__
+        and required_version != __version__.split(".")[0]
+    ):
         err(
             f"{error_msg} The required version `{required_version}` does not match"
             f" the running version `{__version__}`!"
@@ -469,6 +525,8 @@ def main(
         string_normalization=not skip_string_normalization,
         magic_trailing_comma=not skip_magic_trailing_comma,
         experimental_string_processing=experimental_string_processing,
+        preview=preview,
+        python_cell_magics=set(python_cell_magics),
     )
 
     if code is not None:
@@ -526,6 +584,8 @@ def main(
             )
 
     if verbose or not quiet:
+        if code is None and (verbose or report.change_count or report.failure_count):
+            out()
         out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
         if code is None:
             click.echo(str(report), err=True)
@@ -547,7 +607,6 @@ def get_sources(
 ) -> Set[Path]:
     """Compute the set of files to be formatted."""
     sources: Set[Path] = set()
-    path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
 
     if exclude is None:
         exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
@@ -645,6 +704,9 @@ def reformat_code(
         report.failed(path, str(exc))
 
 
+# diff-shades depends on being to monkeypatch this function to operate. I know it's
+# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
+@mypyc_attr(patchable=True)
 def reformat_one(
     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
 ) -> None:
@@ -708,6 +770,8 @@ def reformat_many(
     workers: Optional[int],
 ) -> None:
     """Reformat multiple files using a ProcessPoolExecutor."""
+    from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
+
     executor: Executor
     loop = asyncio.get_event_loop()
     worker_count = workers if workers is not None else DEFAULT_WORKERS
@@ -749,7 +813,7 @@ async def schedule_formatting(
     mode: Mode,
     report: "Report",
     loop: asyncio.AbstractEventLoop,
-    executor: Executor,
+    executor: "Executor",
 ) -> None:
     """Run formatting of `sources` in parallel using the provided `executor`.
 
@@ -936,17 +1000,7 @@ def check_stability_and_equivalence(
     content differently.
     """
     assert_equivalent(src_contents, dst_contents)
-
-    # Forced second pass to work around optional trailing commas (becoming
-    # forced trailing commas on pass 2) interacting differently with optional
-    # parentheses.  Admittedly ugly.
-    dst_contents_pass2 = format_str(dst_contents, mode=mode)
-    if dst_contents != dst_contents_pass2:
-        dst_contents = dst_contents_pass2
-        assert_equivalent(src_contents, dst_contents, pass_num=2)
-        assert_stable(src_contents, dst_contents, mode=mode)
-    # Note: no need to explicitly call `assert_stable` if `dst_contents` was
-    # the same as `dst_contents_pass2`.
+    assert_stable(src_contents, dst_contents, mode=mode)
 
 
 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
@@ -972,7 +1026,7 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
     return dst_contents
 
 
-def validate_cell(src: str) -> None:
+def validate_cell(src: str, mode: Mode) -> None:
     """Check that cell does not already contain TransformerManager transformations,
     or non-Python cell magics, which might cause tokenizer_rt to break because of
     indentations.
@@ -991,7 +1045,10 @@ def validate_cell(src: str) -> None:
     """
     if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
         raise NothingChanged
-    if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS:
+    if (
+        src[:2] == "%%"
+        and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
+    ):
         raise NothingChanged
 
 
@@ -1011,7 +1068,7 @@ def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
     could potentially be automagics or multi-line magics, which
     are currently not supported.
     """
-    validate_cell(src)
+    validate_cell(src, mode)
     src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
         src
     )
@@ -1073,7 +1130,7 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon
         raise NothingChanged
 
 
-def format_str(src_contents: str, *, mode: Mode) -> FileContent:
+def format_str(src_contents: str, *, mode: Mode) -> str:
     """Reformat a string and return new contents.
 
     `mode` determines formatting options, such as how many characters per line are
@@ -1103,15 +1160,25 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent:
         hey
 
     """
+    dst_contents = _format_str_once(src_contents, mode=mode)
+    # Forced second pass to work around optional trailing commas (becoming
+    # forced trailing commas on pass 2) interacting differently with optional
+    # parentheses.  Admittedly ugly.
+    if src_contents != dst_contents:
+        return _format_str_once(dst_contents, mode=mode)
+    return dst_contents
+
+
+def _format_str_once(src_contents: str, *, mode: Mode) -> str:
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
     dst_contents = []
-    future_imports = get_future_imports(src_node)
     if mode.target_versions:
         versions = mode.target_versions
     else:
+        future_imports = get_future_imports(src_node)
         versions = detect_target_versions(src_node, future_imports=future_imports)
 
-    normalize_fmt_off(src_node)
+    normalize_fmt_off(src_node, preview=mode.preview)
     lines = LineGenerator(mode=mode)
     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
     empty_line = Line(mode=mode)
@@ -1178,8 +1245,7 @@ def get_features_used(  # noqa: C901
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
                 features.add(Feature.F_STRINGS)
 
-        elif n.type == token.NUMBER:
-            assert isinstance(n, Leaf)
+        elif is_number_token(n):
             if "_" in n.value:
                 features.add(Feature.NUMERIC_UNDERSCORES)
 
@@ -1234,6 +1300,25 @@ def get_features_used(  # noqa: C901
         ):
             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
 
+        elif (
+            n.type == syms.except_clause
+            and len(n.children) >= 2
+            and n.children[1].type == token.STAR
+        ):
+            features.add(Feature.EXCEPT_STAR)
+
+        elif n.type in {syms.subscriptlist, syms.trailer} and any(
+            child.type == syms.star_expr for child in n.children
+        ):
+            features.add(Feature.VARIADIC_GENERICS)
+
+        elif (
+            n.type == syms.tname_star
+            and len(n.children) == 3
+            and n.children[2].type == syms.star_expr
+        ):
+            features.add(Feature.VARIADIC_GENERICS)
+
     return features
 
 
@@ -1297,13 +1382,16 @@ def get_future_imports(node: Node) -> Set[str]:
     return imports
 
 
-def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
+def assert_equivalent(src: str, dst: str) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
     try:
         src_ast = parse_ast(src)
     except Exception as exc:
         raise AssertionError(
-            f"cannot use --safe with this file; failed to parse source file: {exc}"
+            "cannot use --safe with this file; failed to parse source file AST: "
+            f"{exc}\n"
+            "This could be caused by running Black with an older Python version "
+            "that does not support new syntax used in your source file."
         ) from exc
 
     try:
@@ -1311,7 +1399,7 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
     except Exception as exc:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
-            f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
+            f"INTERNAL ERROR: Black produced invalid code: {exc}. "
             "Please report a bug on https://github.com/psf/black/issues.  "
             f"This invalid output might be helpful: {log}"
         ) from None
@@ -1322,14 +1410,17 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
         raise AssertionError(
             "INTERNAL ERROR: Black produced code that is not equivalent to the"
-            f" source on pass {pass_num}.  Please report a bug on "
+            " source.  Please report a bug on "
             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
         ) from None
 
 
 def assert_stable(src: str, dst: str, mode: Mode) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(dst, mode=mode)
+    # We shouldn't call format_str() here, because that formats the string
+    # twice and may hide a bug where we bounce back and forth between two
+    # versions.
+    newdst = _format_str_once(dst, mode=mode)
     if dst != newdst:
         log = dump_to_file(
             str(mode),
@@ -1363,13 +1454,23 @@ def patch_click() -> None:
     file paths is minimal since it's Python source code.  Moreover, this crash was
     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
     """
+    modules: List[Any] = []
     try:
         from click import core
-        from click import _unicodefun
-    except ModuleNotFoundError:
-        return
+    except ImportError:
+        pass
+    else:
+        modules.append(core)
+    try:
+        # Removed in Click 8.1.0 and newer; we keep this around for users who have
+        # older versions installed.
+        from click import _unicodefun  # type: ignore
+    except ImportError:
+        pass
+    else:
+        modules.append(_unicodefun)
 
-    for module in (core, _unicodefun):
+    for module in modules:
         if hasattr(module, "_verify_python3_env"):
             module._verify_python3_env = lambda: None  # type: ignore
         if hasattr(module, "_verify_python_env"):