Callable,
Collection,
Dict,
+ Generator,
Generic,
Iterable,
Iterator,
from blib2to3.pgen2.parse import ParseError
-__version__ = "18.6b3"
+__version__ = "18.6b4"
DEFAULT_LINE_LENGTH = 88
DEFAULT_EXCLUDES = (
r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
NO = 0
YES = 1
DIFF = 2
+ CHECK = 3
@classmethod
def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
if check and not diff:
- return cls.NO
+ return cls.CHECK
return cls.DIFF if diff else cls.YES
mode=mode,
):
changed = Changed.YES
- if write_back == WriteBack.YES and changed is not Changed.NO:
+ if write_back is WriteBack.YES:
+ should_write = changed is not Changed.CACHED
+ elif write_back is WriteBack.CHECK:
+ should_write = changed is Changed.NO
+ else:
+ should_write = False
+
+ if should_write:
write_cache(cache, [src], line_length, mode)
report.done(src, changed)
except Exception as exc:
elif task.exception():
report.failed(src, str(task.exception()))
else:
- formatted.append(src)
- report.done(src, Changed.YES if task.result() else Changed.NO)
+ changed = Changed.YES if task.result() else Changed.NO
+ # In normal mode, write all files to the cache.
+ if write_back is WriteBack.YES:
+ formatted.append(src)
+ # In check mode, write only unchanged files to the cache.
+ elif write_back is WriteBack.CHECK and changed is Changed.NO:
+ formatted.append(src)
+ report.done(src, changed)
if cancelled:
await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
- if write_back == WriteBack.YES and formatted:
+ if write_back in (WriteBack.YES, WriteBack.CHECK) and formatted:
write_cache(cache, formatted, line_length, mode)
) -> bool:
"""Format file under `src` path. Return True if changed.
- If `write_back` is True, write reformatted code back to stdout.
+ If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
+ code to the file.
`line_length` and `fast` options are passed to :func:`format_file_contents`.
"""
if src.suffix == ".pyi":
) -> bool:
"""Format file on stdin. Return True if changed.
- If `write_back` is True, write reformatted code back to stdout.
+ If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
+ write a diff to stdout.
`line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
:func:`format_file_contents`.
"""
remove_u_prefix=py36 or "unicode_literals" in future_imports,
is_pyi=is_pyi,
normalize_strings=normalize_strings,
+ allow_underscores=py36,
)
elt = EmptyLineTracker(is_pyi=is_pyi)
empty_line = Line()
syms.testlist_gexp,
syms.testlist_star_expr,
}
-SURROUNDED_BY_BRACKETS = {
- syms.typedargslist,
- syms.arglist,
- syms.subscriptlist,
- syms.vfplist,
- syms.import_as_names,
- syms.yield_expr,
- syms.testlist_gexp,
- syms.testlist_star_expr,
- syms.listmaker,
- syms.dictsetmaker,
-}
TEST_DESCENDANTS = {
syms.test,
syms.lambdef,
normalize_strings: bool = True
current_line: Line = Factory(Line)
remove_u_prefix: bool = False
+ allow_underscores: bool = False
def line(self, indent: int = 0) -> Iterator[Line]:
"""Generate a line.
if self.normalize_strings and node.type == token.STRING:
normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
normalize_string_quotes(node)
+ if node.type == token.NUMBER:
+ normalize_numeric_literal(node, self.allow_underscores)
if node.type not in WHITESPACE:
self.current_line.append(node)
yield from super().visit_default(node)
if parent.type == syms.file_input:
break
- if parent.type in SURROUNDED_BY_BRACKETS:
+ if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
break
container = parent
leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
+def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
+ """Normalizes numeric (float, int, and complex) literals."""
+ # We want all letters (e in exponents, j in complex literals, a-f
+ # in hex literals) to be lowercase.
+ text = leaf.value.lower()
+ if text.startswith(("0o", "0x", "0b")):
+ # Leave octal, hex, and binary literals alone for now.
+ pass
+ elif "e" in text:
+ before, after = text.split("e")
+ if after.startswith("-"):
+ after = after[1:]
+ sign = "-"
+ elif after.startswith("+"):
+ after = after[1:]
+ sign = ""
+ else:
+ sign = ""
+ before = format_float_or_int_string(before, allow_underscores)
+ after = format_int_string(after, allow_underscores)
+ text = f"{before}e{sign}{after}"
+ # Complex numbers and Python 2 longs
+ elif "j" in text or "l" in text:
+ number = text[:-1]
+ suffix = text[-1]
+ text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
+ else:
+ text = format_float_or_int_string(text, allow_underscores)
+ leaf.value = text
+
+
+def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
+ """Formats a float string like "1.0"."""
+ if "." not in text:
+ return format_int_string(text, allow_underscores)
+ before, after = text.split(".")
+ before = format_int_string(before, allow_underscores) if before else "0"
+ after = format_int_string(after, allow_underscores) if after else "0"
+ return f"{before}.{after}"
+
+
+def format_int_string(text: str, allow_underscores: bool) -> str:
+ """Normalizes underscores in a string to e.g. 1_000_000.
+
+ Input must be a string consisting only of digits and underscores.
+ """
+ if not allow_underscores:
+ return text
+ text = text.replace("_", "")
+ if len(text) <= 6:
+ # No underscores for numbers <= 6 digits long.
+ return text
+ return format(int(text), "3_")
+
+
def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
"""Make existing optional parentheses invisible or create new ones.
def get_future_imports(node: Node) -> Set[str]:
"""Return a set of __future__ imports in the file."""
- imports = set()
+ imports: Set[str] = set()
+
+ def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
+ for child in children:
+ if isinstance(child, Leaf):
+ if child.type == token.NAME:
+ yield child.value
+ elif child.type == syms.import_as_name:
+ orig_name = child.children[0]
+ assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
+ assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
+ yield orig_name.value
+ elif child.type == syms.import_as_names:
+ yield from get_imports_from_children(child.children)
+ else:
+ assert False, "Invalid syntax parsing imports"
+
for child in node.children:
if child.type != syms.simple_stmt:
break
module_name = first_child.children[1]
if not isinstance(module_name, Leaf) or module_name.value != "__future__":
break
- for import_from_child in first_child.children[3:]:
- if isinstance(import_from_child, Leaf):
- if import_from_child.type == token.NAME:
- imports.add(import_from_child.value)
- else:
- assert import_from_child.type == syms.import_as_names
- for leaf in import_from_child.children:
- if isinstance(leaf, Leaf) and leaf.type == token.NAME:
- imports.add(leaf.value)
+ imports |= set(get_imports_from_children(first_child.children[3:]))
else:
break
return imports