]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump docker/build-push-action from 3 to 4 (#3549)
[etc/vim.git] / src / black / __init__.py
index 2a5c750a5835562ddf5c743b01ab0cc58b4cb75c..4ebf28821c39ca339b05f372de2695246f6fcc99 100644 (file)
@@ -1,10 +1,7 @@
-import asyncio
 import io
 import json
-import os
 import platform
 import re
-import signal
 import sys
 import tokenize
 import traceback
@@ -13,10 +10,8 @@ from dataclasses import replace
 from datetime import datetime
 from enum import Enum
 from json.decoder import JSONDecodeError
-from multiprocessing import Manager, freeze_support
 from pathlib import Path
 from typing import (
-    TYPE_CHECKING,
     Any,
     Dict,
     Generator,
@@ -35,12 +30,12 @@ from typing import (
 import click
 from click.core import ParameterSource
 from mypy_extensions import mypyc_attr
+from pathspec import PathSpec
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
 
 from _black_version import version as __version__
-from black.cache import Cache, filter_cached, get_cache_info, read_cache, write_cache
+from black.cache import Cache, get_cache_info, read_cache, write_cache
 from black.comments import normalize_fmt_off
-from black.concurrency import cancel, maybe_install_uvloop, shutdown
 from black.const import (
     DEFAULT_EXCLUDES,
     DEFAULT_INCLUDES,
@@ -67,7 +62,7 @@ from black.handle_ipynb_magics import (
     unmask_cell,
 )
 from black.linegen import LN, LineGenerator, transform_line
-from black.lines import EmptyLineTracker, Line
+from black.lines import EmptyLineTracker, LinesBlock
 from black.mode import (
     FUTURE_FLAG_TO_FEATURE,
     VERSION_TO_FEATURES,
@@ -87,12 +82,10 @@ from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
 from black.parsing import InvalidInput  # noqa F401
 from black.parsing import lib2to3_parse, parse_ast, stringify_ast
 from black.report import Changed, NothingChanged, Report
+from black.trans import iter_fexpr_spans
 from blib2to3.pgen2 import token
 from blib2to3.pytree import Leaf, Node
 
-if TYPE_CHECKING:
-    from concurrent.futures import Executor
-
 COMPILED = Path(__file__).suffix in (".pyd", ".so")
 
 # types
@@ -124,8 +117,6 @@ class WriteBack(Enum):
 # Legacy name, left for integrations.
 FileMode = Mode
 
-DEFAULT_WORKERS = os.cpu_count()
-
 
 def read_pyproject_toml(
     ctx: click.Context, param: click.Parameter, value: Optional[str]
@@ -228,8 +219,9 @@ def validate_regex(
     callback=target_version_option_callback,
     multiple=True,
     help=(
-        "Python versions that should be supported by Black's output. [default: per-file"
-        " auto-detection]"
+        "Python versions that should be supported by Black's output. By default, Black"
+        " will try to infer this from the project metadata in pyproject.toml. If this"
+        " does not yield conclusive results, Black will use per-file auto-detection."
     ),
 )
 @click.option(
@@ -253,11 +245,17 @@ def validate_regex(
     multiple=True,
     help=(
         "When processing Jupyter Notebooks, add the given magic to the list"
-        f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
+        f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
         " Useful for formatting cells with custom python magics."
     ),
     default=[],
 )
+@click.option(
+    "-x",
+    "--skip-source-first-line",
+    is_flag=True,
+    help="Skip the first line of the source code.",
+)
 @click.option(
     "-S",
     "--skip-string-normalization",
@@ -374,9 +372,8 @@ def validate_regex(
     "-W",
     "--workers",
     type=click.IntRange(min=1),
-    default=DEFAULT_WORKERS,
-    show_default=True,
-    help="Number of parallel workers",
+    default=None,
+    help="Number of parallel workers [default: number of CPUs in the system]",
 )
 @click.option(
     "-q",
@@ -439,6 +436,7 @@ def main(  # noqa: C901
     pyi: bool,
     ipynb: bool,
     python_cell_magics: Sequence[str],
+    skip_source_first_line: bool,
     skip_string_normalization: bool,
     skip_magic_trailing_comma: bool,
     experimental_string_processing: bool,
@@ -451,7 +449,7 @@ def main(  # noqa: C901
     extend_exclude: Optional[Pattern[str]],
     force_exclude: Optional[Pattern[str]],
     stdin_filename: Optional[str],
-    workers: int,
+    workers: Optional[int],
     src: Tuple[str, ...],
     config: Optional[str],
 ) -> None:
@@ -468,7 +466,9 @@ def main(  # noqa: C901
         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
         ctx.exit(1)
 
-    root, method = find_project_root(src) if code is None else (None, None)
+    root, method = (
+        find_project_root(src, stdin_filename) if code is None else (None, None)
+    )
     ctx.obj["root"] = root
 
     if verbose:
@@ -479,14 +479,20 @@ def main(  # noqa: C901
             )
 
             normalized = [
-                (normalize_path_maybe_ignore(Path(source), root), source)
+                (
+                    (source, source)
+                    if source == "-"
+                    else (normalize_path_maybe_ignore(Path(source), root), source)
+                )
                 for source in src
             ]
             srcs_string = ", ".join(
                 [
-                    f'"{_norm}"'
-                    if _norm
-                    else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+                    (
+                        f'"{_norm}"'
+                        if _norm
+                        else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+                    )
                     for _norm, source in normalized
                 ]
             )
@@ -497,8 +503,10 @@ def main(  # noqa: C901
             user_level_config = str(find_user_pyproject_toml())
             if config == user_level_config:
                 out(
-                    "Using configuration from user-level config at "
-                    f"'{user_level_config}'.",
+                    (
+                        "Using configuration from user-level config at "
+                        f"'{user_level_config}'."
+                    ),
                     fg="blue",
                 )
             elif config_source in (
@@ -508,6 +516,9 @@ def main(  # noqa: C901
                 out("Using configuration from project root.", fg="blue")
             else:
                 out(f"Using configuration in '{config}'.", fg="blue")
+            if ctx.default_map:
+                for param, value in ctx.default_map.items():
+                    out(f"{param}: {value}")
 
     error_msg = "Oh no! 💥 💔 💥"
     if (
@@ -535,6 +546,7 @@ def main(  # noqa: C901
         line_length=line_length,
         is_pyi=pyi,
         is_ipynb=ipynb,
+        skip_source_first_line=skip_source_first_line,
         string_normalization=not skip_string_normalization,
         magic_trailing_comma=not skip_magic_trailing_comma,
         experimental_string_processing=experimental_string_processing,
@@ -587,6 +599,8 @@ def main(  # noqa: C901
                 report=report,
             )
         else:
+            from black.concurrency import reformat_many
+
             reformat_many(
                 sources=sources,
                 fast=fast,
@@ -620,12 +634,12 @@ def get_sources(
 ) -> Set[Path]:
     """Compute the set of files to be formatted."""
     sources: Set[Path] = set()
+    root = ctx.obj["root"]
 
-    if exclude is None:
-        exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
-        gitignore = get_gitignore(ctx.obj["root"])
-    else:
-        gitignore = None
+    using_default_exclude = exclude is None
+    exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
+    gitignore: Optional[Dict[Path, PathSpec]] = None
+    root_gitignore = get_gitignore(root)
 
     for s in src:
         if s == "-" and stdin_filename:
@@ -660,6 +674,12 @@ def get_sources(
 
             sources.add(p)
         elif p.is_dir():
+            p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
+            if using_default_exclude:
+                gitignore = {
+                    root: root_gitignore,
+                    p: get_gitignore(p),
+                }
             sources.update(
                 gen_python_files(
                     p.iterdir(),
@@ -771,132 +791,6 @@ def reformat_one(
         report.failed(src, str(exc))
 
 
-# diff-shades depends on being to monkeypatch this function to operate. I know it's
-# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
-@mypyc_attr(patchable=True)
-def reformat_many(
-    sources: Set[Path],
-    fast: bool,
-    write_back: WriteBack,
-    mode: Mode,
-    report: "Report",
-    workers: Optional[int],
-) -> None:
-    """Reformat multiple files using a ProcessPoolExecutor."""
-    from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor
-
-    executor: Executor
-    worker_count = workers if workers is not None else DEFAULT_WORKERS
-    if sys.platform == "win32":
-        # Work around https://bugs.python.org/issue26903
-        assert worker_count is not None
-        worker_count = min(worker_count, 60)
-    try:
-        executor = ProcessPoolExecutor(max_workers=worker_count)
-    except (ImportError, NotImplementedError, OSError):
-        # we arrive here if the underlying system does not support multi-processing
-        # like in AWS Lambda or Termux, in which case we gracefully fallback to
-        # a ThreadPoolExecutor with just a single worker (more workers would not do us
-        # any good due to the Global Interpreter Lock)
-        executor = ThreadPoolExecutor(max_workers=1)
-
-    loop = asyncio.new_event_loop()
-    asyncio.set_event_loop(loop)
-    try:
-        loop.run_until_complete(
-            schedule_formatting(
-                sources=sources,
-                fast=fast,
-                write_back=write_back,
-                mode=mode,
-                report=report,
-                loop=loop,
-                executor=executor,
-            )
-        )
-    finally:
-        try:
-            shutdown(loop)
-        finally:
-            asyncio.set_event_loop(None)
-        if executor is not None:
-            executor.shutdown()
-
-
-async def schedule_formatting(
-    sources: Set[Path],
-    fast: bool,
-    write_back: WriteBack,
-    mode: Mode,
-    report: "Report",
-    loop: asyncio.AbstractEventLoop,
-    executor: "Executor",
-) -> None:
-    """Run formatting of `sources` in parallel using the provided `executor`.
-
-    (Use ProcessPoolExecutors for actual parallelism.)
-
-    `write_back`, `fast`, and `mode` options are passed to
-    :func:`format_file_in_place`.
-    """
-    cache: Cache = {}
-    if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
-        cache = read_cache(mode)
-        sources, cached = filter_cached(cache, sources)
-        for src in sorted(cached):
-            report.done(src, Changed.CACHED)
-    if not sources:
-        return
-
-    cancelled = []
-    sources_to_cache = []
-    lock = None
-    if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
-        # For diff output, we need locks to ensure we don't interleave output
-        # from different processes.
-        manager = Manager()
-        lock = manager.Lock()
-    tasks = {
-        asyncio.ensure_future(
-            loop.run_in_executor(
-                executor, format_file_in_place, src, fast, mode, write_back, lock
-            )
-        ): src
-        for src in sorted(sources)
-    }
-    pending = tasks.keys()
-    try:
-        loop.add_signal_handler(signal.SIGINT, cancel, pending)
-        loop.add_signal_handler(signal.SIGTERM, cancel, pending)
-    except NotImplementedError:
-        # There are no good alternatives for these on Windows.
-        pass
-    while pending:
-        done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
-        for task in done:
-            src = tasks.pop(task)
-            if task.cancelled():
-                cancelled.append(task)
-            elif task.exception():
-                report.failed(src, str(task.exception()))
-            else:
-                changed = Changed.YES if task.result() else Changed.NO
-                # If the file was written back or was successfully checked as
-                # well-formatted, store this information in the cache.
-                if write_back is WriteBack.YES or (
-                    write_back is WriteBack.CHECK and changed is Changed.NO
-                ):
-                    sources_to_cache.append(src)
-                report.done(src, changed)
-    if cancelled:
-        if sys.version_info >= (3, 7):
-            await asyncio.gather(*cancelled, return_exceptions=True)
-        else:
-            await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
-    if sources_to_cache:
-        write_cache(cache, sources_to_cache, mode)
-
-
 def format_file_in_place(
     src: Path,
     fast: bool,
@@ -916,7 +810,10 @@ def format_file_in_place(
         mode = replace(mode, is_ipynb=True)
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
+    header = b""
     with open(src, "rb") as buf:
+        if mode.skip_source_first_line:
+            header = buf.readline()
         src_contents, encoding, newline = decode_bytes(buf.read())
     try:
         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
@@ -926,6 +823,8 @@ def format_file_in_place(
         raise ValueError(
             f"File '{src}' cannot be parsed as valid Jupyter notebook."
         ) from None
+    src_contents = header.decode(encoding) + src_contents
+    dst_contents = header.decode(encoding) + dst_contents
 
     if write_back == WriteBack.YES:
         with open(src, "w", encoding=encoding, newline=newline) as f:
@@ -1027,9 +926,6 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
     `mode` is passed to :func:`format_str`.
     """
-    if not src_contents.strip():
-        raise NothingChanged
-
     if mode.is_ipynb:
         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
     else:
@@ -1124,6 +1020,9 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon
     Operate cell-by-cell, only on code cells, only for Python notebooks.
     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
     """
+    if not src_contents:
+        raise NothingChanged
+
     trailing_newline = src_contents[-1] == "\n"
     modified = False
     nb = json.loads(src_contents)
@@ -1188,31 +1087,46 @@ def format_str(src_contents: str, *, mode: Mode) -> str:
 
 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
-    dst_contents = []
+    dst_blocks: List[LinesBlock] = []
     if mode.target_versions:
         versions = mode.target_versions
     else:
         future_imports = get_future_imports(src_node)
         versions = detect_target_versions(src_node, future_imports=future_imports)
 
-    normalize_fmt_off(src_node, preview=mode.preview)
-    lines = LineGenerator(mode=mode)
-    elt = EmptyLineTracker(is_pyi=mode.is_pyi)
-    empty_line = Line(mode=mode)
-    after = 0
+    context_manager_features = {
+        feature
+        for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
+        if supports_feature(versions, feature)
+    }
+    normalize_fmt_off(src_node)
+    lines = LineGenerator(mode=mode, features=context_manager_features)
+    elt = EmptyLineTracker(mode=mode)
     split_line_features = {
         feature
         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
         if supports_feature(versions, feature)
     }
+    block: Optional[LinesBlock] = None
     for current_line in lines.visit(src_node):
-        dst_contents.append(str(empty_line) * after)
-        before, after = elt.maybe_empty_lines(current_line)
-        dst_contents.append(str(empty_line) * before)
+        block = elt.maybe_empty_lines(current_line)
+        dst_blocks.append(block)
         for line in transform_line(
             current_line, mode=mode, features=split_line_features
         ):
-            dst_contents.append(str(line))
+            block.content_lines.append(str(line))
+    if dst_blocks:
+        dst_blocks[-1].after = 0
+    dst_contents = []
+    for block in dst_blocks:
+        dst_contents.extend(block.all_lines())
+    if not dst_contents:
+        # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
+        # and check if normalized_content has more than one line
+        normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
+        if "\n" in normalized_content:
+            return newline
+        return ""
     return "".join(dst_contents)
 
 
@@ -1240,6 +1154,7 @@ def get_features_used(  # noqa: C901
 
     Currently looking for:
     - f-strings;
+    - self-documenting expressions in f-strings (f"{x=}");
     - underscores in numeric literals;
     - trailing commas after * or ** in function signatures and calls;
     - positional only arguments in function signatures and lambdas;
@@ -1247,6 +1162,10 @@ def get_features_used(  # noqa: C901
     - relaxed decorator syntax;
     - usage of __future__ flags (annotations);
     - print / exec statements;
+    - parenthesized context managers;
+    - match statements;
+    - except* clause;
+    - variadic generics;
     """
     features: Set[Feature] = set()
     if future_imports:
@@ -1261,6 +1180,11 @@ def get_features_used(  # noqa: C901
             value_head = n.value[:2]
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
                 features.add(Feature.F_STRINGS)
+                if Feature.DEBUG_F_STRINGS not in features:
+                    for span_beg, span_end in iter_fexpr_spans(n.value):
+                        if n.value[span_beg : span_end - 1].rstrip().endswith("="):
+                            features.add(Feature.DEBUG_F_STRINGS)
+                            break
 
         elif is_number_token(n):
             if "_" in n.value:
@@ -1317,6 +1241,23 @@ def get_features_used(  # noqa: C901
         ):
             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
 
+        elif (
+            n.type == syms.with_stmt
+            and len(n.children) > 2
+            and n.children[1].type == syms.atom
+        ):
+            atom_children = n.children[1].children
+            if (
+                len(atom_children) == 3
+                and atom_children[0].type == token.LPAR
+                and atom_children[1].type == syms.testlist_gexp
+                and atom_children[2].type == token.RPAR
+            ):
+                features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
+
+        elif n.type == syms.match_stmt:
+            features.add(Feature.PATTERN_MATCHING)
+
         elif (
             n.type == syms.except_clause
             and len(n.children) >= 2
@@ -1489,14 +1430,19 @@ def patch_click() -> None:
 
     for module in modules:
         if hasattr(module, "_verify_python3_env"):
-            module._verify_python3_env = lambda: None  # type: ignore
+            module._verify_python3_env = lambda: None
         if hasattr(module, "_verify_python_env"):
-            module._verify_python_env = lambda: None  # type: ignore
+            module._verify_python_env = lambda: None
 
 
 def patched_main() -> None:
-    maybe_install_uvloop()
-    freeze_support()
+    # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
+    # environments so just assume we always need to call it if frozen.
+    if getattr(sys, "frozen", False):
+        from multiprocessing import freeze_support
+
+        freeze_support()
+
     patch_click()
     main()