X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/3e60f6d454616a795acb1e3e2b43efa979de4f46..07c046b7cd0e4ced8f1713c8b8fa9fb0190cf472:/black.py diff --git a/black.py b/black.py index 68c0052..e55e4fe 100644 --- a/black.py +++ b/black.py @@ -186,7 +186,7 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { @dataclass -class FileMode: +class Mode: target_versions: Set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True @@ -209,30 +209,46 @@ class FileMode: return ".".join(parts) +# Legacy name, left for integrations. +FileMode = Mode + + def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool: return all(feature in VERSION_TO_FEATURES[version] for version in target_versions) +def find_pyproject_toml(path_search_start: str) -> Optional[str]: + """Find the absolute filepath to a pyproject.toml if it exists""" + path_project_root = find_project_root(path_search_start) + path_pyproject_toml = path_project_root / "pyproject.toml" + return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + + +def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: + """Parse a pyproject toml file, pulling out relevant parts for Black + + If parsing fails, will raise a toml.TomlDecodeError + """ + pyproject_toml = toml.load(path_config) + config = pyproject_toml.get("tool", {}).get("black", {}) + return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} + + def read_pyproject_toml( - ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None] + ctx: click.Context, param: click.Parameter, value: Optional[str] ) -> Optional[str]: """Inject Black configuration from "pyproject.toml" into defaults in `ctx`. Returns the path to a successfully found and read configuration file, None otherwise. """ - assert not isinstance(value, (int, bool)), "Invalid parameter type passed" if not value: - root = find_project_root(ctx.params.get("src", ())) - path = root / "pyproject.toml" - if path.is_file(): - value = str(path) - else: + value = find_pyproject_toml(ctx.params.get("src", ())) + if value is None: return None try: - pyproject_toml = toml.load(value) - config = pyproject_toml.get("tool", {}).get("black", {}) + config = parse_pyproject_toml(value) except (toml.TomlDecodeError, OSError) as e: raise click.FileError( filename=value, hint=f"Error reading configuration file: {e}" @@ -241,11 +257,12 @@ def read_pyproject_toml( if not config: return None - if ctx.default_map is None: - ctx.default_map = {} - ctx.default_map.update( # type: ignore # bad types in .pyi - {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} - ) + default_map: Dict[str, Any] = {} + if ctx.default_map: + default_map.update(ctx.default_map) + default_map.update(config) + + ctx.default_map = default_map return value @@ -379,7 +396,12 @@ def target_version_option_callback( @click.option( "--config", type=click.Path( - exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + allow_dash=False, + path_type=str, ), is_eager=True, callback=read_pyproject_toml, @@ -408,7 +430,7 @@ def main( write_back = WriteBack.from_configuration(check=check, diff=diff) if target_version: if py36: - err(f"Cannot use both --target-version and --py36") + err("Cannot use both --target-version and --py36") ctx.exit(2) else: versions = set(target_version) @@ -421,7 +443,7 @@ def main( else: # We'll autodetect later. versions = set() - mode = FileMode( + mode = Mode( target_versions=versions, line_length=line_length, is_pyi=pyi, @@ -442,7 +464,7 @@ def main( except re.error: err(f"Invalid regular expression for exclude given: {exclude!r}") ctx.exit(2) - report = Report(check=check, quiet=quiet, verbose=verbose) + report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose) root = find_project_root(src) sources: Set[Path] = set() path_empty(src, quiet, verbose, ctx) @@ -496,7 +518,7 @@ def path_empty( def reformat_one( - src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report" + src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: """Reformat a single file under `src` without spawning child processes. @@ -529,11 +551,7 @@ def reformat_one( def reformat_many( - sources: Set[Path], - fast: bool, - write_back: WriteBack, - mode: FileMode, - report: "Report", + sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: """Reformat multiple files using a ProcessPoolExecutor.""" loop = asyncio.get_event_loop() @@ -563,7 +581,7 @@ async def schedule_formatting( sources: Set[Path], fast: bool, write_back: WriteBack, - mode: FileMode, + mode: Mode, report: "Report", loop: asyncio.AbstractEventLoop, executor: Executor, @@ -633,7 +651,7 @@ async def schedule_formatting( def format_file_in_place( src: Path, fast: bool, - mode: FileMode, + mode: Mode, write_back: WriteBack = WriteBack.NO, lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy ) -> bool: @@ -677,7 +695,7 @@ def format_file_in_place( def format_stdin_to_stdout( - fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode + fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode ) -> bool: """Format file on stdin. Return True if changed. @@ -709,9 +727,7 @@ def format_stdin_to_stdout( f.detach() -def format_file_contents( - src_contents: str, *, fast: bool, mode: FileMode -) -> FileContent: +def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: """Reformat contents a file and return new contents. If `fast` is False, additionally confirm that the reformatted code is @@ -731,11 +747,34 @@ def format_file_contents( return dst_contents -def format_str(src_contents: str, *, mode: FileMode) -> FileContent: +def format_str(src_contents: str, *, mode: Mode) -> FileContent: """Reformat a string and return new contents. `mode` determines formatting options, such as how many characters per line are - allowed. + allowed. Example: + + >>> import black + >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode())) + def f(arg: str = "") -> None: + ... + + A more complex example: + >>> print( + ... black.format_str( + ... "def f(arg:str='')->None: hey", + ... mode=black.Mode( + ... target_versions={black.TargetVersion.PY36}, + ... line_length=10, + ... string_normalization=False, + ... is_pyi=False, + ... ), + ... ), + ... ) + def f( + arg: str = '', + ) -> None: + hey + """ src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) dst_contents = [] @@ -760,11 +799,9 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent: if supports_feature(versions, feature) } for current_line in lines.visit(src_node): - for _ in range(after): - dst_contents.append(str(empty_line)) + dst_contents.append(str(empty_line) * after) before, after = elt.maybe_empty_lines(current_line) - for _ in range(before): - dst_contents.append(str(empty_line)) + dst_contents.append(str(empty_line) * before) for line in split_line( current_line, line_length=mode.line_length, features=split_line_features ): @@ -1407,7 +1444,10 @@ class Line: for leaf_id, comments in self.comments.items(): for comment in comments: if is_type_comment(comment): - if leaf_id not in ignored_ids or comment_seen: + if comment_seen or ( + not is_type_comment(comment, " ignore") + and leaf_id not in ignored_ids + ): return True comment_seen = True @@ -1446,11 +1486,7 @@ class Line: return False def contains_multiline_strings(self) -> bool: - for leaf in self.leaves: - if is_multiline_string(leaf): - return True - - return False + return any(is_multiline_string(leaf) for leaf in self.leaves) def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: """Remove trailing comma if there is one and it's safe.""" @@ -2834,7 +2870,7 @@ def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None: match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL) assert match is not None, f"failed to match string {leaf.value!r}" orig_prefix = match.group(1) - new_prefix = orig_prefix.lower() + new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u") if remove_u_prefix: new_prefix = new_prefix.replace("u", "") leaf.value = f"{new_prefix}{match.group(2)}" @@ -3080,18 +3116,49 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]: """ container: Optional[LN] = container_of(leaf) while container is not None and container.type != token.ENDMARKER: - is_fmt_on = False - for comment in list_comments(container.prefix, is_endmarker=False): - if comment.value in FMT_ON: - is_fmt_on = True - elif comment.value in FMT_OFF: - is_fmt_on = False - if is_fmt_on: + if fmt_on(container): return - yield container + # fix for fmt: on in children + if contains_fmt_on_at_column(container, leaf.column): + for child in container.children: + if contains_fmt_on_at_column(child, leaf.column): + return + yield child + else: + yield container + container = container.next_sibling + + +def fmt_on(container: LN) -> bool: + is_fmt_on = False + for comment in list_comments(container.prefix, is_endmarker=False): + if comment.value in FMT_ON: + is_fmt_on = True + elif comment.value in FMT_OFF: + is_fmt_on = False + return is_fmt_on + - container = container.next_sibling +def contains_fmt_on_at_column(container: LN, column: int) -> bool: + for child in container.children: + if ( + isinstance(child, Node) + and first_leaf_column(child) == column + or isinstance(child, Leaf) + and child.column == column + ): + if fmt_on(child): + return True + + return False + + +def first_leaf_column(node: Node) -> Optional[int]: + for child in node.children: + if isinstance(child, Leaf): + return child.column + return None def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: @@ -3538,7 +3605,7 @@ def gen_python_files_in_dir( for child in path.iterdir(): # First ignore files matching .gitignore if gitignore.match_file(child.as_posix()): - report.path_ignored(child, f"matches the .gitignore file content") + report.path_ignored(child, "matches the .gitignore file content") continue # Then ignore with `exclude` option. @@ -3562,7 +3629,7 @@ def gen_python_files_in_dir( exclude_match = exclude.search(normalized_path) if exclude_match and exclude_match.group(0): - report.path_ignored(child, f"matches the --exclude regular expression") + report.path_ignored(child, "matches the --exclude regular expression") continue if child.is_dir(): @@ -3594,7 +3661,7 @@ def find_project_root(srcs: Iterable[str]) -> Path: # Append a fake file so `parents` below returns `common_base_dir`, too. common_base /= "fake-file" for directory in common_base.parents: - if (directory / ".git").is_dir(): + if (directory / ".git").exists(): return directory if (directory / ".hg").is_dir(): @@ -3611,6 +3678,7 @@ class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" check: bool = False + diff: bool = False quiet: bool = False verbose: bool = False change_count: int = 0 @@ -3620,7 +3688,7 @@ class Report: def done(self, src: Path, changed: Changed) -> None: """Increment the counter for successful reformatting. Write out a message.""" if changed is Changed.YES: - reformatted = "would reformat" if self.check else "reformatted" + reformatted = "would reformat" if self.check or self.diff else "reformatted" if self.verbose or not self.quiet: out(f"{reformatted} {src}") self.change_count += 1 @@ -3666,7 +3734,7 @@ class Report: Use `click.unstyle` to remove colors. """ - if self.check: + if self.check or self.diff: reformatted = "would be reformatted" unchanged = "would be left unchanged" failed = "would fail to reformat" @@ -3804,7 +3872,7 @@ def assert_equivalent(src: str, dst: str) -> None: ) from None -def assert_stable(src: str, dst: str, mode: FileMode) -> None: +def assert_stable(src: str, dst: str, mode: Mode) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" newdst = format_str(dst, mode=mode) if dst != newdst: @@ -3846,8 +3914,8 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: """Return a unified diff string between strings `a` and `b`.""" import difflib - a_lines = [line + "\n" for line in a.split("\n")] - b_lines = [line + "\n" for line in b.split("\n")] + a_lines = [line + "\n" for line in a.splitlines()] + b_lines = [line + "\n" for line in b.splitlines()] return "".join( difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5) ) @@ -4075,11 +4143,11 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool: return False -def get_cache_file(mode: FileMode) -> Path: +def get_cache_file(mode: Mode) -> Path: return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle" -def read_cache(mode: FileMode) -> Cache: +def read_cache(mode: Mode) -> Cache: """Read the cache if it exists and is well formed. If it is not well formed, the call to write_cache later should resolve the issue. @@ -4119,7 +4187,7 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set return todo, done -def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None: +def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None: """Update the cache file.""" cache_file = get_cache_file(mode) try: