List,
Optional,
Pattern,
+ Sequence,
Set,
Tuple,
Type,
from blib2to3.pgen2 import driver, token
from blib2to3.pgen2.parse import ParseError
+
__version__ = "18.4a6"
DEFAULT_LINE_LENGTH = 88
If `write_back` is True, write reformatted code back to stdout.
`line_length` and `fast` options are passed to :func:`format_file_contents`.
"""
+ is_pyi = src.suffix == ".pyi"
with tokenize.open(src) as src_buffer:
src_contents = src_buffer.read()
try:
dst_contents = format_file_contents(
- src_contents, line_length=line_length, fast=fast
+ src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
)
except NothingChanged:
return False
def format_file_contents(
- src_contents: str, line_length: int, fast: bool
+ src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
) -> FileContent:
"""Reformat contents a file and return new contents.
if src_contents.strip() == "":
raise NothingChanged
- dst_contents = format_str(src_contents, line_length=line_length)
+ dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
if src_contents == dst_contents:
raise NothingChanged
if not fast:
assert_equivalent(src_contents, dst_contents)
- assert_stable(src_contents, dst_contents, line_length=line_length)
+ assert_stable(
+ src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
+ )
return dst_contents
-def format_str(src_contents: str, line_length: int) -> FileContent:
+def format_str(
+ src_contents: str, line_length: int, *, is_pyi: bool = False
+) -> FileContent:
"""Reformat a string and return new contents.
`line_length` determines how many characters per line are allowed.
"""
src_node = lib2to3_parse(src_contents)
dst_contents = ""
- lines = LineGenerator()
- elt = EmptyLineTracker()
+ future_imports = get_future_imports(src_node)
+ elt = EmptyLineTracker(is_pyi=is_pyi)
py36 = is_python36(src_node)
+ lines = LineGenerator(
+ remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
+ )
empty_line = Line()
after = 0
for current_line in lines.visit(src_node):
bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
delimiters: Dict[LeafID, Priority] = Factory(dict)
previous: Optional[Leaf] = None
- _for_loop_variable: bool = False
- _lambda_arguments: bool = False
+ _for_loop_variable: int = 0
+ _lambda_arguments: int = 0
def mark(self, leaf: Leaf) -> None:
"""Mark `leaf` with bracket-related metadata. Keep track of delimiters.
"""
if leaf.type == token.NAME and leaf.value == "for":
self.depth += 1
- self._for_loop_variable = True
+ self._for_loop_variable += 1
return True
return False
"""See `maybe_increment_for_loop_variable` above for explanation."""
if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
self.depth -= 1
- self._for_loop_variable = False
+ self._for_loop_variable -= 1
return True
return False
"""
if leaf.type == token.NAME and leaf.value == "lambda":
self.depth += 1
- self._lambda_arguments = True
+ self._lambda_arguments += 1
return True
return False
"""See `maybe_increment_lambda_arguments` above for explanation."""
if self._lambda_arguments and leaf.type == token.COLON:
self.depth -= 1
- self._lambda_arguments = False
+ self._lambda_arguments -= 1
return True
return False
and self.leaves[0].value == "class"
)
+ @property
+ def is_stub_class(self) -> bool:
+ """Is this line a class definition with a body consisting only of "..."?"""
+ return (
+ self.is_class
+ and self.leaves[-3:] == [Leaf(token.DOT, ".") for _ in range(3)]
+ )
+
@property
def is_def(self) -> bool:
"""Is this a function definition? (Also returns True for async defs.)"""
self.comments.append((after, comment))
return True
- def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
- """Generate comments that should appear directly after `leaf`."""
- for _leaf_index, _leaf in enumerate(self.leaves):
- if leaf is _leaf:
- break
+ def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
+ """Generate comments that should appear directly after `leaf`.
- else:
- return
+ Provide a non-negative leaf `_index` to speed up the function.
+ """
+ if _index == -1:
+ for _index, _leaf in enumerate(self.leaves):
+ if leaf is _leaf:
+ break
+
+ else:
+ return
for index, comment_after in self.comments:
- if _leaf_index == index:
+ if _index == index:
yield comment_after
def remove_trailing_comma(self) -> None:
the prefix of the first leaf consists of optional newlines. Those newlines
are consumed by `maybe_empty_lines()` and included in the computation.
"""
+ is_pyi: bool = False
previous_line: Optional[Line] = None
previous_after: int = 0
previous_defs: List[int] = Factory(list)
def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
max_allowed = 1
if current_line.depth == 0:
- max_allowed = 2
+ max_allowed = 1 if self.is_pyi else 2
if current_line.leaves:
# Consume the first leaf's extra newlines.
first_leaf = current_line.leaves[0]
depth = current_line.depth
while self.previous_defs and self.previous_defs[-1] >= depth:
self.previous_defs.pop()
- before = 1 if depth else 2
+ if self.is_pyi:
+ before = 0 if depth else 1
+ else:
+ before = 1 if depth else 2
is_decorator = current_line.is_decorator
if is_decorator or current_line.is_def or current_line.is_class:
if not is_decorator:
):
return 0, 0
- newlines = 2
- if current_line.depth:
+ if self.is_pyi:
+ if self.previous_line.depth > current_line.depth:
+ newlines = 1
+ elif current_line.is_class or self.previous_line.is_class:
+ if current_line.is_stub_class and self.previous_line.is_stub_class:
+ newlines = 0
+ else:
+ newlines = 1
+ else:
+ newlines = 0
+ else:
+ newlines = 2
+ if current_line.depth and newlines:
newlines -= 1
return newlines, 0
Note: destroys the tree it's visiting by mutating prefixes of its leaves
in ways that will no longer stringify to valid Python code on the tree.
"""
+ is_pyi: bool = False
current_line: Line = Factory(Line)
+ remove_u_prefix: bool = False
def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
"""Generate a line.
else:
normalize_prefix(node, inside_brackets=any_open_brackets)
if node.type == token.STRING:
+ normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
normalize_string_quotes(node)
if node.type not in WHITESPACE:
self.current_line.append(node)
yield from self.visit(child)
+ def visit_suite(self, node: Node) -> Iterator[Line]:
+ """Visit a suite."""
+ if self.is_pyi and is_stub_suite(node):
+ yield from self.visit(node.children[2])
+ else:
+ yield from self.visit_default(node)
+
def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
"""Visit a statement without nested statements."""
is_suite_like = node.parent and node.parent.type in STATEMENT
if is_suite_like:
- yield from self.line(+1)
- yield from self.visit_default(node)
- yield from self.line(-1)
+ if self.is_pyi and is_stub_body(node):
+ yield from self.visit_default(node)
+ else:
+ yield from self.line(+1)
+ yield from self.visit_default(node)
+ yield from self.line(-1)
else:
- yield from self.line()
+ if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
+ yield from self.line()
yield from self.visit_default(node)
def visit_async_stmt(self, node: Node) -> Iterator[Line]:
v = self.visit_stmt
Ø: Set[str] = set()
self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
- self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
+ self.visit_if_stmt = partial(
+ v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
+ )
self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
self.visit_try_stmt = partial(
return
line_str = str(line).strip("\n")
- if (
- len(line_str) <= line_length
- and "\n" not in line_str # multiline strings
- and not line.contains_standalone_comments()
- ):
+ if is_line_short_enough(line, line_length=line_length, line_str=line_str):
yield line
return
split_funcs = [left_hand_split]
elif line.is_import:
split_funcs = [explode_split]
- elif line.inside_brackets:
- split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
else:
- split_funcs = [right_hand_split]
+
+ def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
+ for omit in generate_trailers_to_omit(line, line_length):
+ lines = list(right_hand_split(line, py36, omit=omit))
+ if is_line_short_enough(lines[0], line_length=line_length):
+ yield from lines
+ return
+
+ # All splits failed, best effort split with no omits.
+ yield from right_hand_split(line, py36)
+
+ if line.inside_brackets:
+ split_funcs = [delimiter_split, standalone_comment_split, rhs]
+ else:
+ split_funcs = [rhs]
for split_func in split_funcs:
# We are accumulating lines in `result` because we might want to abort
# mission and return the original line in the end, or attempt a different
"""Split line into many lines, starting with the last matching bracket pair.
If the split was by optional parentheses, attempt splitting without them, too.
+ `omit` is a collection of closing bracket IDs that shouldn't be considered for
+ this split.
+
+ Note: running this function modifies `bracket_depth` on the leaves of `line`.
"""
head = Line(depth=line.depth)
body = Line(depth=line.depth + 1, inside_brackets=True)
current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
current_line.append(leaf)
- for leaf in line.leaves:
+ for index, leaf in enumerate(line.leaves):
yield from append_to_line(leaf)
- for comment_after in line.comments_after(leaf):
+ for comment_after in line.comments_after(leaf, index):
yield from append_to_line(comment_after)
lowest_depth = min(lowest_depth, leaf.bracket_depth)
current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
current_line.append(leaf)
- for leaf in line.leaves:
+ for index, leaf in enumerate(line.leaves):
yield from append_to_line(leaf)
- for comment_after in line.comments_after(leaf):
+ for comment_after in line.comments_after(leaf, index):
yield from append_to_line(comment_after)
if current_line:
leaf.prefix = ""
+def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
+ """Make all string prefixes lowercase.
+
+ If remove_u_prefix is given, also removes any u prefix from the string.
+
+ Note: Mutates its argument.
+ """
+ match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
+ assert match is not None, f"failed to match string {leaf.value!r}"
+ orig_prefix = match.group(1)
+ new_prefix = orig_prefix.lower()
+ if remove_u_prefix:
+ new_prefix = new_prefix.replace("u", "")
+ leaf.value = f"{new_prefix}{match.group(2)}"
+
+
def normalize_string_quotes(leaf: Leaf) -> None:
"""Prefer double quotes but only if it doesn't cause more escaping.
node.type != syms.atom
or is_empty_tuple(node)
or is_one_tuple(node)
+ or is_yield(node)
or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
):
return False
)
+def is_yield(node: LN) -> bool:
+ """Return True if `node` holds a `yield` or `yield from` expression."""
+ if node.type == syms.yield_expr:
+ return True
+
+ if node.type == token.NAME and node.value == "yield": # type: ignore
+ return True
+
+ if node.type != syms.atom:
+ return False
+
+ if len(node.children) != 3:
+ return False
+
+ lpar, expr, rpar = node.children
+ if lpar.type == token.LPAR and rpar.type == token.RPAR:
+ return is_yield(expr)
+
+ return False
+
+
def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
"""Return True if `leaf` is a star or double star in a vararg or kwarg.
If `within` includes VARARGS_PARENTS, this applies to function signatures.
- If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
- hand-side extended iterable unpacking (PEP 3132) and additional unpacking
+ If `within` includes UNPACKING_PARENTS, it applies to right hand-side
+ extended iterable unpacking (PEP 3132) and additional unpacking
generalizations (PEP 448).
"""
if leaf.type not in STARS or not leaf.parent:
return p.type in within
+def is_stub_suite(node: Node) -> bool:
+ """Return True if `node` is a suite with a stub body."""
+ if (
+ len(node.children) != 4
+ or node.children[0].type != token.NEWLINE
+ or node.children[1].type != token.INDENT
+ or node.children[3].type != token.DEDENT
+ ):
+ return False
+
+ return is_stub_body(node.children[2])
+
+
+def is_stub_body(node: LN) -> bool:
+ """Return True if `node` is a simple statement containing an ellipsis."""
+ if not isinstance(node, Node) or node.type != syms.simple_stmt:
+ return False
+
+ if len(node.children) != 2:
+ return False
+
+ child = node.children[0]
+ return (
+ child.type == syms.atom
+ and len(child.children) == 3
+ and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
+ )
+
+
def max_delimiter_priority_in_atom(node: LN) -> int:
"""Return maximum delimiter priority inside `node`.
return False
-PYTHON_EXTENSIONS = {".py"}
+def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
+ """Generate sets of closing bracket IDs that should be omitted in a RHS.
+
+ Brackets can be omitted if the entire trailer up to and including
+ a preceding closing bracket fits in one line.
+
+ Yielded sets are cumulative (contain results of previous yields, too). First
+ set is empty.
+ """
+
+ omit: Set[LeafID] = set()
+ yield omit
+
+ length = 4 * line.depth
+ opening_bracket = None
+ closing_bracket = None
+ optional_brackets: Set[LeafID] = set()
+ inner_brackets: Set[LeafID] = set()
+ for index, leaf in enumerate_reversed(line.leaves):
+ length += len(leaf.prefix) + len(leaf.value)
+ if length > line_length:
+ break
+
+ comment: Optional[Leaf]
+ for comment in line.comments_after(leaf, index):
+ if "\n" in comment.prefix:
+ break # Oops, standalone comment!
+
+ length += len(comment.value)
+ else:
+ comment = None
+ if comment is not None:
+ break # There was a standalone comment, we can't continue.
+
+ optional_brackets.discard(id(leaf))
+ if opening_bracket:
+ if leaf is opening_bracket:
+ opening_bracket = None
+ elif leaf.type in CLOSING_BRACKETS:
+ inner_brackets.add(id(leaf))
+ elif leaf.type in CLOSING_BRACKETS:
+ if not leaf.value:
+ optional_brackets.add(id(opening_bracket))
+ continue
+
+ if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
+ # Empty brackets would fail a split so treat them as "inner"
+ # brackets (e.g. only add them to the `omit` set if another
+ # pair of brackets was good enough.
+ inner_brackets.add(id(leaf))
+ continue
+
+ opening_bracket = leaf.opening_bracket
+ if closing_bracket:
+ omit.add(id(closing_bracket))
+ omit.update(inner_brackets)
+ inner_brackets.clear()
+ yield omit
+ closing_bracket = leaf
+
+
+def get_future_imports(node: Node) -> Set[str]:
+ """Return a set of __future__ imports in the file."""
+ imports = set()
+ for child in node.children:
+ if child.type != syms.simple_stmt:
+ break
+ first_child = child.children[0]
+ if isinstance(first_child, Leaf):
+ # Continue looking if we see a docstring; otherwise stop.
+ if (
+ len(child.children) == 2
+ and first_child.type == token.STRING
+ and child.children[1].type == token.NEWLINE
+ ):
+ continue
+ else:
+ break
+ elif first_child.type == syms.import_from:
+ module_name = first_child.children[1]
+ if not isinstance(module_name, Leaf) or module_name.value != "__future__":
+ break
+ for import_from_child in first_child.children[3:]:
+ if isinstance(import_from_child, Leaf):
+ if import_from_child.type == token.NAME:
+ imports.add(import_from_child.value)
+ else:
+ assert import_from_child.type == syms.import_as_names
+ for leaf in import_from_child.children:
+ if isinstance(leaf, Leaf) and leaf.type == token.NAME:
+ imports.add(leaf.value)
+ else:
+ break
+ return imports
+
+
+PYTHON_EXTENSIONS = {".py", ".pyi"}
BLACKLISTED_DIRECTORIES = {
"build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
}
yield from gen_python_files_in_dir(child)
- elif child.suffix in PYTHON_EXTENSIONS:
+ elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
yield child
) from None
-def assert_stable(src: str, dst: str, line_length: int) -> None:
+def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
- newdst = format_str(dst, line_length=line_length)
+ newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
if dst != newdst:
log = dump_to_file(
diff(src, dst, "source", "first pass"),
return regex.sub(replacement, regex.sub(replacement, original))
+def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
+ """Like `reversed(enumerate(sequence))` if that were possible."""
+ index = len(sequence) - 1
+ for element in reversed(sequence):
+ yield (index, element)
+ index -= 1
+
+
+def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
+ """Return True if `line` is no longer than `line_length`.
+
+ Uses the provided `line_str` rendering, if any, otherwise computes a new one.
+ """
+ if not line_str:
+ line_str = str(line).strip("\n")
+ return (
+ len(line_str) <= line_length
+ and "\n" not in line_str # multiline strings
+ and not line.contains_standalone_comments()
+ )
+
+
CACHE_DIR = Path(user_cache_dir("black", version=__version__))