X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/73cb6e7734370108742d992d4fe1fa2829f100fd..6c1bd08f16b636de38b92aeb2e0a1e8ebef0a0b1:/src/black/__init__.py?ds=sidebyside

diff --git a/src/black/__init__.py b/src/black/__init__.py
index 7024c9d..4200066 100644
--- a/src/black/__init__.py
+++ b/src/black/__init__.py
@@ -1,7 +1,6 @@
 import asyncio
 from json.decoder import JSONDecodeError
 import json
-from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
 from contextlib import contextmanager
 from datetime import datetime
 from enum import Enum
@@ -10,12 +9,14 @@ from multiprocessing import Manager, freeze_support
 import os
 from pathlib import Path
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
+import platform
 import re
 import signal
 import sys
 import tokenize
 import traceback
 from typing import (
+    TYPE_CHECKING,
     Any,
     Dict,
     Generator,
@@ -39,7 +40,7 @@ from mypy_extensions import mypyc_attr
 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
 from black.const import STDIN_PLACEHOLDER
 from black.nodes import STARS, syms, is_simple_decorator_expression
-from black.nodes import is_string_token
+from black.nodes import is_string_token, is_number_token
 from black.lines import Line, EmptyLineTracker
 from black.linegen import transform_line, LineGenerator, LN
 from black.comments import normalize_fmt_off
@@ -49,7 +50,12 @@ from black.cache import read_cache, write_cache, get_cache_info, filter_cached,
 from black.concurrency import cancel, shutdown, maybe_install_uvloop
 from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
 from black.report import Report, Changed, NothingChanged
-from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
+from black.files import (
+    find_project_root,
+    find_pyproject_toml,
+    parse_pyproject_toml,
+    find_user_pyproject_toml,
+)
 from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
 from black.files import wrap_stream_for_windows
 from black.parsing import InvalidInput  # noqa F401
@@ -71,6 +77,9 @@ from blib2to3.pgen2 import token
 
 from _black_version import version as __version__
 
+if TYPE_CHECKING:
+    from concurrent.futures import Executor
+
 COMPILED = Path(__file__).suffix in (".pyd", ".so")
 
 # types
@@ -258,7 +267,7 @@ def validate_regex(
     "--preview",
     is_flag=True,
     help=(
-        "Enable potentially disruptive style changes that will be added to Black's main"
+        "Enable potentially disruptive style changes that may be added to Black's main"
         " functionality in the next major release."
     ),
 )
@@ -291,7 +300,8 @@ def validate_regex(
     type=str,
     help=(
         "Require a specific version of Black to be running (useful for unifying results"
-        " across many environments e.g. with a pyproject.toml file)."
+        " across many environments e.g. with a pyproject.toml file). It can be"
+        " either a major version number or an exact version."
     ),
 )
 @click.option(
@@ -375,7 +385,10 @@ def validate_regex(
 )
 @click.version_option(
     version=__version__,
-    message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
+    message=(
+        f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
+        f"Python ({platform.python_implementation()}) {platform.python_version()}"
+    ),
 )
 @click.argument(
     "src",
@@ -401,7 +414,7 @@ def validate_regex(
     help="Read configuration from FILE path.",
 )
 @click.pass_context
-def main(
+def main(  # noqa: C901
     ctx: click.Context,
     code: Optional[str],
     line_length: int,
@@ -468,13 +481,27 @@ def main(
 
         if config:
             config_source = ctx.get_parameter_source("config")
-            if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP):
+            user_level_config = str(find_user_pyproject_toml())
+            if config == user_level_config:
+                out(
+                    "Using configuration from user-level config at "
+                    f"'{user_level_config}'.",
+                    fg="blue",
+                )
+            elif config_source in (
+                ParameterSource.DEFAULT,
+                ParameterSource.DEFAULT_MAP,
+            ):
                 out("Using configuration from project root.", fg="blue")
             else:
                 out(f"Using configuration in '{config}'.", fg="blue")
 
     error_msg = "Oh no! 💥 💔 💥"
-    if required_version and required_version != __version__:
+    if (
+        required_version
+        and required_version != __version__
+        and required_version != __version__.split(".")[0]
+    ):
         err(
             f"{error_msg} The required version `{required_version}` does not match"
             f" the running version `{__version__}`!"
@@ -677,6 +704,9 @@ def reformat_code(
         report.failed(path, str(exc))
 
 
+# diff-shades depends on being to monkeypatch this function to operate. I know it's
+# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
+@mypyc_attr(patchable=True)
 def reformat_one(
     src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
 ) -> None:
@@ -740,6 +770,8 @@ def reformat_many(
     workers: Optional[int],
 ) -> None:
     """Reformat multiple files using a ProcessPoolExecutor."""
+    from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
+
     executor: Executor
     loop = asyncio.get_event_loop()
     worker_count = workers if workers is not None else DEFAULT_WORKERS
@@ -781,7 +813,7 @@ async def schedule_formatting(
     mode: Mode,
     report: "Report",
     loop: asyncio.AbstractEventLoop,
-    executor: Executor,
+    executor: "Executor",
 ) -> None:
     """Run formatting of `sources` in parallel using the provided `executor`.
 
@@ -968,17 +1000,7 @@ def check_stability_and_equivalence(
     content differently.
     """
     assert_equivalent(src_contents, dst_contents)
-
-    # Forced second pass to work around optional trailing commas (becoming
-    # forced trailing commas on pass 2) interacting differently with optional
-    # parentheses.  Admittedly ugly.
-    dst_contents_pass2 = format_str(dst_contents, mode=mode)
-    if dst_contents != dst_contents_pass2:
-        dst_contents = dst_contents_pass2
-        assert_equivalent(src_contents, dst_contents, pass_num=2)
-        assert_stable(src_contents, dst_contents, mode=mode)
-    # Note: no need to explicitly call `assert_stable` if `dst_contents` was
-    # the same as `dst_contents_pass2`.
+    assert_stable(src_contents, dst_contents, mode=mode)
 
 
 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
@@ -1108,7 +1130,7 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon
         raise NothingChanged
 
 
-def format_str(src_contents: str, *, mode: Mode) -> FileContent:
+def format_str(src_contents: str, *, mode: Mode) -> str:
     """Reformat a string and return new contents.
 
     `mode` determines formatting options, such as how many characters per line are
@@ -1138,6 +1160,16 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent:
         hey
 
     """
+    dst_contents = _format_str_once(src_contents, mode=mode)
+    # Forced second pass to work around optional trailing commas (becoming
+    # forced trailing commas on pass 2) interacting differently with optional
+    # parentheses.  Admittedly ugly.
+    if src_contents != dst_contents:
+        return _format_str_once(dst_contents, mode=mode)
+    return dst_contents
+
+
+def _format_str_once(src_contents: str, *, mode: Mode) -> str:
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
     dst_contents = []
     future_imports = get_future_imports(src_node)
@@ -1146,7 +1178,7 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent:
     else:
         versions = detect_target_versions(src_node, future_imports=future_imports)
 
-    normalize_fmt_off(src_node)
+    normalize_fmt_off(src_node, preview=mode.preview)
     lines = LineGenerator(mode=mode)
     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
     empty_line = Line(mode=mode)
@@ -1213,8 +1245,7 @@ def get_features_used(  # noqa: C901
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
                 features.add(Feature.F_STRINGS)
 
-        elif n.type == token.NUMBER:
-            assert isinstance(n, Leaf)
+        elif is_number_token(n):
             if "_" in n.value:
                 features.add(Feature.NUMERIC_UNDERSCORES)
 
@@ -1269,6 +1300,25 @@ def get_features_used(  # noqa: C901
         ):
             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
 
+        elif (
+            n.type == syms.except_clause
+            and len(n.children) >= 2
+            and n.children[1].type == token.STAR
+        ):
+            features.add(Feature.EXCEPT_STAR)
+
+        elif n.type in {syms.subscriptlist, syms.trailer} and any(
+            child.type == syms.star_expr for child in n.children
+        ):
+            features.add(Feature.VARIADIC_GENERICS)
+
+        elif (
+            n.type == syms.tname_star
+            and len(n.children) == 3
+            and n.children[2].type == syms.star_expr
+        ):
+            features.add(Feature.VARIADIC_GENERICS)
+
     return features
 
 
@@ -1332,16 +1382,16 @@ def get_future_imports(node: Node) -> Set[str]:
     return imports
 
 
-def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
+def assert_equivalent(src: str, dst: str) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
     try:
         src_ast = parse_ast(src)
     except Exception as exc:
         raise AssertionError(
-            f"cannot use --safe with this file; failed to parse source file AST: "
+            "cannot use --safe with this file; failed to parse source file AST: "
             f"{exc}\n"
-            f"This could be caused by running Black with an older Python version "
-            f"that does not support new syntax used in your source file."
+            "This could be caused by running Black with an older Python version "
+            "that does not support new syntax used in your source file."
         ) from exc
 
     try:
@@ -1349,7 +1399,7 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
     except Exception as exc:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
-            f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
+            f"INTERNAL ERROR: Black produced invalid code: {exc}. "
             "Please report a bug on https://github.com/psf/black/issues.  "
             f"This invalid output might be helpful: {log}"
         ) from None
@@ -1360,14 +1410,17 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
         raise AssertionError(
             "INTERNAL ERROR: Black produced code that is not equivalent to the"
-            f" source on pass {pass_num}.  Please report a bug on "
+            " source.  Please report a bug on "
             f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
         ) from None
 
 
 def assert_stable(src: str, dst: str, mode: Mode) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(dst, mode=mode)
+    # We shouldn't call format_str() here, because that formats the string
+    # twice and may hide a bug where we bounce back and forth between two
+    # versions.
+    newdst = _format_str_once(dst, mode=mode)
     if dst != newdst:
         log = dump_to_file(
             str(mode),
@@ -1401,13 +1454,23 @@ def patch_click() -> None:
     file paths is minimal since it's Python source code.  Moreover, this crash was
     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
     """
+    modules: List[Any] = []
     try:
         from click import core
-        from click import _unicodefun
-    except ModuleNotFoundError:
-        return
+    except ImportError:
+        pass
+    else:
+        modules.append(core)
+    try:
+        # Removed in Click 8.1.0 and newer; we keep this around for users who have
+        # older versions installed.
+        from click import _unicodefun  # type: ignore
+    except ImportError:
+        pass
+    else:
+        modules.append(_unicodefun)
 
-    for module in (core, _unicodefun):
+    for module in modules:
         if hasattr(module, "_verify_python3_env"):
             module._verify_python3_env = lambda: None  # type: ignore
         if hasattr(module, "_verify_python_env"):