X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/7595dabb4387237b76e80bdee72fb6323b2d603b..3eab6d3131acd14b5900519d8447c8a1152e6d87:/black.py diff --git a/black.py b/black.py index efe5af9..7823ae0 100644 --- a/black.py +++ b/black.py @@ -24,6 +24,7 @@ from typing import ( List, Optional, Pattern, + Sequence, Set, Tuple, Type, @@ -41,6 +42,7 @@ from blib2to3 import pygram, pytree from blib2to3.pgen2 import driver, token from blib2to3.pgen2.parse import ParseError + __version__ = "18.4a6" DEFAULT_LINE_LENGTH = 88 @@ -409,9 +411,10 @@ def format_str(src_contents: str, line_length: int) -> FileContent: """ src_node = lib2to3_parse(src_contents) dst_contents = "" - lines = LineGenerator() - elt = EmptyLineTracker() + future_imports = get_future_imports(src_node) py36 = is_python36(src_node) + lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports) + elt = EmptyLineTracker() empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -974,17 +977,21 @@ class Line: self.comments.append((after, comment)) return True - def comments_after(self, leaf: Leaf) -> Iterator[Leaf]: - """Generate comments that should appear directly after `leaf`.""" - for _leaf_index, _leaf in enumerate(self.leaves): - if leaf is _leaf: - break + def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]: + """Generate comments that should appear directly after `leaf`. - else: - return + Provide a non-negative leaf `_index` to speed up the function. + """ + if _index == -1: + for _index, _leaf in enumerate(self.leaves): + if leaf is _leaf: + break + + else: + return for index, comment_after in self.comments: - if _leaf_index == index: + if _index == index: yield comment_after def remove_trailing_comma(self) -> None: @@ -1171,6 +1178,7 @@ class LineGenerator(Visitor[Line]): in ways that will no longer stringify to valid Python code on the tree. """ current_line: Line = Factory(Line) + remove_u_prefix: bool = False def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]: """Generate a line. @@ -1238,6 +1246,7 @@ class LineGenerator(Visitor[Line]): else: normalize_prefix(node, inside_brackets=any_open_brackets) if node.type == token.STRING: + normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix) normalize_string_quotes(node) if node.type not in WHITESPACE: self.current_line.append(node) @@ -1821,11 +1830,7 @@ def split_line( return line_str = str(line).strip("\n") - if ( - len(line_str) <= line_length - and "\n" not in line_str # multiline strings - and not line.contains_standalone_comments() - ): + if is_line_short_enough(line, line_length=line_length, line_str=line_str): yield line return @@ -1834,10 +1839,22 @@ def split_line( split_funcs = [left_hand_split] elif line.is_import: split_funcs = [explode_split] - elif line.inside_brackets: - split_funcs = [delimiter_split, standalone_comment_split, right_hand_split] else: - split_funcs = [right_hand_split] + + def rhs(line: Line, py36: bool = False) -> Iterator[Line]: + for omit in generate_trailers_to_omit(line, line_length): + lines = list(right_hand_split(line, py36, omit=omit)) + if is_line_short_enough(lines[0], line_length=line_length): + yield from lines + return + + # All splits failed, best effort split with no omits. + yield from right_hand_split(line, py36) + + if line.inside_brackets: + split_funcs = [delimiter_split, standalone_comment_split, rhs] + else: + split_funcs = [rhs] for split_func in split_funcs: # We are accumulating lines in `result` because we might want to abort # mission and return the original line in the end, or attempt a different @@ -1910,6 +1927,8 @@ def right_hand_split( """Split line into many lines, starting with the last matching bracket pair. If the split was by optional parentheses, attempt splitting without them, too. + `omit` is a collection of closing bracket IDs that shouldn't be considered for + this split. """ head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -2052,10 +2071,10 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) current_line.append(leaf) - for leaf in line.leaves: + for index, leaf in enumerate(line.leaves): yield from append_to_line(leaf) - for comment_after in line.comments_after(leaf): + for comment_after in line.comments_after(leaf, index): yield from append_to_line(comment_after) lowest_depth = min(lowest_depth, leaf.bracket_depth) @@ -2099,10 +2118,10 @@ def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) current_line.append(leaf) - for leaf in line.leaves: + for index, leaf in enumerate(line.leaves): yield from append_to_line(leaf) - for comment_after in line.comments_after(leaf): + for comment_after in line.comments_after(leaf, index): yield from append_to_line(comment_after) if current_line: @@ -2161,6 +2180,22 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: leaf.prefix = "" +def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None: + """Make all string prefixes lowercase. + + If remove_u_prefix is given, also removes any u prefix from the string. + + Note: Mutates its argument. + """ + match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL) + assert match is not None, f"failed to match string {leaf.value!r}" + orig_prefix = match.group(1) + new_prefix = orig_prefix.lower() + if remove_u_prefix: + new_prefix = new_prefix.replace("u", "") + leaf.value = f"{new_prefix}{match.group(2)}" + + def normalize_string_quotes(leaf: Leaf) -> None: """Prefer double quotes but only if it doesn't cause more escaping. @@ -2258,6 +2293,7 @@ def maybe_make_parens_invisible_in_atom(node: LN) -> bool: node.type != syms.atom or is_empty_tuple(node) or is_one_tuple(node) + or is_yield(node) or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY ): return False @@ -2308,6 +2344,27 @@ def is_one_tuple(node: LN) -> bool: ) +def is_yield(node: LN) -> bool: + """Return True if `node` holds a `yield` or `yield from` expression.""" + if node.type == syms.yield_expr: + return True + + if node.type == token.NAME and node.value == "yield": # type: ignore + return True + + if node.type != syms.atom: + return False + + if len(node.children) != 3: + return False + + lpar, expr, rpar = node.children + if lpar.type == token.LPAR and rpar.type == token.RPAR: + return is_yield(expr) + + return False + + def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: """Return True if `leaf` is a star or double star in a vararg or kwarg. @@ -2401,6 +2458,102 @@ def is_python36(node: Node) -> bool: return False +def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]: + """Generate sets of closing bracket IDs that should be omitted in a RHS. + + Brackets can be omitted if the entire trailer up to and including + a preceding closing bracket fits in one line. + + Yielded sets are cumulative (contain results of previous yields, too). First + set is empty. + """ + + omit: Set[LeafID] = set() + yield omit + + length = 4 * line.depth + opening_bracket = None + closing_bracket = None + optional_brackets: Set[LeafID] = set() + inner_brackets: Set[LeafID] = set() + for index, leaf in enumerate_reversed(line.leaves): + length += len(leaf.prefix) + len(leaf.value) + if length > line_length: + break + + comment: Optional[Leaf] + for comment in line.comments_after(leaf, index): + if "\n" in comment.prefix: + break # Oops, standalone comment! + + length += len(comment.value) + else: + comment = None + if comment is not None: + break # There was a standalone comment, we can't continue. + + optional_brackets.discard(id(leaf)) + if opening_bracket: + if leaf is opening_bracket: + opening_bracket = None + elif leaf.type in CLOSING_BRACKETS: + inner_brackets.add(id(leaf)) + elif leaf.type in CLOSING_BRACKETS: + if not leaf.value: + optional_brackets.add(id(opening_bracket)) + continue + + if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS: + # Empty brackets would fail a split so treat them as "inner" + # brackets (e.g. only add them to the `omit` set if another + # pair of brackets was good enough. + inner_brackets.add(id(leaf)) + continue + + opening_bracket = leaf.opening_bracket + if closing_bracket: + omit.add(id(closing_bracket)) + omit.update(inner_brackets) + inner_brackets.clear() + yield omit + closing_bracket = leaf + + +def get_future_imports(node: Node) -> Set[str]: + """Return a set of __future__ imports in the file.""" + imports = set() + for child in node.children: + if child.type != syms.simple_stmt: + break + first_child = child.children[0] + if isinstance(first_child, Leaf): + # Continue looking if we see a docstring; otherwise stop. + if ( + len(child.children) == 2 + and first_child.type == token.STRING + and child.children[1].type == token.NEWLINE + ): + continue + else: + break + elif first_child.type == syms.import_from: + module_name = first_child.children[1] + if not isinstance(module_name, Leaf) or module_name.value != "__future__": + break + for import_from_child in first_child.children[3:]: + if isinstance(import_from_child, Leaf): + if import_from_child.type == token.NAME: + imports.add(import_from_child.value) + else: + assert import_from_child.type == syms.import_as_names + for leaf in import_from_child.children: + if isinstance(leaf, Leaf) and leaf.type == token.NAME: + imports.add(leaf.value) + else: + break + return imports + + PYTHON_EXTENSIONS = {".py"} BLACKLISTED_DIRECTORIES = { "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv" @@ -2418,7 +2571,7 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]: yield from gen_python_files_in_dir(child) - elif child.suffix in PYTHON_EXTENSIONS: + elif child.is_file() and child.suffix in PYTHON_EXTENSIONS: yield child @@ -2643,6 +2796,28 @@ def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str: return regex.sub(replacement, regex.sub(replacement, original)) +def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]: + """Like `reversed(enumerate(sequence))` if that were possible.""" + index = len(sequence) - 1 + for element in reversed(sequence): + yield (index, element) + index -= 1 + + +def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool: + """Return True if `line` is no longer than `line_length`. + + Uses the provided `line_str` rendering, if any, otherwise computes a new one. + """ + if not line_str: + line_str = str(line).strip("\n") + return ( + len(line_str) <= line_length + and "\n" not in line_str # multiline strings + and not line.contains_standalone_comments() + ) + + CACHE_DIR = Path(user_cache_dir("black", version=__version__))