X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/41240e9a784fe11d9e1a76befaf09b7ab2d09893..33601ffa6d27f2c54864d0f54d65c65846bf8647:/black.py diff --git a/black.py b/black.py index f49e6df..b66ad0d 100644 --- a/black.py +++ b/black.py @@ -20,6 +20,7 @@ from typing import ( Callable, Collection, Dict, + Generator, Generic, Iterable, Iterator, @@ -93,11 +94,12 @@ class WriteBack(Enum): NO = 0 YES = 1 DIFF = 2 + CHECK = 3 @classmethod def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack": if check and not diff: - return cls.NO + return cls.CHECK return cls.DIFF if diff else cls.YES @@ -397,7 +399,14 @@ def reformat_one( mode=mode, ): changed = Changed.YES - if write_back == WriteBack.YES and changed is not Changed.NO: + if write_back is WriteBack.YES: + should_write = changed is not Changed.CACHED + elif write_back is WriteBack.CHECK: + should_write = changed is Changed.NO + else: + should_write = False + + if should_write: write_cache(cache, [src], line_length, mode) report.done(src, changed) except Exception as exc: @@ -465,11 +474,17 @@ async def schedule_formatting( elif task.exception(): report.failed(src, str(task.exception())) else: - formatted.append(src) - report.done(src, Changed.YES if task.result() else Changed.NO) + changed = Changed.YES if task.result() else Changed.NO + # In normal mode, write all files to the cache. + if write_back is WriteBack.YES: + formatted.append(src) + # In check mode, write only unchanged files to the cache. + elif write_back is WriteBack.CHECK and changed is Changed.NO: + formatted.append(src) + report.done(src, changed) if cancelled: await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) - if write_back == WriteBack.YES and formatted: + if write_back in (WriteBack.YES, WriteBack.CHECK) and formatted: write_cache(cache, formatted, line_length, mode) @@ -483,7 +498,8 @@ def format_file_in_place( ) -> bool: """Format file under `src` path. Return True if changed. - If `write_back` is True, write reformatted code back to stdout. + If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted + code to the file. `line_length` and `fast` options are passed to :func:`format_file_contents`. """ if src.suffix == ".pyi": @@ -532,7 +548,8 @@ def format_stdin_to_stdout( ) -> bool: """Format file on stdin. Return True if changed. - If `write_back` is True, write reformatted code back to stdout. + If `write_back` is YES, write reformatted code back to stdout. If it is DIFF, + write a diff to stdout. `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to :func:`format_file_contents`. """ @@ -604,6 +621,7 @@ def format_str( remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi, normalize_strings=normalize_strings, + allow_underscores=py36, ) elt = EmptyLineTracker(is_pyi=is_pyi) empty_line = Line() @@ -796,18 +814,6 @@ UNPACKING_PARENTS = { syms.testlist_gexp, syms.testlist_star_expr, } -SURROUNDED_BY_BRACKETS = { - syms.typedargslist, - syms.arglist, - syms.subscriptlist, - syms.vfplist, - syms.import_as_names, - syms.yield_expr, - syms.testlist_gexp, - syms.testlist_star_expr, - syms.listmaker, - syms.dictsetmaker, -} TEST_DESCENDANTS = { syms.test, syms.lambdef, @@ -1402,6 +1408,7 @@ class LineGenerator(Visitor[Line]): normalize_strings: bool = True current_line: Line = Factory(Line) remove_u_prefix: bool = False + allow_underscores: bool = False def line(self, indent: int = 0) -> Iterator[Line]: """Generate a line. @@ -1443,6 +1450,8 @@ class LineGenerator(Visitor[Line]): if self.normalize_strings and node.type == token.STRING: normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix) normalize_string_quotes(node) + if node.type == token.NUMBER: + normalize_numeric_literal(node, self.allow_underscores) if node.type not in WHITESPACE: self.current_line.append(node) yield from super().visit_default(node) @@ -1852,7 +1861,7 @@ def container_of(leaf: Leaf) -> LN: if parent.type == syms.file_input: break - if parent.type in SURROUNDED_BY_BRACKETS: + if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS: break container = parent @@ -2504,6 +2513,61 @@ def normalize_string_quotes(leaf: Leaf) -> None: leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}" +def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None: + """Normalizes numeric (float, int, and complex) literals.""" + # We want all letters (e in exponents, j in complex literals, a-f + # in hex literals) to be lowercase. + text = leaf.value.lower() + if text.startswith(("0o", "0x", "0b")): + # Leave octal, hex, and binary literals alone for now. + pass + elif "e" in text: + before, after = text.split("e") + if after.startswith("-"): + after = after[1:] + sign = "-" + elif after.startswith("+"): + after = after[1:] + sign = "" + else: + sign = "" + before = format_float_or_int_string(before, allow_underscores) + after = format_int_string(after, allow_underscores) + text = f"{before}e{sign}{after}" + # Complex numbers and Python 2 longs + elif "j" in text or "l" in text: + number = text[:-1] + suffix = text[-1] + text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}" + else: + text = format_float_or_int_string(text, allow_underscores) + leaf.value = text + + +def format_float_or_int_string(text: str, allow_underscores: bool) -> str: + """Formats a float string like "1.0".""" + if "." not in text: + return format_int_string(text, allow_underscores) + before, after = text.split(".") + before = format_int_string(before, allow_underscores) if before else "0" + after = format_int_string(after, allow_underscores) if after else "0" + return f"{before}.{after}" + + +def format_int_string(text: str, allow_underscores: bool) -> str: + """Normalizes underscores in a string to e.g. 1_000_000. + + Input must be a string consisting only of digits and underscores. + """ + if not allow_underscores: + return text + text = text.replace("_", "") + if len(text) <= 6: + # No underscores for numbers <= 6 digits long. + return text + return format(int(text), "3_") + + def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: """Make existing optional parentheses invisible or create new ones. @@ -2910,7 +2974,23 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf def get_future_imports(node: Node) -> Set[str]: """Return a set of __future__ imports in the file.""" - imports = set() + imports: Set[str] = set() + + def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]: + for child in children: + if isinstance(child, Leaf): + if child.type == token.NAME: + yield child.value + elif child.type == syms.import_as_name: + orig_name = child.children[0] + assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports" + assert orig_name.type == token.NAME, "Invalid syntax parsing imports" + yield orig_name.value + elif child.type == syms.import_as_names: + yield from get_imports_from_children(child.children) + else: + assert False, "Invalid syntax parsing imports" + for child in node.children: if child.type != syms.simple_stmt: break @@ -2929,15 +3009,7 @@ def get_future_imports(node: Node) -> Set[str]: module_name = first_child.children[1] if not isinstance(module_name, Leaf) or module_name.value != "__future__": break - for import_from_child in first_child.children[3:]: - if isinstance(import_from_child, Leaf): - if import_from_child.type == token.NAME: - imports.add(import_from_child.value) - else: - assert import_from_child.type == syms.import_as_names - for leaf in import_from_child.children: - if isinstance(leaf, Leaf) and leaf.type == token.NAME: - imports.add(leaf.value) + imports |= set(get_imports_from_children(first_child.children[3:])) else: break return imports