X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/95fd5cb6485c21cf0294899a4b2cb99c453fa01d..7a14a37981862ef418f3cdb4a7e2375856f97529:/black.py diff --git a/black.py b/black.py index 953d532..26a2915 100644 --- a/black.py +++ b/black.py @@ -37,9 +37,11 @@ from typing import ( Union, cast, ) +from typing_extensions import Final +from mypy_extensions import mypyc_attr from appdirs import user_cache_dir -from attr import dataclass, evolve, Factory +from dataclasses import dataclass, field, replace import click import toml from typed_ast import ast3, ast27 @@ -184,8 +186,8 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { @dataclass -class FileMode: - target_versions: Set[TargetVersion] = Factory(set) +class Mode: + target_versions: Set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True is_pyi: bool = False @@ -207,10 +209,31 @@ class FileMode: return ".".join(parts) +# Legacy name, left for integrations. +FileMode = Mode + + def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool: return all(feature in VERSION_TO_FEATURES[version] for version in target_versions) +def find_pyproject_toml(path_search_start: str) -> Optional[str]: + """Find the absolute filepath to a pyproject.toml if it exists""" + path_project_root = find_project_root(path_search_start) + path_pyproject_toml = path_project_root / "pyproject.toml" + return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + + +def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: + """Parse a pyproject toml file, pulling out relevant parts for Black + + If parsing fails, will raise a toml.TomlDecodeError + """ + pyproject_toml = toml.load(path_config) + config = pyproject_toml.get("tool", {}).get("black", {}) + return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} + + def read_pyproject_toml( ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None] ) -> Optional[str]: @@ -221,16 +244,12 @@ def read_pyproject_toml( """ assert not isinstance(value, (int, bool)), "Invalid parameter type passed" if not value: - root = find_project_root(ctx.params.get("src", ())) - path = root / "pyproject.toml" - if path.is_file(): - value = str(path) - else: + value = find_pyproject_toml(ctx.params.get("src", ())) + if value is None: return None try: - pyproject_toml = toml.load(value) - config = pyproject_toml.get("tool", {}).get("black", {}) + config = parse_pyproject_toml(value) except (toml.TomlDecodeError, OSError) as e: raise click.FileError( filename=value, hint=f"Error reading configuration file: {e}" @@ -241,12 +260,21 @@ def read_pyproject_toml( if ctx.default_map is None: ctx.default_map = {} - ctx.default_map.update( # type: ignore # bad types in .pyi - {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} - ) + ctx.default_map.update(config) # type: ignore # bad types in .pyi return value +def target_version_option_callback( + c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...] +) -> List[TargetVersion]: + """Compute the target versions from a --target-version flag. + + This is its own function because mypy couldn't infer the type correctly + when it was a lambda, causing mypyc trouble. + """ + return [TargetVersion[val.upper()] for val in v] + + @click.command(context_settings=dict(help_option_names=["-h", "--help"])) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( @@ -261,7 +289,7 @@ def read_pyproject_toml( "-t", "--target-version", type=click.Choice([v.name.lower() for v in TargetVersion]), - callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v], + callback=target_version_option_callback, multiple=True, help=( "Python versions that should be supported by Black's output. [default: " @@ -366,7 +394,7 @@ def read_pyproject_toml( @click.option( "--config", type=click.Path( - exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False + exists=True, file_okay=True, dir_okay=False, readable=True, allow_dash=False ), is_eager=True, callback=read_pyproject_toml, @@ -388,14 +416,14 @@ def main( verbose: bool, include: str, exclude: str, - src: Tuple[str], + src: Tuple[str, ...], config: Optional[str], ) -> None: """The uncompromising code formatter.""" write_back = WriteBack.from_configuration(check=check, diff=diff) if target_version: if py36: - err(f"Cannot use both --target-version and --py36") + err("Cannot use both --target-version and --py36") ctx.exit(2) else: versions = set(target_version) @@ -408,7 +436,7 @@ def main( else: # We'll autodetect later. versions = set() - mode = FileMode( + mode = Mode( target_versions=versions, line_length=line_length, is_pyi=pyi, @@ -429,7 +457,7 @@ def main( except re.error: err(f"Invalid regular expression for exclude given: {exclude!r}") ctx.exit(2) - report = Report(check=check, quiet=quiet, verbose=verbose) + report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose) root = find_project_root(src) sources: Set[Path] = set() path_empty(src, quiet, verbose, ctx) @@ -470,7 +498,9 @@ def main( ctx.exit(report.return_code) -def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None: +def path_empty( + src: Tuple[str, ...], quiet: bool, verbose: bool, ctx: click.Context +) -> None: """ Exit if there is no `src` provided for formatting """ @@ -481,7 +511,7 @@ def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) def reformat_one( - src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report" + src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: """Reformat a single file under `src` without spawning child processes. @@ -514,11 +544,7 @@ def reformat_one( def reformat_many( - sources: Set[Path], - fast: bool, - write_back: WriteBack, - mode: FileMode, - report: "Report", + sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: """Reformat multiple files using a ProcessPoolExecutor.""" loop = asyncio.get_event_loop() @@ -548,7 +574,7 @@ async def schedule_formatting( sources: Set[Path], fast: bool, write_back: WriteBack, - mode: FileMode, + mode: Mode, report: "Report", loop: asyncio.AbstractEventLoop, executor: Executor, @@ -585,7 +611,7 @@ async def schedule_formatting( ): src for src in sorted(sources) } - pending: Iterable[asyncio.Future] = tasks.keys() + pending: Iterable["asyncio.Future[bool]"] = tasks.keys() try: loop.add_signal_handler(signal.SIGINT, cancel, pending) loop.add_signal_handler(signal.SIGTERM, cancel, pending) @@ -618,7 +644,7 @@ async def schedule_formatting( def format_file_in_place( src: Path, fast: bool, - mode: FileMode, + mode: Mode, write_back: WriteBack = WriteBack.NO, lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy ) -> bool: @@ -629,7 +655,7 @@ def format_file_in_place( `mode` and `fast` options are passed to :func:`format_file_contents`. """ if src.suffix == ".pyi": - mode = evolve(mode, is_pyi=True) + mode = replace(mode, is_pyi=True) then = datetime.utcfromtimestamp(src.stat().st_mtime) with open(src, "rb") as buf: @@ -639,10 +665,10 @@ def format_file_in_place( except NothingChanged: return False - if write_back == write_back.YES: + if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: f.write(dst_contents) - elif write_back == write_back.DIFF: + elif write_back == WriteBack.DIFF: now = datetime.utcnow() src_name = f"{src}\t{then} +0000" dst_name = f"{src}\t{now} +0000" @@ -662,7 +688,7 @@ def format_file_in_place( def format_stdin_to_stdout( - fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode + fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode ) -> bool: """Format file on stdin. Return True if changed. @@ -694,9 +720,7 @@ def format_stdin_to_stdout( f.detach() -def format_file_contents( - src_contents: str, *, fast: bool, mode: FileMode -) -> FileContent: +def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: """Reformat contents a file and return new contents. If `fast` is False, additionally confirm that the reformatted code is @@ -716,11 +740,34 @@ def format_file_contents( return dst_contents -def format_str(src_contents: str, *, mode: FileMode) -> FileContent: +def format_str(src_contents: str, *, mode: Mode) -> FileContent: """Reformat a string and return new contents. `mode` determines formatting options, such as how many characters per line are - allowed. + allowed. Example: + + >>> import black + >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode())) + def f(arg: str = "") -> None: + ... + + A more complex example: + >>> print( + ... black.format_str( + ... "def f(arg:str='')->None: hey", + ... mode=black.Mode( + ... target_versions={black.TargetVersion.PY36}, + ... line_length=10, + ... string_normalization=False, + ... is_pyi=False, + ... ), + ... ), + ... ) + def f( + arg: str = '', + ) -> None: + hey + """ src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) dst_contents = [] @@ -745,11 +792,9 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent: if supports_feature(versions, feature) } for current_line in lines.visit(src_node): - for _ in range(after): - dst_contents.append(str(empty_line)) + dst_contents.append(str(empty_line) * after) before, after = elt.maybe_empty_lines(current_line) - for _ in range(before): - dst_contents.append(str(empty_line)) + dst_contents.append(str(empty_line) * before) for line in split_line( current_line, line_length=mode.line_length, features=split_line_features ): @@ -865,8 +910,16 @@ class Visitor(Generic[T]): if node.type < 256: name = token.tok_name[node.type] else: - name = type_repr(node.type) - yield from getattr(self, f"visit_{name}", self.visit_default)(node) + name = str(type_repr(node.type)) + # We explicitly branch on whether a visitor exists (instead of + # using self.visit_default as the default arg to getattr) in order + # to save needing to create a bound method object and so mypyc can + # generate a native call to visit_default. + visitf = getattr(self, f"visit_{name}", None) + if visitf: + yield from visitf(node) + else: + yield from self.visit_default(node) def visit_default(self, node: LN) -> Iterator[T]: """Default `visit_*()` implementation. Recurses to children of `node`.""" @@ -911,8 +964,8 @@ class DebugVisitor(Visitor[T]): list(v.visit(code)) -WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE} -STATEMENT = { +WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE} +STATEMENT: Final = { syms.if_stmt, syms.while_stmt, syms.for_stmt, @@ -922,10 +975,10 @@ STATEMENT = { syms.funcdef, syms.classdef, } -STANDALONE_COMMENT = 153 +STANDALONE_COMMENT: Final = 153 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT" -LOGIC_OPERATORS = {"and", "or"} -COMPARATORS = { +LOGIC_OPERATORS: Final = {"and", "or"} +COMPARATORS: Final = { token.LESS, token.GREATER, token.EQEQUAL, @@ -933,7 +986,7 @@ COMPARATORS = { token.LESSEQUAL, token.GREATEREQUAL, } -MATH_OPERATORS = { +MATH_OPERATORS: Final = { token.VBAR, token.CIRCUMFLEX, token.AMPER, @@ -949,23 +1002,23 @@ MATH_OPERATORS = { token.TILDE, token.DOUBLESTAR, } -STARS = {token.STAR, token.DOUBLESTAR} -VARARGS_SPECIALS = STARS | {token.SLASH} -VARARGS_PARENTS = { +STARS: Final = {token.STAR, token.DOUBLESTAR} +VARARGS_SPECIALS: Final = STARS | {token.SLASH} +VARARGS_PARENTS: Final = { syms.arglist, syms.argument, # double star in arglist syms.trailer, # single argument to call syms.typedargslist, syms.varargslist, # lambdas } -UNPACKING_PARENTS = { +UNPACKING_PARENTS: Final = { syms.atom, # single element of a list or set literal syms.dictsetmaker, syms.listmaker, syms.testlist_gexp, syms.testlist_star_expr, } -TEST_DESCENDANTS = { +TEST_DESCENDANTS: Final = { syms.test, syms.lambdef, syms.or_test, @@ -982,7 +1035,7 @@ TEST_DESCENDANTS = { syms.term, syms.power, } -ASSIGNMENTS = { +ASSIGNMENTS: Final = { "=", "+=", "-=", @@ -998,13 +1051,13 @@ ASSIGNMENTS = { "**=", "//=", } -COMPREHENSION_PRIORITY = 20 -COMMA_PRIORITY = 18 -TERNARY_PRIORITY = 16 -LOGIC_PRIORITY = 14 -STRING_PRIORITY = 12 -COMPARATOR_PRIORITY = 10 -MATH_PRIORITIES = { +COMPREHENSION_PRIORITY: Final = 20 +COMMA_PRIORITY: Final = 18 +TERNARY_PRIORITY: Final = 16 +LOGIC_PRIORITY: Final = 14 +STRING_PRIORITY: Final = 12 +COMPARATOR_PRIORITY: Final = 10 +MATH_PRIORITIES: Final = { token.VBAR: 9, token.CIRCUMFLEX: 8, token.AMPER: 7, @@ -1020,7 +1073,7 @@ MATH_PRIORITIES = { token.TILDE: 3, token.DOUBLESTAR: 2, } -DOT_PRIORITY = 1 +DOT_PRIORITY: Final = 1 @dataclass @@ -1028,11 +1081,11 @@ class BracketTracker: """Keeps track of brackets on a line.""" depth: int = 0 - bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) - delimiters: Dict[LeafID, Priority] = Factory(dict) + bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict) + delimiters: Dict[LeafID, Priority] = field(default_factory=dict) previous: Optional[Leaf] = None - _for_loop_depths: List[int] = Factory(list) - _lambda_argument_depths: List[int] = Factory(list) + _for_loop_depths: List[int] = field(default_factory=list) + _lambda_argument_depths: List[int] = field(default_factory=list) def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. @@ -1160,9 +1213,10 @@ class Line: """Holds leaves and comments. Can be printed with `str(line)`.""" depth: int = 0 - leaves: List[Leaf] = Factory(list) - comments: Dict[LeafID, List[Leaf]] = Factory(dict) # keys ordered like `leaves` - bracket_tracker: BracketTracker = Factory(BracketTracker) + leaves: List[Leaf] = field(default_factory=list) + # keys ordered like `leaves` + comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict) + bracket_tracker: BracketTracker = field(default_factory=BracketTracker) inside_brackets: bool = False should_explode: bool = False @@ -1250,6 +1304,7 @@ class Line: """ if not self.leaves or len(self.leaves) < 4: return False + # Look for and address a trailing colon. if self.leaves[-1].type == token.COLON: closer = self.leaves[-2] @@ -1259,6 +1314,7 @@ class Line: close_index = -1 if closer.type not in CLOSING_BRACKETS or self.inside_brackets: return False + if closer.type == token.RPAR: # Tuples require an extra check, because if there's only # one element in the tuple removing the comma unmakes the @@ -1273,9 +1329,11 @@ class Line: for _open_index, leaf in enumerate(self.leaves): if leaf is opener: break + else: # Couldn't find the matching opening paren, play it safe. return False + commas = 0 comma_depth = self.leaves[close_index - 1].bracket_depth for leaf in self.leaves[_open_index + 1 : close_index]: @@ -1285,16 +1343,20 @@ class Line: # We haven't looked yet for the trailing comma because # we might also have caught noop parens. return self.leaves[close_index - 1].type == token.COMMA + elif commas == 1: return False # it's either a one-tuple or didn't have a trailing comma + if self.leaves[close_index - 1].type in CLOSING_BRACKETS: close_index -= 1 closer = self.leaves[close_index] if closer.type == token.RPAR: # TODO: this is a gut feeling. Will we ever see this? return False + if self.leaves[close_index - 1].type != token.COMMA: return False + return True @property @@ -1344,9 +1406,9 @@ class Line: def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool: """If so, needs to be split before emitting.""" for leaf in self.leaves: - if leaf.type == STANDALONE_COMMENT: - if leaf.bracket_depth <= depth_limit: - return True + if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit: + return True + return False def contains_uncollapsable_type_comments(self) -> bool: @@ -1375,7 +1437,10 @@ class Line: for leaf_id, comments in self.comments.items(): for comment in comments: if is_type_comment(comment): - if leaf_id not in ignored_ids or comment_seen: + if comment_seen or ( + not is_type_comment(comment, " ignore") + and leaf_id not in ignored_ids + ): return True comment_seen = True @@ -1414,16 +1479,13 @@ class Line: return False def contains_multiline_strings(self) -> bool: - for leaf in self.leaves: - if is_multiline_string(leaf): - return True - - return False + return any(is_multiline_string(leaf) for leaf in self.leaves) def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: """Remove trailing comma if there is one and it's safe.""" if not (self.leaves and self.leaves[-1].type == token.COMMA): return False + # We remove trailing commas only in the case of importing a # single name from a module. if not ( @@ -1488,6 +1550,7 @@ class Line: comment.type = STANDALONE_COMMENT comment.prefix = "" return False + last_leaf = self.leaves[-2] self.comments.setdefault(id(last_leaf), []).append(comment) return True @@ -1555,7 +1618,7 @@ class EmptyLineTracker: is_pyi: bool = False previous_line: Optional[Line] = None previous_after: int = 0 - previous_defs: List[int] = Factory(list) + previous_defs: List[int] = field(default_factory=list) def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: """Return the number of extra empty lines before and after the `current_line`. @@ -1669,7 +1732,7 @@ class LineGenerator(Visitor[Line]): is_pyi: bool = False normalize_strings: bool = True - current_line: Line = Factory(Line) + current_line: Line = field(default_factory=Line) remove_u_prefix: bool = False def line(self, indent: int = 0) -> Iterator[Line]: @@ -1718,13 +1781,13 @@ class LineGenerator(Visitor[Line]): self.current_line.append(node) yield from super().visit_default(node) - def visit_INDENT(self, node: Node) -> Iterator[Line]: + def visit_INDENT(self, node: Leaf) -> Iterator[Line]: """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. yield from self.line(+1) yield from self.visit_default(node) - def visit_DEDENT(self, node: Node) -> Iterator[Line]: + def visit_DEDENT(self, node: Leaf) -> Iterator[Line]: """Decrease indentation level, maybe yield a line.""" # The current line might still wait for trailing comments. At DEDENT time # there won't be any (they would be prefixes on the preceding NEWLINE). @@ -1834,7 +1897,7 @@ class LineGenerator(Visitor[Line]): node.insert_child(index, Node(syms.atom, [lpar, operand, rpar])) yield from self.visit_default(node) - def __attrs_post_init__(self) -> None: + def __post_init__(self) -> None: """You are in a twisty little maze of passages.""" v = self.visit_stmt Ø: Set[str] = set() @@ -2452,7 +2515,7 @@ def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ body_leaves: List[Leaf] = [] head_leaves: List[Leaf] = [] current_leaves = head_leaves - matching_bracket = None + matching_bracket: Optional[Leaf] = None for leaf in line.leaves: if ( current_leaves is body_leaves @@ -2495,8 +2558,8 @@ def right_hand_split( body_leaves: List[Leaf] = [] head_leaves: List[Leaf] = [] current_leaves = tail_leaves - opening_bracket = None - closing_bracket = None + opening_bracket: Optional[Leaf] = None + closing_bracket: Optional[Leaf] = None for leaf in reversed(line.leaves): if current_leaves is body_leaves: if leaf is opening_bracket: @@ -2619,11 +2682,11 @@ def bracket_split_build_line( for i in range(len(leaves) - 1, -1, -1): if leaves[i].type == STANDALONE_COMMENT: continue - elif leaves[i].type == token.COMMA: - break - else: + + if leaves[i].type != token.COMMA: leaves.insert(i + 1, Leaf(token.COMMA, ",")) - break + break + # Populate the line for leaf in leaves: result.append(leaf, preformatted=True) @@ -2800,7 +2863,7 @@ def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None: match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL) assert match is not None, f"failed to match string {leaf.value!r}" orig_prefix = match.group(1) - new_prefix = orig_prefix.lower() + new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u") if remove_u_prefix: new_prefix = new_prefix.replace("u", "") leaf.value = f"{new_prefix}{match.group(2)}" @@ -2867,6 +2930,7 @@ def normalize_string_quotes(leaf: Leaf) -> None: if "\\" in str(m): # Do not introduce backslashes in interpolated expressions return + if new_quote == '"""' and new_body[-1:] == '"': # edge case: new_body = new_body[:-1] + '\\"' @@ -2939,7 +3003,6 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: if pc.value in FMT_OFF: # This `node` has a prefix with `# fmt: off`, don't mess with parens. return - check_lpar = False for index, child in enumerate(list(node.children)): # Add parentheses around long tuple unpacking in assignments. @@ -2953,26 +3016,12 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: if check_lpar: if is_walrus_assignment(child): continue - if child.type == syms.atom: - # Determines if the underlying atom should be surrounded with - # invisible params - also makes parens invisible recursively - # within the atom and removes repeated invisible parens within - # the atom - should_surround_with_parens = maybe_make_parens_invisible_in_atom( - child, parent=node - ) - if should_surround_with_parens: - lpar = Leaf(token.LPAR, "") - rpar = Leaf(token.RPAR, "") - index = child.remove() or 0 - node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + if child.type == syms.atom: + if maybe_make_parens_invisible_in_atom(child, parent=node): + wrap_in_parentheses(node, child, visible=False) elif is_one_tuple(child): - # wrap child in visible parentheses - lpar = Leaf(token.LPAR, "(") - rpar = Leaf(token.RPAR, ")") - child.remove() - node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + wrap_in_parentheses(node, child, visible=True) elif node.type == syms.import_from: # "import from" nodes store parentheses directly as part of # the statement @@ -2987,15 +3036,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: break elif not (isinstance(child, Leaf) and is_multiline_string(child)): - # wrap child in invisible parentheses - lpar = Leaf(token.LPAR, "") - rpar = Leaf(token.RPAR, "") - index = child.remove() or 0 - prefix = child.prefix - child.prefix = "" - new_child = Node(syms.atom, [lpar, child, rpar]) - new_child.prefix = prefix - node.insert_child(index, new_child) + wrap_in_parentheses(node, child, visible=False) check_lpar = isinstance(child, Leaf) and child.value in parens_after @@ -3039,7 +3080,7 @@ def convert_one_fmt_off_pair(node: Node) -> bool: # That happens when one of the `ignored_nodes` ended with a NEWLINE # leaf (possibly followed by a DEDENT). hidden_value = hidden_value[:-1] - first_idx = None + first_idx: Optional[int] = None for ignored in ignored_nodes: index = ignored.remove() if first_idx is None: @@ -3068,9 +3109,14 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]: """ container: Optional[LN] = container_of(leaf) while container is not None and container.type != token.ENDMARKER: + is_fmt_on = False for comment in list_comments(container.prefix, is_endmarker=False): if comment.value in FMT_ON: - return + is_fmt_on = True + elif comment.value in FMT_OFF: + is_fmt_on = False + if is_fmt_on: + return yield container @@ -3147,6 +3193,7 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]: Parenthesis can be optional. Returns None otherwise""" if len(node.children) != 3: return None + lpar, wrapped, rpar = node.children if not (lpar.type == token.LPAR and rpar.type == token.RPAR): return None @@ -3154,6 +3201,24 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]: return wrapped +def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None: + """Wrap `child` in parentheses. + + This replaces `child` with an atom holding the parentheses and the old + child. That requires moving the prefix. + + If `visible` is False, the leaves will be valueless (and thus invisible). + """ + lpar = Leaf(token.LPAR, "(" if visible else "") + rpar = Leaf(token.RPAR, ")" if visible else "") + prefix = child.prefix + child.prefix = "" + index = child.remove() or 0 + new_child = Node(syms.atom, [lpar, child, rpar]) + new_child.prefix = prefix + parent.insert_child(index, new_child) + + def is_one_tuple(node: LN) -> bool: """Return True if `node` holds a tuple with one element, with or without parens.""" if node.type == syms.atom: @@ -3386,8 +3451,8 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf yield omit length = 4 * line.depth - opening_bracket = None - closing_bracket = None + opening_bracket: Optional[Leaf] = None + closing_bracket: Optional[Leaf] = None inner_brackets: Set[LeafID] = set() for index, leaf, leaf_length in enumerate_with_length(line, reversed=True): length += leaf_length @@ -3431,19 +3496,23 @@ def get_future_imports(node: Node) -> Set[str]: if isinstance(child, Leaf): if child.type == token.NAME: yield child.value + elif child.type == syms.import_as_name: orig_name = child.children[0] assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports" assert orig_name.type == token.NAME, "Invalid syntax parsing imports" yield orig_name.value + elif child.type == syms.import_as_names: yield from get_imports_from_children(child.children) + else: raise AssertionError("Invalid syntax parsing imports") for child in node.children: if child.type != syms.simple_stmt: break + first_child = child.children[0] if isinstance(first_child, Leaf): # Continue looking if we see a docstring; otherwise stop. @@ -3453,15 +3522,18 @@ def get_future_imports(node: Node) -> Set[str]: and child.children[1].type == token.NEWLINE ): continue - else: - break + + break + elif first_child.type == syms.import_from: module_name = first_child.children[1] if not isinstance(module_name, Leaf) or module_name.value != "__future__": break + imports |= set(get_imports_from_children(first_child.children[3:])) else: break + return imports @@ -3469,10 +3541,11 @@ def get_future_imports(node: Node) -> Set[str]: def get_gitignore(root: Path) -> PathSpec: """ Return a PathSpec matching gitignore content if present.""" gitignore = root / ".gitignore" - if not gitignore.is_file(): - return PathSpec.from_lines("gitwildmatch", []) - else: - return PathSpec.from_lines("gitwildmatch", gitignore.open()) + lines: List[str] = [] + if gitignore.is_file(): + with gitignore.open() as gf: + lines = gf.readlines() + return PathSpec.from_lines("gitwildmatch", lines) def gen_python_files_in_dir( @@ -3494,7 +3567,7 @@ def gen_python_files_in_dir( for child in path.iterdir(): # First ignore files matching .gitignore if gitignore.match_file(child.as_posix()): - report.path_ignored(child, f"matches the .gitignore file content") + report.path_ignored(child, "matches the .gitignore file content") continue # Then ignore with `exclude` option. @@ -3503,6 +3576,7 @@ def gen_python_files_in_dir( except OSError as e: report.path_ignored(child, f"cannot be read because {e}") continue + except ValueError: if child.is_symlink(): report.path_ignored( @@ -3517,7 +3591,7 @@ def gen_python_files_in_dir( exclude_match = exclude.search(normalized_path) if exclude_match and exclude_match.group(0): - report.path_ignored(child, f"matches the --exclude regular expression") + report.path_ignored(child, "matches the --exclude regular expression") continue if child.is_dir(): @@ -3549,7 +3623,7 @@ def find_project_root(srcs: Iterable[str]) -> Path: # Append a fake file so `parents` below returns `common_base_dir`, too. common_base /= "fake-file" for directory in common_base.parents: - if (directory / ".git").is_dir(): + if (directory / ".git").exists(): return directory if (directory / ".hg").is_dir(): @@ -3566,6 +3640,7 @@ class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" check: bool = False + diff: bool = False quiet: bool = False verbose: bool = False change_count: int = 0 @@ -3575,7 +3650,7 @@ class Report: def done(self, src: Path, changed: Changed) -> None: """Increment the counter for successful reformatting. Write out a message.""" if changed is Changed.YES: - reformatted = "would reformat" if self.check else "reformatted" + reformatted = "would reformat" if self.check or self.diff else "reformatted" if self.verbose or not self.quiet: out(f"{reformatted} {src}") self.change_count += 1 @@ -3621,7 +3696,7 @@ class Report: Use `click.unstyle` to remove colors. """ - if self.check: + if self.check or self.diff: reformatted = "would be reformatted" unchanged = "would be left unchanged" failed = "would fail to reformat" @@ -3669,14 +3744,15 @@ def _fixup_ast_constants( node: Union[ast.AST, ast3.AST, ast27.AST] ) -> Union[ast.AST, ast3.AST, ast27.AST]: """Map ast nodes deprecated in 3.8 to Constant.""" - # casts are required until this is released: - # https://github.com/python/typeshed/pull/3142 if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)): - return cast(ast.AST, ast.Constant(value=node.s)) - elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)): - return cast(ast.AST, ast.Constant(value=node.n)) - elif isinstance(node, (ast.NameConstant, ast3.NameConstant)): - return cast(ast.AST, ast.Constant(value=node.value)) + return ast.Constant(value=node.s) + + if isinstance(node, (ast.Num, ast3.Num, ast27.Num)): + return ast.Constant(value=node.n) + + if isinstance(node, (ast.NameConstant, ast3.NameConstant)): + return ast.Constant(value=node.value) + return node @@ -3690,7 +3766,7 @@ def assert_equivalent(src: str, dst: str) -> None: yield f"{' ' * depth}{node.__class__.__name__}(" - for field in sorted(node._fields): + for field in sorted(node._fields): # noqa: F402 # TypeIgnore has only one field 'lineno' which breaks this comparison type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore) if sys.version_info >= (3, 8): @@ -3716,6 +3792,7 @@ def assert_equivalent(src: str, dst: str) -> None: ): for item in item.elts: yield from _v(item, depth + 2) + elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)): yield from _v(item, depth + 2) @@ -3757,7 +3834,7 @@ def assert_equivalent(src: str, dst: str) -> None: ) from None -def assert_stable(src: str, dst: str, mode: FileMode) -> None: +def assert_stable(src: str, dst: str, mode: Mode) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" newdst = format_str(dst, mode=mode) if dst != newdst: @@ -3773,6 +3850,7 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None: ) from None +@mypyc_attr(patchable=True) def dump_to_file(*output: str) -> str: """Dump `output` to a temporary file. Return path to the file.""" with tempfile.NamedTemporaryFile( @@ -3798,14 +3876,14 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: """Return a unified diff string between strings `a` and `b`.""" import difflib - a_lines = [line + "\n" for line in a.split("\n")] - b_lines = [line + "\n" for line in b.split("\n")] + a_lines = [line + "\n" for line in a.splitlines()] + b_lines = [line + "\n" for line in b.splitlines()] return "".join( difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5) ) -def cancel(tasks: Iterable[asyncio.Task]) -> None: +def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None: """asyncio signal handler that cancels all `tasks` and reports to stderr.""" err("Aborted!") for task in tasks: @@ -4027,11 +4105,11 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool: return False -def get_cache_file(mode: FileMode) -> Path: +def get_cache_file(mode: Mode) -> Path: return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle" -def read_cache(mode: FileMode) -> Cache: +def read_cache(mode: Mode) -> Cache: """Read the cache if it exists and is well formed. If it is not well formed, the call to write_cache later should resolve the issue. @@ -4071,7 +4149,7 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set return todo, done -def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None: +def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None: """Update the cache file.""" cache_file = get_cache_file(mode) try: