]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Simplify caching logic.
[etc/vim.git] / black.py
index b9eca0adbaf29bc170e77f595b5c905945819e82..ab2394e485729151a0eef7615c30b1fc2f7faf5e 100644 (file)
--- a/black.py
+++ b/black.py
@@ -20,6 +20,7 @@ from typing import (
     Callable,
     Collection,
     Dict,
+    Generator,
     Generic,
     Iterable,
     Iterator,
@@ -46,7 +47,7 @@ from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.6b2"
+__version__ = "18.6b4"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
@@ -93,11 +94,12 @@ class WriteBack(Enum):
     NO = 0
     YES = 1
     DIFF = 2
+    CHECK = 3
 
     @classmethod
     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
         if check and not diff:
-            return cls.NO
+            return cls.CHECK
 
         return cls.DIFF if diff else cls.YES
 
@@ -397,7 +399,9 @@ def reformat_one(
                 mode=mode,
             ):
                 changed = Changed.YES
-            if write_back == WriteBack.YES and changed is not Changed.NO:
+            if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
+                write_back is WriteBack.CHECK and changed is Changed.NO
+            ):
                 write_cache(cache, [src], line_length, mode)
         report.done(src, changed)
     except Exception as exc:
@@ -428,7 +432,7 @@ async def schedule_formatting(
         for src in sorted(cached):
             report.done(src, Changed.CACHED)
     cancelled = []
-    formatted = []
+    sources_to_cache = []
     if sources:
         lock = None
         if write_back == WriteBack.DIFF:
@@ -465,12 +469,18 @@ async def schedule_formatting(
                 elif task.exception():
                     report.failed(src, str(task.exception()))
                 else:
-                    formatted.append(src)
-                    report.done(src, Changed.YES if task.result() else Changed.NO)
+                    changed = Changed.YES if task.result() else Changed.NO
+                    # If the file was written back or was successfully checked as
+                    # well-formatted, store this information in the cache.
+                    if write_back is WriteBack.YES or (
+                        write_back is WriteBack.CHECK and changed is Changed.NO
+                    ):
+                        sources_to_cache.append(src)
+                    report.done(src, changed)
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
-    if write_back == WriteBack.YES and formatted:
-        write_cache(cache, formatted, line_length, mode)
+    if sources_to_cache:
+        write_cache(cache, sources_to_cache, line_length, mode)
 
 
 def format_file_in_place(
@@ -483,7 +493,8 @@ def format_file_in_place(
 ) -> bool:
     """Format file under `src` path. Return True if changed.
 
-    If `write_back` is True, write reformatted code back to stdout.
+    If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
+    code to the file.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
@@ -532,7 +543,8 @@ def format_stdin_to_stdout(
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
-    If `write_back` is True, write reformatted code back to stdout.
+    If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
+    write a diff to stdout.
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
@@ -604,6 +616,7 @@ def format_str(
         remove_u_prefix=py36 or "unicode_literals" in future_imports,
         is_pyi=is_pyi,
         normalize_strings=normalize_strings,
+        allow_underscores=py36,
     )
     elt = EmptyLineTracker(is_pyi=is_pyi)
     empty_line = Line()
@@ -796,18 +809,6 @@ UNPACKING_PARENTS = {
     syms.testlist_gexp,
     syms.testlist_star_expr,
 }
-SURROUNDED_BY_BRACKETS = {
-    syms.typedargslist,
-    syms.arglist,
-    syms.subscriptlist,
-    syms.vfplist,
-    syms.import_as_names,
-    syms.yield_expr,
-    syms.testlist_gexp,
-    syms.testlist_star_expr,
-    syms.listmaker,
-    syms.dictsetmaker,
-}
 TEST_DESCENDANTS = {
     syms.test,
     syms.lambdef,
@@ -1402,6 +1403,7 @@ class LineGenerator(Visitor[Line]):
     normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
+    allow_underscores: bool = False
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -1443,6 +1445,8 @@ class LineGenerator(Visitor[Line]):
             if self.normalize_strings and node.type == token.STRING:
                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                 normalize_string_quotes(node)
+            if node.type == token.NUMBER:
+                normalize_numeric_literal(node, self.allow_underscores)
             if node.type not in WHITESPACE:
                 self.current_line.append(node)
         yield from super().visit_default(node)
@@ -1852,7 +1856,7 @@ def container_of(leaf: Leaf) -> LN:
         if parent.type == syms.file_input:
             break
 
-        if parent.type in SURROUNDED_BY_BRACKETS:
+        if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
             break
 
         container = parent
@@ -2504,6 +2508,61 @@ def normalize_string_quotes(leaf: Leaf) -> None:
     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
 
 
+def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
+    """Normalizes numeric (float, int, and complex) literals."""
+    # We want all letters (e in exponents, j in complex literals, a-f
+    # in hex literals) to be lowercase.
+    text = leaf.value.lower()
+    if text.startswith(("0o", "0x", "0b")):
+        # Leave octal, hex, and binary literals alone for now.
+        pass
+    elif "e" in text:
+        before, after = text.split("e")
+        if after.startswith("-"):
+            after = after[1:]
+            sign = "-"
+        elif after.startswith("+"):
+            after = after[1:]
+            sign = ""
+        else:
+            sign = ""
+        before = format_float_or_int_string(before, allow_underscores)
+        after = format_int_string(after, allow_underscores)
+        text = f"{before}e{sign}{after}"
+    # Complex numbers and Python 2 longs
+    elif "j" in text or "l" in text:
+        number = text[:-1]
+        suffix = text[-1]
+        text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
+    else:
+        text = format_float_or_int_string(text, allow_underscores)
+    leaf.value = text
+
+
+def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
+    """Formats a float string like "1.0"."""
+    if "." not in text:
+        return format_int_string(text, allow_underscores)
+    before, after = text.split(".")
+    before = format_int_string(before, allow_underscores) if before else "0"
+    after = format_int_string(after, allow_underscores) if after else "0"
+    return f"{before}.{after}"
+
+
+def format_int_string(text: str, allow_underscores: bool) -> str:
+    """Normalizes underscores in a string to e.g. 1_000_000.
+
+    Input must be a string consisting only of digits and underscores.
+    """
+    if not allow_underscores:
+        return text
+    text = text.replace("_", "")
+    if len(text) <= 6:
+        # No underscores for numbers <= 6 digits long.
+        return text
+    return format(int(text), "3_")
+
+
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     """Make existing optional parentheses invisible or create new ones.
 
@@ -2577,6 +2636,9 @@ def convert_one_fmt_off_pair(node: Node) -> bool:
                         continue
 
                 ignored_nodes = list(generate_ignored_nodes(leaf))
+                if not ignored_nodes:
+                    continue
+
                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
                 parent = first.parent
                 prefix = first.prefix
@@ -2605,7 +2667,7 @@ def convert_one_fmt_off_pair(node: Node) -> bool:
                 )
                 return True
 
-            previous_consumed += comment.consumed
+            previous_consumed = comment.consumed
 
     return False
 
@@ -2907,7 +2969,23 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
 
 def get_future_imports(node: Node) -> Set[str]:
     """Return a set of __future__ imports in the file."""
-    imports = set()
+    imports: Set[str] = set()
+
+    def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
+        for child in children:
+            if isinstance(child, Leaf):
+                if child.type == token.NAME:
+                    yield child.value
+            elif child.type == syms.import_as_name:
+                orig_name = child.children[0]
+                assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
+                assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
+                yield orig_name.value
+            elif child.type == syms.import_as_names:
+                yield from get_imports_from_children(child.children)
+            else:
+                assert False, "Invalid syntax parsing imports"
+
     for child in node.children:
         if child.type != syms.simple_stmt:
             break
@@ -2926,15 +3004,7 @@ def get_future_imports(node: Node) -> Set[str]:
             module_name = first_child.children[1]
             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
                 break
-            for import_from_child in first_child.children[3:]:
-                if isinstance(import_from_child, Leaf):
-                    if import_from_child.type == token.NAME:
-                        imports.add(import_from_child.value)
-                else:
-                    assert import_from_child.type == syms.import_as_names
-                    for leaf in import_from_child.children:
-                        if isinstance(leaf, Leaf) and leaf.type == token.NAME:
-                            imports.add(leaf.value)
+            imports |= set(get_imports_from_children(first_child.children[3:]))
         else:
             break
     return imports
@@ -2950,7 +3020,7 @@ def gen_python_files_in_dir(
     """Generate all files under `path` whose paths are not excluded by the
     `exclude` regex, but are included by the `include` regex.
 
-    Symbolic links pointing outside of the root directory are ignored.
+    Symbolic links pointing outside of the `root` directory are ignored.
 
     `report` is where output about exclusions goes.
     """
@@ -2961,8 +3031,7 @@ def gen_python_files_in_dir(
         except ValueError:
             if child.is_symlink():
                 report.path_ignored(
-                    child,
-                    "is a symbolic link that points outside of the root directory",
+                    child, f"is a symbolic link that points outside {root}"
                 )
                 continue