X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/afed2c01903465f9a486ac481a66aa3413cc1b01..5d0d5936db2ed7a01c50a374e32753e1afe9cc71:/src/black/__init__.py diff --git a/src/black/__init__.py b/src/black/__init__.py index 117dc83..9f44722 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -1,6 +1,5 @@ import io import json -import os import platform import re import sys @@ -28,14 +27,10 @@ from typing import ( Union, ) -if sys.version_info >= (3, 8): - from typing import Final -else: - from typing_extensions import Final - import click from click.core import ParameterSource from mypy_extensions import mypyc_attr +from pathspec import PathSpec from pathspec.patterns.gitwildmatch import GitWildMatchPatternError from _black_version import version as __version__ @@ -67,7 +62,7 @@ from black.handle_ipynb_magics import ( unmask_cell, ) from black.linegen import LN, LineGenerator, transform_line -from black.lines import EmptyLineTracker, Line +from black.lines import EmptyLineTracker, LinesBlock from black.mode import ( FUTURE_FLAG_TO_FEATURE, VERSION_TO_FEATURES, @@ -92,7 +87,6 @@ from blib2to3.pgen2 import token from blib2to3.pytree import Leaf, Node COMPILED = Path(__file__).suffix in (".pyd", ".so") -DEFAULT_WORKERS: Final = os.cpu_count() # types FileContent = str @@ -255,6 +249,12 @@ def validate_regex( ), default=[], ) +@click.option( + "-x", + "--skip-source-first-line", + is_flag=True, + help="Skip the first line of the source code.", +) @click.option( "-S", "--skip-string-normalization", @@ -371,9 +371,8 @@ def validate_regex( "-W", "--workers", type=click.IntRange(min=1), - default=DEFAULT_WORKERS, - show_default=True, - help="Number of parallel workers", + default=None, + help="Number of parallel workers [default: number of CPUs in the system]", ) @click.option( "-q", @@ -436,6 +435,7 @@ def main( # noqa: C901 pyi: bool, ipynb: bool, python_cell_magics: Sequence[str], + skip_source_first_line: bool, skip_string_normalization: bool, skip_magic_trailing_comma: bool, experimental_string_processing: bool, @@ -448,7 +448,7 @@ def main( # noqa: C901 extend_exclude: Optional[Pattern[str]], force_exclude: Optional[Pattern[str]], stdin_filename: Optional[str], - workers: int, + workers: Optional[int], src: Tuple[str, ...], config: Optional[str], ) -> None: @@ -478,16 +478,20 @@ def main( # noqa: C901 ) normalized = [ - (source, source) - if source == "-" - else (normalize_path_maybe_ignore(Path(source), root), source) + ( + (source, source) + if source == "-" + else (normalize_path_maybe_ignore(Path(source), root), source) + ) for source in src ] srcs_string = ", ".join( [ - f'"{_norm}"' - if _norm - else f'\033[31m"{source} (skipping - invalid)"\033[34m' + ( + f'"{_norm}"' + if _norm + else f'\033[31m"{source} (skipping - invalid)"\033[34m' + ) for _norm, source in normalized ] ) @@ -498,8 +502,10 @@ def main( # noqa: C901 user_level_config = str(find_user_pyproject_toml()) if config == user_level_config: out( - "Using configuration from user-level config at " - f"'{user_level_config}'.", + ( + "Using configuration from user-level config at " + f"'{user_level_config}'." + ), fg="blue", ) elif config_source in ( @@ -509,6 +515,9 @@ def main( # noqa: C901 out("Using configuration from project root.", fg="blue") else: out(f"Using configuration in '{config}'.", fg="blue") + if ctx.default_map: + for param, value in ctx.default_map.items(): + out(f"{param}: {value}") error_msg = "Oh no! 💥 💔 💥" if ( @@ -536,6 +545,7 @@ def main( # noqa: C901 line_length=line_length, is_pyi=pyi, is_ipynb=ipynb, + skip_source_first_line=skip_source_first_line, string_normalization=not skip_string_normalization, magic_trailing_comma=not skip_magic_trailing_comma, experimental_string_processing=experimental_string_processing, @@ -625,6 +635,11 @@ def get_sources( sources: Set[Path] = set() root = ctx.obj["root"] + using_default_exclude = exclude is None + exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude + gitignore: Optional[Dict[Path, PathSpec]] = None + root_gitignore = get_gitignore(root) + for s in src: if s == "-" and stdin_filename: p = Path(stdin_filename) @@ -658,11 +673,11 @@ def get_sources( sources.add(p) elif p.is_dir(): - if exclude is None: - exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) - gitignore = get_gitignore(root) - else: - gitignore = None + if using_default_exclude: + gitignore = { + root: root_gitignore, + root / p: get_gitignore(p), + } sources.update( gen_python_files( p.iterdir(), @@ -793,7 +808,10 @@ def format_file_in_place( mode = replace(mode, is_ipynb=True) then = datetime.utcfromtimestamp(src.stat().st_mtime) + header = b"" with open(src, "rb") as buf: + if mode.skip_source_first_line: + header = buf.readline() src_contents, encoding, newline = decode_bytes(buf.read()) try: dst_contents = format_file_contents(src_contents, fast=fast, mode=mode) @@ -803,6 +821,8 @@ def format_file_in_place( raise ValueError( f"File '{src}' cannot be parsed as valid Jupyter notebook." ) from None + src_contents = header.decode(encoding) + src_contents + dst_contents = header.decode(encoding) + dst_contents if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: @@ -904,7 +924,7 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. `mode` is passed to :func:`format_str`. """ - if not src_contents.strip(): + if not mode.preview and not src_contents.strip(): raise NothingChanged if mode.is_ipynb: @@ -1001,6 +1021,9 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon Operate cell-by-cell, only on code cells, only for Python notebooks. If the ``.ipynb`` originally had a trailing newline, it'll be preserved. """ + if mode.preview and not src_contents: + raise NothingChanged + trailing_newline = src_contents[-1] == "\n" modified = False nb = json.loads(src_contents) @@ -1065,7 +1088,7 @@ def format_str(src_contents: str, *, mode: Mode) -> str: def _format_str_once(src_contents: str, *, mode: Mode) -> str: src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) - dst_contents = [] + dst_blocks: List[LinesBlock] = [] if mode.target_versions: versions = mode.target_versions else: @@ -1074,22 +1097,32 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str: normalize_fmt_off(src_node, preview=mode.preview) lines = LineGenerator(mode=mode) - elt = EmptyLineTracker(is_pyi=mode.is_pyi) - empty_line = Line(mode=mode) - after = 0 + elt = EmptyLineTracker(mode=mode) split_line_features = { feature for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF} if supports_feature(versions, feature) } + block: Optional[LinesBlock] = None for current_line in lines.visit(src_node): - dst_contents.append(str(empty_line) * after) - before, after = elt.maybe_empty_lines(current_line) - dst_contents.append(str(empty_line) * before) + block = elt.maybe_empty_lines(current_line) + dst_blocks.append(block) for line in transform_line( current_line, mode=mode, features=split_line_features ): - dst_contents.append(str(line)) + block.content_lines.append(str(line)) + if dst_blocks: + dst_blocks[-1].after = 0 + dst_contents = [] + for block in dst_blocks: + dst_contents.extend(block.all_lines()) + if mode.preview and not dst_contents: + # Use decode_bytes to retrieve the correct source newline (CRLF or LF), + # and check if normalized_content has more than one line + normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8")) + if "\n" in normalized_content: + return newline + return "" return "".join(dst_contents) @@ -1372,13 +1405,15 @@ def patch_click() -> None: for module in modules: if hasattr(module, "_verify_python3_env"): - module._verify_python3_env = lambda: None # type: ignore + module._verify_python3_env = lambda: None if hasattr(module, "_verify_python_env"): - module._verify_python_env = lambda: None # type: ignore + module._verify_python_env = lambda: None def patched_main() -> None: - if sys.platform == "win32" and getattr(sys, "frozen", False): + # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows + # environments so just assume we always need to call it if frozen. + if getattr(sys, "frozen", False): from multiprocessing import freeze_support freeze_support()