]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

vim: Restore cursor/window position after format (#433)
[etc/vim.git] / black.py
index 0da2fad34a98422e732415776e893d98916209cf..b66ad0d00b9e7c1574011c5146c8f6e380488973 100644 (file)
--- a/black.py
+++ b/black.py
@@ -20,6 +20,7 @@ from typing import (
     Callable,
     Collection,
     Dict,
+    Generator,
     Generic,
     Iterable,
     Iterator,
@@ -46,7 +47,7 @@ from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.6b3"
+__version__ = "18.6b4"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
@@ -93,11 +94,12 @@ class WriteBack(Enum):
     NO = 0
     YES = 1
     DIFF = 2
+    CHECK = 3
 
     @classmethod
     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
         if check and not diff:
-            return cls.NO
+            return cls.CHECK
 
         return cls.DIFF if diff else cls.YES
 
@@ -397,7 +399,14 @@ def reformat_one(
                 mode=mode,
             ):
                 changed = Changed.YES
-            if write_back == WriteBack.YES and changed is not Changed.NO:
+            if write_back is WriteBack.YES:
+                should_write = changed is not Changed.CACHED
+            elif write_back is WriteBack.CHECK:
+                should_write = changed is Changed.NO
+            else:
+                should_write = False
+
+            if should_write:
                 write_cache(cache, [src], line_length, mode)
         report.done(src, changed)
     except Exception as exc:
@@ -465,11 +474,17 @@ async def schedule_formatting(
                 elif task.exception():
                     report.failed(src, str(task.exception()))
                 else:
-                    formatted.append(src)
-                    report.done(src, Changed.YES if task.result() else Changed.NO)
+                    changed = Changed.YES if task.result() else Changed.NO
+                    # In normal mode, write all files to the cache.
+                    if write_back is WriteBack.YES:
+                        formatted.append(src)
+                    # In check mode, write only unchanged files to the cache.
+                    elif write_back is WriteBack.CHECK and changed is Changed.NO:
+                        formatted.append(src)
+                    report.done(src, changed)
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
-    if write_back == WriteBack.YES and formatted:
+    if write_back in (WriteBack.YES, WriteBack.CHECK) and formatted:
         write_cache(cache, formatted, line_length, mode)
 
 
@@ -483,7 +498,8 @@ def format_file_in_place(
 ) -> bool:
     """Format file under `src` path. Return True if changed.
 
-    If `write_back` is True, write reformatted code back to stdout.
+    If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
+    code to the file.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
@@ -532,7 +548,8 @@ def format_stdin_to_stdout(
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
-    If `write_back` is True, write reformatted code back to stdout.
+    If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
+    write a diff to stdout.
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
@@ -604,6 +621,7 @@ def format_str(
         remove_u_prefix=py36 or "unicode_literals" in future_imports,
         is_pyi=is_pyi,
         normalize_strings=normalize_strings,
+        allow_underscores=py36,
     )
     elt = EmptyLineTracker(is_pyi=is_pyi)
     empty_line = Line()
@@ -796,18 +814,6 @@ UNPACKING_PARENTS = {
     syms.testlist_gexp,
     syms.testlist_star_expr,
 }
-SURROUNDED_BY_BRACKETS = {
-    syms.typedargslist,
-    syms.arglist,
-    syms.subscriptlist,
-    syms.vfplist,
-    syms.import_as_names,
-    syms.yield_expr,
-    syms.testlist_gexp,
-    syms.testlist_star_expr,
-    syms.listmaker,
-    syms.dictsetmaker,
-}
 TEST_DESCENDANTS = {
     syms.test,
     syms.lambdef,
@@ -1402,6 +1408,7 @@ class LineGenerator(Visitor[Line]):
     normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
+    allow_underscores: bool = False
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -1443,6 +1450,8 @@ class LineGenerator(Visitor[Line]):
             if self.normalize_strings and node.type == token.STRING:
                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                 normalize_string_quotes(node)
+            if node.type == token.NUMBER:
+                normalize_numeric_literal(node, self.allow_underscores)
             if node.type not in WHITESPACE:
                 self.current_line.append(node)
         yield from super().visit_default(node)
@@ -1852,7 +1861,7 @@ def container_of(leaf: Leaf) -> LN:
         if parent.type == syms.file_input:
             break
 
-        if parent.type in SURROUNDED_BY_BRACKETS:
+        if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
             break
 
         container = parent
@@ -2504,6 +2513,61 @@ def normalize_string_quotes(leaf: Leaf) -> None:
     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
 
 
+def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
+    """Normalizes numeric (float, int, and complex) literals."""
+    # We want all letters (e in exponents, j in complex literals, a-f
+    # in hex literals) to be lowercase.
+    text = leaf.value.lower()
+    if text.startswith(("0o", "0x", "0b")):
+        # Leave octal, hex, and binary literals alone for now.
+        pass
+    elif "e" in text:
+        before, after = text.split("e")
+        if after.startswith("-"):
+            after = after[1:]
+            sign = "-"
+        elif after.startswith("+"):
+            after = after[1:]
+            sign = ""
+        else:
+            sign = ""
+        before = format_float_or_int_string(before, allow_underscores)
+        after = format_int_string(after, allow_underscores)
+        text = f"{before}e{sign}{after}"
+    # Complex numbers and Python 2 longs
+    elif "j" in text or "l" in text:
+        number = text[:-1]
+        suffix = text[-1]
+        text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
+    else:
+        text = format_float_or_int_string(text, allow_underscores)
+    leaf.value = text
+
+
+def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
+    """Formats a float string like "1.0"."""
+    if "." not in text:
+        return format_int_string(text, allow_underscores)
+    before, after = text.split(".")
+    before = format_int_string(before, allow_underscores) if before else "0"
+    after = format_int_string(after, allow_underscores) if after else "0"
+    return f"{before}.{after}"
+
+
+def format_int_string(text: str, allow_underscores: bool) -> str:
+    """Normalizes underscores in a string to e.g. 1_000_000.
+
+    Input must be a string consisting only of digits and underscores.
+    """
+    if not allow_underscores:
+        return text
+    text = text.replace("_", "")
+    if len(text) <= 6:
+        # No underscores for numbers <= 6 digits long.
+        return text
+    return format(int(text), "3_")
+
+
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     """Make existing optional parentheses invisible or create new ones.
 
@@ -2608,7 +2672,7 @@ def convert_one_fmt_off_pair(node: Node) -> bool:
                 )
                 return True
 
-            previous_consumed += comment.consumed
+            previous_consumed = comment.consumed
 
     return False
 
@@ -2910,7 +2974,23 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
 
 def get_future_imports(node: Node) -> Set[str]:
     """Return a set of __future__ imports in the file."""
-    imports = set()
+    imports: Set[str] = set()
+
+    def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
+        for child in children:
+            if isinstance(child, Leaf):
+                if child.type == token.NAME:
+                    yield child.value
+            elif child.type == syms.import_as_name:
+                orig_name = child.children[0]
+                assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
+                assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
+                yield orig_name.value
+            elif child.type == syms.import_as_names:
+                yield from get_imports_from_children(child.children)
+            else:
+                assert False, "Invalid syntax parsing imports"
+
     for child in node.children:
         if child.type != syms.simple_stmt:
             break
@@ -2929,15 +3009,7 @@ def get_future_imports(node: Node) -> Set[str]:
             module_name = first_child.children[1]
             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
                 break
-            for import_from_child in first_child.children[3:]:
-                if isinstance(import_from_child, Leaf):
-                    if import_from_child.type == token.NAME:
-                        imports.add(import_from_child.value)
-                else:
-                    assert import_from_child.type == syms.import_as_names
-                    for leaf in import_from_child.children:
-                        if isinstance(leaf, Leaf) and leaf.type == token.NAME:
-                            imports.add(leaf.value)
+            imports |= set(get_imports_from_children(first_child.children[3:]))
         else:
             break
     return imports