X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/c47b91f513052cd39b818ea7c19716423c85c04e..8e618f386995fa89434834e6a793a1057e58112a:/src/black/__init__.py diff --git a/src/black/__init__.py b/src/black/__init__.py index a0c1ad4..222cb3c 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -1,22 +1,17 @@ -import asyncio import io import json -import os import platform import re -import signal import sys import tokenize import traceback from contextlib import contextmanager from dataclasses import replace -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from json.decoder import JSONDecodeError -from multiprocessing import Manager, freeze_support from pathlib import Path from typing import ( - TYPE_CHECKING, Any, Dict, Generator, @@ -35,12 +30,12 @@ from typing import ( import click from click.core import ParameterSource from mypy_extensions import mypyc_attr +from pathspec import PathSpec from pathspec.patterns.gitwildmatch import GitWildMatchPatternError from _black_version import version as __version__ -from black.cache import Cache, filter_cached, get_cache_info, read_cache, write_cache +from black.cache import Cache, get_cache_info, read_cache, write_cache from black.comments import normalize_fmt_off -from black.concurrency import cancel, maybe_install_uvloop, shutdown from black.const import ( DEFAULT_EXCLUDES, DEFAULT_INCLUDES, @@ -67,7 +62,7 @@ from black.handle_ipynb_magics import ( unmask_cell, ) from black.linegen import LN, LineGenerator, transform_line -from black.lines import EmptyLineTracker, Line +from black.lines import EmptyLineTracker, LinesBlock from black.mode import ( FUTURE_FLAG_TO_FEATURE, VERSION_TO_FEATURES, @@ -91,9 +86,6 @@ from black.trans import iter_fexpr_spans from blib2to3.pgen2 import token from blib2to3.pytree import Leaf, Node -if TYPE_CHECKING: - from concurrent.futures import Executor - COMPILED = Path(__file__).suffix in (".pyd", ".so") # types @@ -125,8 +117,6 @@ class WriteBack(Enum): # Legacy name, left for integrations. FileMode = Mode -DEFAULT_WORKERS = os.cpu_count() - def read_pyproject_toml( ctx: click.Context, param: click.Parameter, value: Optional[str] @@ -137,7 +127,9 @@ def read_pyproject_toml( otherwise. """ if not value: - value = find_pyproject_toml(ctx.params.get("src", ())) + value = find_pyproject_toml( + ctx.params.get("src", ()), ctx.params.get("stdin_filename", None) + ) if value is None: return None @@ -229,8 +221,9 @@ def validate_regex( callback=target_version_option_callback, multiple=True, help=( - "Python versions that should be supported by Black's output. [default: per-file" - " auto-detection]" + "Python versions that should be supported by Black's output. By default, Black" + " will try to infer this from the project metadata in pyproject.toml. If this" + " does not yield conclusive results, Black will use per-file auto-detection." ), ) @click.option( @@ -254,11 +247,17 @@ def validate_regex( multiple=True, help=( "When processing Jupyter Notebooks, add the given magic to the list" - f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})." + f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})." " Useful for formatting cells with custom python magics." ), default=[], ) +@click.option( + "-x", + "--skip-source-first-line", + is_flag=True, + help="Skip the first line of the source code.", +) @click.option( "-S", "--skip-string-normalization", @@ -365,6 +364,7 @@ def validate_regex( @click.option( "--stdin-filename", type=str, + is_eager=True, help=( "The name of the file when passing it through stdin. Useful to make " "sure Black will respect --force-exclude option on some " @@ -375,9 +375,11 @@ def validate_regex( "-W", "--workers", type=click.IntRange(min=1), - default=DEFAULT_WORKERS, - show_default=True, - help="Number of parallel workers", + default=None, + help=( + "Number of parallel workers [default: BLACK_NUM_WORKERS environment variable " + "or number of CPUs in the system]" + ), ) @click.option( "-q", @@ -440,6 +442,7 @@ def main( # noqa: C901 pyi: bool, ipynb: bool, python_cell_magics: Sequence[str], + skip_source_first_line: bool, skip_string_normalization: bool, skip_magic_trailing_comma: bool, experimental_string_processing: bool, @@ -452,7 +455,7 @@ def main( # noqa: C901 extend_exclude: Optional[Pattern[str]], force_exclude: Optional[Pattern[str]], stdin_filename: Optional[str], - workers: int, + workers: Optional[int], src: Tuple[str, ...], config: Optional[str], ) -> None: @@ -481,22 +484,6 @@ def main( # noqa: C901 fg="blue", ) - normalized = [ - (source, source) - if source == "-" - else (normalize_path_maybe_ignore(Path(source), root), source) - for source in src - ] - srcs_string = ", ".join( - [ - f'"{_norm}"' - if _norm - else f'\033[31m"{source} (skipping - invalid)"\033[34m' - for _norm, source in normalized - ] - ) - out(f"Sources to be formatted: {srcs_string}", fg="blue") - if config: config_source = ctx.get_parameter_source("config") user_level_config = str(find_user_pyproject_toml()) @@ -513,6 +500,9 @@ def main( # noqa: C901 out("Using configuration from project root.", fg="blue") else: out(f"Using configuration in '{config}'.", fg="blue") + if ctx.default_map: + for param, value in ctx.default_map.items(): + out(f"{param}: {value}") error_msg = "Oh no! 💥 💔 💥" if ( @@ -540,6 +530,7 @@ def main( # noqa: C901 line_length=line_length, is_pyi=pyi, is_ipynb=ipynb, + skip_source_first_line=skip_source_first_line, string_normalization=not skip_string_normalization, magic_trailing_comma=not skip_magic_trailing_comma, experimental_string_processing=experimental_string_processing, @@ -592,6 +583,8 @@ def main( # noqa: C901 report=report, ) else: + from black.concurrency import reformat_many + reformat_many( sources=sources, fast=fast, @@ -625,12 +618,12 @@ def get_sources( ) -> Set[Path]: """Compute the set of files to be formatted.""" sources: Set[Path] = set() + root = ctx.obj["root"] - if exclude is None: - exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) - gitignore = get_gitignore(ctx.obj["root"]) - else: - gitignore = None + using_default_exclude = exclude is None + exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude + gitignore: Optional[Dict[Path, PathSpec]] = None + root_gitignore = get_gitignore(root) for s in src: if s == "-" and stdin_filename: @@ -641,9 +634,15 @@ def get_sources( is_stdin = False if is_stdin or p.is_file(): - normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report) + normalized_path: Optional[str] = normalize_path_maybe_ignore( + p, ctx.obj["root"], report + ) if normalized_path is None: + if verbose: + out(f'Skipping invalid source: "{normalized_path}"', fg="red") continue + if verbose: + out(f'Found input source: "{normalized_path}"', fg="blue") normalized_path = "/" + normalized_path # Hard-exclude any files that matches the `--force-exclude` regex. @@ -665,6 +664,15 @@ def get_sources( sources.add(p) elif p.is_dir(): + p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report) + if verbose: + out(f'Found input source directory: "{p}"', fg="blue") + + if using_default_exclude: + gitignore = { + root: root_gitignore, + p: get_gitignore(p), + } sources.update( gen_python_files( p.iterdir(), @@ -680,9 +688,12 @@ def get_sources( ) ) elif s == "-": + if verbose: + out("Found input source stdin", fg="blue") sources.add(p) else: err(f"invalid path: {s}") + return sources @@ -776,132 +787,6 @@ def reformat_one( report.failed(src, str(exc)) -# diff-shades depends on being to monkeypatch this function to operate. I know it's -# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26 -@mypyc_attr(patchable=True) -def reformat_many( - sources: Set[Path], - fast: bool, - write_back: WriteBack, - mode: Mode, - report: "Report", - workers: Optional[int], -) -> None: - """Reformat multiple files using a ProcessPoolExecutor.""" - from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor - - executor: Executor - worker_count = workers if workers is not None else DEFAULT_WORKERS - if sys.platform == "win32": - # Work around https://bugs.python.org/issue26903 - assert worker_count is not None - worker_count = min(worker_count, 60) - try: - executor = ProcessPoolExecutor(max_workers=worker_count) - except (ImportError, NotImplementedError, OSError): - # we arrive here if the underlying system does not support multi-processing - # like in AWS Lambda or Termux, in which case we gracefully fallback to - # a ThreadPoolExecutor with just a single worker (more workers would not do us - # any good due to the Global Interpreter Lock) - executor = ThreadPoolExecutor(max_workers=1) - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - schedule_formatting( - sources=sources, - fast=fast, - write_back=write_back, - mode=mode, - report=report, - loop=loop, - executor=executor, - ) - ) - finally: - try: - shutdown(loop) - finally: - asyncio.set_event_loop(None) - if executor is not None: - executor.shutdown() - - -async def schedule_formatting( - sources: Set[Path], - fast: bool, - write_back: WriteBack, - mode: Mode, - report: "Report", - loop: asyncio.AbstractEventLoop, - executor: "Executor", -) -> None: - """Run formatting of `sources` in parallel using the provided `executor`. - - (Use ProcessPoolExecutors for actual parallelism.) - - `write_back`, `fast`, and `mode` options are passed to - :func:`format_file_in_place`. - """ - cache: Cache = {} - if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF): - cache = read_cache(mode) - sources, cached = filter_cached(cache, sources) - for src in sorted(cached): - report.done(src, Changed.CACHED) - if not sources: - return - - cancelled = [] - sources_to_cache = [] - lock = None - if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): - # For diff output, we need locks to ensure we don't interleave output - # from different processes. - manager = Manager() - lock = manager.Lock() - tasks = { - asyncio.ensure_future( - loop.run_in_executor( - executor, format_file_in_place, src, fast, mode, write_back, lock - ) - ): src - for src in sorted(sources) - } - pending = tasks.keys() - try: - loop.add_signal_handler(signal.SIGINT, cancel, pending) - loop.add_signal_handler(signal.SIGTERM, cancel, pending) - except NotImplementedError: - # There are no good alternatives for these on Windows. - pass - while pending: - done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - for task in done: - src = tasks.pop(task) - if task.cancelled(): - cancelled.append(task) - elif task.exception(): - report.failed(src, str(task.exception())) - else: - changed = Changed.YES if task.result() else Changed.NO - # If the file was written back or was successfully checked as - # well-formatted, store this information in the cache. - if write_back is WriteBack.YES or ( - write_back is WriteBack.CHECK and changed is Changed.NO - ): - sources_to_cache.append(src) - report.done(src, changed) - if cancelled: - if sys.version_info >= (3, 7): - await asyncio.gather(*cancelled, return_exceptions=True) - else: - await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) - if sources_to_cache: - write_cache(cache, sources_to_cache, mode) - - def format_file_in_place( src: Path, fast: bool, @@ -920,8 +805,11 @@ def format_file_in_place( elif src.suffix == ".ipynb": mode = replace(mode, is_ipynb=True) - then = datetime.utcfromtimestamp(src.stat().st_mtime) + then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc) + header = b"" with open(src, "rb") as buf: + if mode.skip_source_first_line: + header = buf.readline() src_contents, encoding, newline = decode_bytes(buf.read()) try: dst_contents = format_file_contents(src_contents, fast=fast, mode=mode) @@ -931,14 +819,16 @@ def format_file_in_place( raise ValueError( f"File '{src}' cannot be parsed as valid Jupyter notebook." ) from None + src_contents = header.decode(encoding) + src_contents + dst_contents = header.decode(encoding) + dst_contents if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: f.write(dst_contents) elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): - now = datetime.utcnow() - src_name = f"{src}\t{then} +0000" - dst_name = f"{src}\t{now} +0000" + now = datetime.now(timezone.utc) + src_name = f"{src}\t{then}" + dst_name = f"{src}\t{now}" if mode.is_ipynb: diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name) else: @@ -976,7 +866,7 @@ def format_stdin_to_stdout( write a diff to stdout. The `mode` argument is passed to :func:`format_file_contents`. """ - then = datetime.utcnow() + then = datetime.now(timezone.utc) if content is None: src, encoding, newline = decode_bytes(sys.stdin.buffer.read()) @@ -1001,9 +891,9 @@ def format_stdin_to_stdout( dst += "\n" f.write(dst) elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): - now = datetime.utcnow() - src_name = f"STDIN\t{then} +0000" - dst_name = f"STDOUT\t{now} +0000" + now = datetime.now(timezone.utc) + src_name = f"STDIN\t{then}" + dst_name = f"STDOUT\t{now}" d = diff(src, dst, src_name, dst_name) if write_back == WriteBack.COLOR_DIFF: d = color_diff(d) @@ -1032,9 +922,6 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. `mode` is passed to :func:`format_str`. """ - if not src_contents.strip(): - raise NothingChanged - if mode.is_ipynb: dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode) else: @@ -1129,6 +1016,9 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon Operate cell-by-cell, only on code cells, only for Python notebooks. If the ``.ipynb`` originally had a trailing newline, it'll be preserved. """ + if not src_contents: + raise NothingChanged + trailing_newline = src_contents[-1] == "\n" modified = False nb = json.loads(src_contents) @@ -1193,31 +1083,46 @@ def format_str(src_contents: str, *, mode: Mode) -> str: def _format_str_once(src_contents: str, *, mode: Mode) -> str: src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) - dst_contents = [] + dst_blocks: List[LinesBlock] = [] if mode.target_versions: versions = mode.target_versions else: future_imports = get_future_imports(src_node) versions = detect_target_versions(src_node, future_imports=future_imports) - normalize_fmt_off(src_node, preview=mode.preview) - lines = LineGenerator(mode=mode) - elt = EmptyLineTracker(is_pyi=mode.is_pyi) - empty_line = Line(mode=mode) - after = 0 + context_manager_features = { + feature + for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS} + if supports_feature(versions, feature) + } + normalize_fmt_off(src_node) + lines = LineGenerator(mode=mode, features=context_manager_features) + elt = EmptyLineTracker(mode=mode) split_line_features = { feature for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF} if supports_feature(versions, feature) } + block: Optional[LinesBlock] = None for current_line in lines.visit(src_node): - dst_contents.append(str(empty_line) * after) - before, after = elt.maybe_empty_lines(current_line) - dst_contents.append(str(empty_line) * before) + block = elt.maybe_empty_lines(current_line) + dst_blocks.append(block) for line in transform_line( current_line, mode=mode, features=split_line_features ): - dst_contents.append(str(line)) + block.content_lines.append(str(line)) + if dst_blocks: + dst_blocks[-1].after = 0 + dst_contents = [] + for block in dst_blocks: + dst_contents.extend(block.all_lines()) + if not dst_contents: + # Use decode_bytes to retrieve the correct source newline (CRLF or LF), + # and check if normalized_content has more than one line + normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8")) + if "\n" in normalized_content: + return newline + return "" return "".join(dst_contents) @@ -1253,6 +1158,10 @@ def get_features_used( # noqa: C901 - relaxed decorator syntax; - usage of __future__ flags (annotations); - print / exec statements; + - parenthesized context managers; + - match statements; + - except* clause; + - variadic generics; """ features: Set[Feature] = set() if future_imports: @@ -1328,6 +1237,23 @@ def get_features_used( # noqa: C901 ): features.add(Feature.ANN_ASSIGN_EXTENDED_RHS) + elif ( + n.type == syms.with_stmt + and len(n.children) > 2 + and n.children[1].type == syms.atom + ): + atom_children = n.children[1].children + if ( + len(atom_children) == 3 + and atom_children[0].type == token.LPAR + and atom_children[1].type == syms.testlist_gexp + and atom_children[2].type == token.RPAR + ): + features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS) + + elif n.type == syms.match_stmt: + features.add(Feature.PATTERN_MATCHING) + elif ( n.type == syms.except_clause and len(n.children) >= 2 @@ -1347,6 +1273,9 @@ def get_features_used( # noqa: C901 ): features.add(Feature.VARIADIC_GENERICS) + elif n.type in (syms.type_stmt, syms.typeparams): + features.add(Feature.TYPE_PARAMS) + return features @@ -1500,14 +1429,19 @@ def patch_click() -> None: for module in modules: if hasattr(module, "_verify_python3_env"): - module._verify_python3_env = lambda: None # type: ignore + module._verify_python3_env = lambda: None if hasattr(module, "_verify_python_env"): - module._verify_python_env = lambda: None # type: ignore + module._verify_python_env = lambda: None def patched_main() -> None: - maybe_install_uvloop() - freeze_support() + # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows + # environments so just assume we always need to call it if frozen. + if getattr(sys, "frozen", False): + from multiprocessing import freeze_support + + freeze_support() + patch_click() main()