Callable,
Collection,
Dict,
+ Generator,
Generic,
Iterable,
Iterator,
NO = 0
YES = 1
DIFF = 2
+ CHECK = 3
@classmethod
def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
if check and not diff:
- return cls.NO
+ return cls.CHECK
return cls.DIFF if diff else cls.YES
mode=mode,
):
changed = Changed.YES
- if write_back == WriteBack.YES and changed is not Changed.NO:
+ if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
+ write_back is WriteBack.CHECK and changed is Changed.NO
+ ):
write_cache(cache, [src], line_length, mode)
report.done(src, changed)
except Exception as exc:
sources, cached = filter_cached(cache, sources)
for src in sorted(cached):
report.done(src, Changed.CACHED)
+ if not sources:
+ return
+
cancelled = []
- formatted = []
- if sources:
- lock = None
- if write_back == WriteBack.DIFF:
- # For diff output, we need locks to ensure we don't interleave output
- # from different processes.
- manager = Manager()
- lock = manager.Lock()
- tasks = {
- loop.run_in_executor(
- executor,
- format_file_in_place,
- src,
- line_length,
- fast,
- write_back,
- mode,
- lock,
- ): src
- for src in sorted(sources)
- }
- pending: Iterable[asyncio.Task] = tasks.keys()
- try:
- loop.add_signal_handler(signal.SIGINT, cancel, pending)
- loop.add_signal_handler(signal.SIGTERM, cancel, pending)
- except NotImplementedError:
- # There are no good alternatives for these on Windows
- pass
- while pending:
- done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
- for task in done:
- src = tasks.pop(task)
- if task.cancelled():
- cancelled.append(task)
- elif task.exception():
- report.failed(src, str(task.exception()))
- else:
- formatted.append(src)
- report.done(src, Changed.YES if task.result() else Changed.NO)
+ sources_to_cache = []
+ lock = None
+ if write_back == WriteBack.DIFF:
+ # For diff output, we need locks to ensure we don't interleave output
+ # from different processes.
+ manager = Manager()
+ lock = manager.Lock()
+ tasks = {
+ loop.run_in_executor(
+ executor,
+ format_file_in_place,
+ src,
+ line_length,
+ fast,
+ write_back,
+ mode,
+ lock,
+ ): src
+ for src in sorted(sources)
+ }
+ pending: Iterable[asyncio.Task] = tasks.keys()
+ try:
+ loop.add_signal_handler(signal.SIGINT, cancel, pending)
+ loop.add_signal_handler(signal.SIGTERM, cancel, pending)
+ except NotImplementedError:
+ # There are no good alternatives for these on Windows.
+ pass
+ while pending:
+ done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
+ for task in done:
+ src = tasks.pop(task)
+ if task.cancelled():
+ cancelled.append(task)
+ elif task.exception():
+ report.failed(src, str(task.exception()))
+ else:
+ changed = Changed.YES if task.result() else Changed.NO
+ # If the file was written back or was successfully checked as
+ # well-formatted, store this information in the cache.
+ if write_back is WriteBack.YES or (
+ write_back is WriteBack.CHECK and changed is Changed.NO
+ ):
+ sources_to_cache.append(src)
+ report.done(src, changed)
if cancelled:
await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
- if write_back == WriteBack.YES and formatted:
- write_cache(cache, formatted, line_length, mode)
+ if sources_to_cache:
+ write_cache(cache, sources_to_cache, line_length, mode)
def format_file_in_place(
) -> bool:
"""Format file under `src` path. Return True if changed.
- If `write_back` is True, write reformatted code back to stdout.
+ If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
+ code to the file.
`line_length` and `fast` options are passed to :func:`format_file_contents`.
"""
if src.suffix == ".pyi":
) -> bool:
"""Format file on stdin. Return True if changed.
- If `write_back` is True, write reformatted code back to stdout.
+ If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
+ write a diff to stdout.
`line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
:func:`format_file_contents`.
"""
remove_u_prefix=py36 or "unicode_literals" in future_imports,
is_pyi=is_pyi,
normalize_strings=normalize_strings,
+ allow_underscores=py36,
)
elt = EmptyLineTracker(is_pyi=is_pyi)
empty_line = Line()
syms.testlist_gexp,
syms.testlist_star_expr,
}
-SURROUNDED_BY_BRACKETS = {
- syms.typedargslist,
- syms.arglist,
- syms.subscriptlist,
- syms.vfplist,
- syms.import_as_names,
- syms.yield_expr,
- syms.testlist_gexp,
- syms.testlist_star_expr,
- syms.listmaker,
- syms.dictsetmaker,
-}
TEST_DESCENDANTS = {
syms.test,
syms.lambdef,
normalize_strings: bool = True
current_line: Line = Factory(Line)
remove_u_prefix: bool = False
+ allow_underscores: bool = False
def line(self, indent: int = 0) -> Iterator[Line]:
"""Generate a line.
if self.normalize_strings and node.type == token.STRING:
normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
normalize_string_quotes(node)
+ if node.type == token.NUMBER:
+ normalize_numeric_literal(node, self.allow_underscores)
if node.type not in WHITESPACE:
self.current_line.append(node)
yield from super().visit_default(node)
if parent.type == syms.file_input:
break
- if parent.type in SURROUNDED_BY_BRACKETS:
+ if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
break
container = parent
leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
+def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
+ """Normalizes numeric (float, int, and complex) literals.
+
+ All letters used in the representation are normalized to lowercase, long number
+ literals are split using underscores.
+ """
+ text = leaf.value.lower()
+ if text.startswith(("0o", "0x", "0b")):
+ # Leave octal, hex, and binary literals alone.
+ pass
+ elif "e" in text:
+ before, after = text.split("e")
+ sign = ""
+ if after.startswith("-"):
+ after = after[1:]
+ sign = "-"
+ elif after.startswith("+"):
+ after = after[1:]
+ before = format_float_or_int_string(before, allow_underscores)
+ after = format_int_string(after, allow_underscores)
+ text = f"{before}e{sign}{after}"
+ elif text.endswith(("j", "l")):
+ number = text[:-1]
+ suffix = text[-1]
+ text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
+ else:
+ text = format_float_or_int_string(text, allow_underscores)
+ leaf.value = text
+
+
+def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
+ """Formats a float string like "1.0"."""
+ if "." not in text:
+ return format_int_string(text, allow_underscores)
+
+ before, after = text.split(".")
+ before = format_int_string(before, allow_underscores) if before else "0"
+ after = format_int_string(after, allow_underscores) if after else "0"
+ return f"{before}.{after}"
+
+
+def format_int_string(text: str, allow_underscores: bool) -> str:
+ """Normalizes underscores in a string to e.g. 1_000_000.
+
+ Input must be a string of at least six digits and optional underscores.
+ """
+ if not allow_underscores:
+ return text
+
+ text = text.replace("_", "")
+ if len(text) <= 6:
+ # No underscores for numbers <= 6 digits long.
+ return text
+
+ return format(int(text), "3_")
+
+
def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
"""Make existing optional parentheses invisible or create new ones.
def get_future_imports(node: Node) -> Set[str]:
"""Return a set of __future__ imports in the file."""
- imports = set()
+ imports: Set[str] = set()
+
+ def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
+ for child in children:
+ if isinstance(child, Leaf):
+ if child.type == token.NAME:
+ yield child.value
+ elif child.type == syms.import_as_name:
+ orig_name = child.children[0]
+ assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
+ assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
+ yield orig_name.value
+ elif child.type == syms.import_as_names:
+ yield from get_imports_from_children(child.children)
+ else:
+ assert False, "Invalid syntax parsing imports"
+
for child in node.children:
if child.type != syms.simple_stmt:
break
module_name = first_child.children[1]
if not isinstance(module_name, Leaf) or module_name.value != "__future__":
break
- for import_from_child in first_child.children[3:]:
- if isinstance(import_from_child, Leaf):
- if import_from_child.type == token.NAME:
- imports.add(import_from_child.value)
- else:
- assert import_from_child.type == syms.import_as_names
- for leaf in import_from_child.children:
- if isinstance(leaf, Leaf) and leaf.type == token.NAME:
- imports.add(leaf.value)
+ imports |= set(get_imports_from_children(first_child.children[3:]))
else:
break
return imports