X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/31f4105731c9e5a6930905d358b62b2b834b940b..a33823e85956b30b3785120478775f8a9fe64b62:/black.py?ds=sidebyside diff --git a/black.py b/black.py index ddeaa88..31859d1 100644 --- a/black.py +++ b/black.py @@ -37,6 +37,8 @@ from typing import ( Union, cast, ) +from typing_extensions import Final +from mypy_extensions import mypyc_attr from appdirs import user_cache_dir from dataclasses import dataclass, field, replace @@ -184,7 +186,7 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { @dataclass -class FileMode: +class Mode: target_versions: Set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True @@ -207,10 +209,31 @@ class FileMode: return ".".join(parts) +# Legacy name, left for integrations. +FileMode = Mode + + def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool: return all(feature in VERSION_TO_FEATURES[version] for version in target_versions) +def find_pyproject_toml(path_search_start: str) -> Optional[str]: + """Find the absolute filepath to a pyproject.toml if it exists""" + path_project_root = find_project_root(path_search_start) + path_pyproject_toml = path_project_root / "pyproject.toml" + return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + + +def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: + """Parse a pyproject toml file, pulling out relevant parts for Black + + If parsing fails, will raise a toml.TomlDecodeError + """ + pyproject_toml = toml.load(path_config) + config = pyproject_toml.get("tool", {}).get("black", {}) + return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} + + def read_pyproject_toml( ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None] ) -> Optional[str]: @@ -221,16 +244,12 @@ def read_pyproject_toml( """ assert not isinstance(value, (int, bool)), "Invalid parameter type passed" if not value: - root = find_project_root(ctx.params.get("src", ())) - path = root / "pyproject.toml" - if path.is_file(): - value = str(path) - else: + value = find_pyproject_toml(ctx.params.get("src", ())) + if value is None: return None try: - pyproject_toml = toml.load(value) - config = pyproject_toml.get("tool", {}).get("black", {}) + config = parse_pyproject_toml(value) except (toml.TomlDecodeError, OSError) as e: raise click.FileError( filename=value, hint=f"Error reading configuration file: {e}" @@ -241,12 +260,21 @@ def read_pyproject_toml( if ctx.default_map is None: ctx.default_map = {} - ctx.default_map.update( # type: ignore # bad types in .pyi - {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} - ) + ctx.default_map.update(config) # type: ignore # bad types in .pyi return value +def target_version_option_callback( + c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...] +) -> List[TargetVersion]: + """Compute the target versions from a --target-version flag. + + This is its own function because mypy couldn't infer the type correctly + when it was a lambda, causing mypyc trouble. + """ + return [TargetVersion[val.upper()] for val in v] + + @click.command(context_settings=dict(help_option_names=["-h", "--help"])) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( @@ -261,7 +289,7 @@ def read_pyproject_toml( "-t", "--target-version", type=click.Choice([v.name.lower() for v in TargetVersion]), - callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v], + callback=target_version_option_callback, multiple=True, help=( "Python versions that should be supported by Black's output. [default: " @@ -388,7 +416,7 @@ def main( verbose: bool, include: str, exclude: str, - src: Tuple[str], + src: Tuple[str, ...], config: Optional[str], ) -> None: """The uncompromising code formatter.""" @@ -408,7 +436,7 @@ def main( else: # We'll autodetect later. versions = set() - mode = FileMode( + mode = Mode( target_versions=versions, line_length=line_length, is_pyi=pyi, @@ -429,7 +457,7 @@ def main( except re.error: err(f"Invalid regular expression for exclude given: {exclude!r}") ctx.exit(2) - report = Report(check=check, quiet=quiet, verbose=verbose) + report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose) root = find_project_root(src) sources: Set[Path] = set() path_empty(src, quiet, verbose, ctx) @@ -470,7 +498,9 @@ def main( ctx.exit(report.return_code) -def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None: +def path_empty( + src: Tuple[str, ...], quiet: bool, verbose: bool, ctx: click.Context +) -> None: """ Exit if there is no `src` provided for formatting """ @@ -481,7 +511,7 @@ def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) def reformat_one( - src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report" + src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: """Reformat a single file under `src` without spawning child processes. @@ -514,11 +544,7 @@ def reformat_one( def reformat_many( - sources: Set[Path], - fast: bool, - write_back: WriteBack, - mode: FileMode, - report: "Report", + sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: """Reformat multiple files using a ProcessPoolExecutor.""" loop = asyncio.get_event_loop() @@ -548,7 +574,7 @@ async def schedule_formatting( sources: Set[Path], fast: bool, write_back: WriteBack, - mode: FileMode, + mode: Mode, report: "Report", loop: asyncio.AbstractEventLoop, executor: Executor, @@ -585,7 +611,7 @@ async def schedule_formatting( ): src for src in sorted(sources) } - pending: Iterable[asyncio.Future] = tasks.keys() + pending: Iterable["asyncio.Future[bool]"] = tasks.keys() try: loop.add_signal_handler(signal.SIGINT, cancel, pending) loop.add_signal_handler(signal.SIGTERM, cancel, pending) @@ -618,7 +644,7 @@ async def schedule_formatting( def format_file_in_place( src: Path, fast: bool, - mode: FileMode, + mode: Mode, write_back: WriteBack = WriteBack.NO, lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy ) -> bool: @@ -639,10 +665,10 @@ def format_file_in_place( except NothingChanged: return False - if write_back == write_back.YES: + if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: f.write(dst_contents) - elif write_back == write_back.DIFF: + elif write_back == WriteBack.DIFF: now = datetime.utcnow() src_name = f"{src}\t{then} +0000" dst_name = f"{src}\t{now} +0000" @@ -662,7 +688,7 @@ def format_file_in_place( def format_stdin_to_stdout( - fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode + fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode ) -> bool: """Format file on stdin. Return True if changed. @@ -694,9 +720,7 @@ def format_stdin_to_stdout( f.detach() -def format_file_contents( - src_contents: str, *, fast: bool, mode: FileMode -) -> FileContent: +def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: """Reformat contents a file and return new contents. If `fast` is False, additionally confirm that the reformatted code is @@ -716,11 +740,34 @@ def format_file_contents( return dst_contents -def format_str(src_contents: str, *, mode: FileMode) -> FileContent: +def format_str(src_contents: str, *, mode: Mode) -> FileContent: """Reformat a string and return new contents. `mode` determines formatting options, such as how many characters per line are - allowed. + allowed. Example: + + >>> import black + >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode())) + def f(arg: str = "") -> None: + ... + + A more complex example: + >>> print( + ... black.format_str( + ... "def f(arg:str='')->None: hey", + ... mode=black.Mode( + ... target_versions={black.TargetVersion.PY36}, + ... line_length=10, + ... string_normalization=False, + ... is_pyi=False, + ... ), + ... ), + ... ) + def f( + arg: str = '', + ) -> None: + hey + """ src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) dst_contents = [] @@ -745,11 +792,9 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent: if supports_feature(versions, feature) } for current_line in lines.visit(src_node): - for _ in range(after): - dst_contents.append(str(empty_line)) + dst_contents.append(str(empty_line) * after) before, after = elt.maybe_empty_lines(current_line) - for _ in range(before): - dst_contents.append(str(empty_line)) + dst_contents.append(str(empty_line) * before) for line in split_line( current_line, line_length=mode.line_length, features=split_line_features ): @@ -865,8 +910,16 @@ class Visitor(Generic[T]): if node.type < 256: name = token.tok_name[node.type] else: - name = type_repr(node.type) - yield from getattr(self, f"visit_{name}", self.visit_default)(node) + name = str(type_repr(node.type)) + # We explicitly branch on whether a visitor exists (instead of + # using self.visit_default as the default arg to getattr) in order + # to save needing to create a bound method object and so mypyc can + # generate a native call to visit_default. + visitf = getattr(self, f"visit_{name}", None) + if visitf: + yield from visitf(node) + else: + yield from self.visit_default(node) def visit_default(self, node: LN) -> Iterator[T]: """Default `visit_*()` implementation. Recurses to children of `node`.""" @@ -911,8 +964,8 @@ class DebugVisitor(Visitor[T]): list(v.visit(code)) -WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE} -STATEMENT = { +WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE} +STATEMENT: Final = { syms.if_stmt, syms.while_stmt, syms.for_stmt, @@ -922,10 +975,10 @@ STATEMENT = { syms.funcdef, syms.classdef, } -STANDALONE_COMMENT = 153 +STANDALONE_COMMENT: Final = 153 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT" -LOGIC_OPERATORS = {"and", "or"} -COMPARATORS = { +LOGIC_OPERATORS: Final = {"and", "or"} +COMPARATORS: Final = { token.LESS, token.GREATER, token.EQEQUAL, @@ -933,7 +986,7 @@ COMPARATORS = { token.LESSEQUAL, token.GREATEREQUAL, } -MATH_OPERATORS = { +MATH_OPERATORS: Final = { token.VBAR, token.CIRCUMFLEX, token.AMPER, @@ -949,23 +1002,23 @@ MATH_OPERATORS = { token.TILDE, token.DOUBLESTAR, } -STARS = {token.STAR, token.DOUBLESTAR} -VARARGS_SPECIALS = STARS | {token.SLASH} -VARARGS_PARENTS = { +STARS: Final = {token.STAR, token.DOUBLESTAR} +VARARGS_SPECIALS: Final = STARS | {token.SLASH} +VARARGS_PARENTS: Final = { syms.arglist, syms.argument, # double star in arglist syms.trailer, # single argument to call syms.typedargslist, syms.varargslist, # lambdas } -UNPACKING_PARENTS = { +UNPACKING_PARENTS: Final = { syms.atom, # single element of a list or set literal syms.dictsetmaker, syms.listmaker, syms.testlist_gexp, syms.testlist_star_expr, } -TEST_DESCENDANTS = { +TEST_DESCENDANTS: Final = { syms.test, syms.lambdef, syms.or_test, @@ -982,7 +1035,7 @@ TEST_DESCENDANTS = { syms.term, syms.power, } -ASSIGNMENTS = { +ASSIGNMENTS: Final = { "=", "+=", "-=", @@ -998,13 +1051,13 @@ ASSIGNMENTS = { "**=", "//=", } -COMPREHENSION_PRIORITY = 20 -COMMA_PRIORITY = 18 -TERNARY_PRIORITY = 16 -LOGIC_PRIORITY = 14 -STRING_PRIORITY = 12 -COMPARATOR_PRIORITY = 10 -MATH_PRIORITIES = { +COMPREHENSION_PRIORITY: Final = 20 +COMMA_PRIORITY: Final = 18 +TERNARY_PRIORITY: Final = 16 +LOGIC_PRIORITY: Final = 14 +STRING_PRIORITY: Final = 12 +COMPARATOR_PRIORITY: Final = 10 +MATH_PRIORITIES: Final = { token.VBAR: 9, token.CIRCUMFLEX: 8, token.AMPER: 7, @@ -1020,7 +1073,7 @@ MATH_PRIORITIES = { token.TILDE: 3, token.DOUBLESTAR: 2, } -DOT_PRIORITY = 1 +DOT_PRIORITY: Final = 1 @dataclass @@ -1384,7 +1437,10 @@ class Line: for leaf_id, comments in self.comments.items(): for comment in comments: if is_type_comment(comment): - if leaf_id not in ignored_ids or comment_seen: + if comment_seen or ( + not is_type_comment(comment, " ignore") + and leaf_id not in ignored_ids + ): return True comment_seen = True @@ -1423,11 +1479,7 @@ class Line: return False def contains_multiline_strings(self) -> bool: - for leaf in self.leaves: - if is_multiline_string(leaf): - return True - - return False + return any(is_multiline_string(leaf) for leaf in self.leaves) def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: """Remove trailing comma if there is one and it's safe.""" @@ -1729,13 +1781,13 @@ class LineGenerator(Visitor[Line]): self.current_line.append(node) yield from super().visit_default(node) - def visit_INDENT(self, node: Node) -> Iterator[Line]: + def visit_INDENT(self, node: Leaf) -> Iterator[Line]: """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. yield from self.line(+1) yield from self.visit_default(node) - def visit_DEDENT(self, node: Node) -> Iterator[Line]: + def visit_DEDENT(self, node: Leaf) -> Iterator[Line]: """Decrease indentation level, maybe yield a line.""" # The current line might still wait for trailing comments. At DEDENT time # there won't be any (they would be prefixes on the preceding NEWLINE). @@ -2463,7 +2515,7 @@ def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ body_leaves: List[Leaf] = [] head_leaves: List[Leaf] = [] current_leaves = head_leaves - matching_bracket = None + matching_bracket: Optional[Leaf] = None for leaf in line.leaves: if ( current_leaves is body_leaves @@ -2506,8 +2558,8 @@ def right_hand_split( body_leaves: List[Leaf] = [] head_leaves: List[Leaf] = [] current_leaves = tail_leaves - opening_bracket = None - closing_bracket = None + opening_bracket: Optional[Leaf] = None + closing_bracket: Optional[Leaf] = None for leaf in reversed(line.leaves): if current_leaves is body_leaves: if leaf is opening_bracket: @@ -2811,7 +2863,7 @@ def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None: match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL) assert match is not None, f"failed to match string {leaf.value!r}" orig_prefix = match.group(1) - new_prefix = orig_prefix.lower() + new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u") if remove_u_prefix: new_prefix = new_prefix.replace("u", "") leaf.value = f"{new_prefix}{match.group(2)}" @@ -3028,7 +3080,7 @@ def convert_one_fmt_off_pair(node: Node) -> bool: # That happens when one of the `ignored_nodes` ended with a NEWLINE # leaf (possibly followed by a DEDENT). hidden_value = hidden_value[:-1] - first_idx = None + first_idx: Optional[int] = None for ignored in ignored_nodes: index = ignored.remove() if first_idx is None: @@ -3399,8 +3451,8 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf yield omit length = 4 * line.depth - opening_bracket = None - closing_bracket = None + opening_bracket: Optional[Leaf] = None + closing_bracket: Optional[Leaf] = None inner_brackets: Set[LeafID] = set() for index, leaf, leaf_length in enumerate_with_length(line, reversed=True): length += leaf_length @@ -3571,7 +3623,7 @@ def find_project_root(srcs: Iterable[str]) -> Path: # Append a fake file so `parents` below returns `common_base_dir`, too. common_base /= "fake-file" for directory in common_base.parents: - if (directory / ".git").is_dir(): + if (directory / ".git").exists(): return directory if (directory / ".hg").is_dir(): @@ -3588,6 +3640,7 @@ class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" check: bool = False + diff: bool = False quiet: bool = False verbose: bool = False change_count: int = 0 @@ -3597,7 +3650,7 @@ class Report: def done(self, src: Path, changed: Changed) -> None: """Increment the counter for successful reformatting. Write out a message.""" if changed is Changed.YES: - reformatted = "would reformat" if self.check else "reformatted" + reformatted = "would reformat" if self.check or self.diff else "reformatted" if self.verbose or not self.quiet: out(f"{reformatted} {src}") self.change_count += 1 @@ -3643,7 +3696,7 @@ class Report: Use `click.unstyle` to remove colors. """ - if self.check: + if self.check or self.diff: reformatted = "would be reformatted" unchanged = "would be left unchanged" failed = "would fail to reformat" @@ -3781,7 +3834,7 @@ def assert_equivalent(src: str, dst: str) -> None: ) from None -def assert_stable(src: str, dst: str, mode: FileMode) -> None: +def assert_stable(src: str, dst: str, mode: Mode) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" newdst = format_str(dst, mode=mode) if dst != newdst: @@ -3797,6 +3850,7 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None: ) from None +@mypyc_attr(patchable=True) def dump_to_file(*output: str) -> str: """Dump `output` to a temporary file. Return path to the file.""" with tempfile.NamedTemporaryFile( @@ -3829,7 +3883,7 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: ) -def cancel(tasks: Iterable[asyncio.Task]) -> None: +def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None: """asyncio signal handler that cancels all `tasks` and reports to stderr.""" err("Aborted!") for task in tasks: @@ -4051,11 +4105,11 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool: return False -def get_cache_file(mode: FileMode) -> Path: +def get_cache_file(mode: Mode) -> Path: return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle" -def read_cache(mode: FileMode) -> Cache: +def read_cache(mode: Mode) -> Cache: """Read the cache if it exists and is well formed. If it is not well formed, the call to write_cache later should resolve the issue. @@ -4095,7 +4149,7 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set return todo, done -def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None: +def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None: """Update the cache file.""" cache_file = get_cache_file(mode) try: