]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add build & dist directories to .gitignore (#487)
[etc/vim.git] / black.py
index 7edf2aea059bf2b9c2a84f056e160e7c14b96a82..3a51f21f4cd16b22e63a76918e4e46709f680316 100644 (file)
--- a/black.py
+++ b/black.py
@@ -94,11 +94,12 @@ class WriteBack(Enum):
     NO = 0
     YES = 1
     DIFF = 2
     NO = 0
     YES = 1
     DIFF = 2
+    CHECK = 3
 
     @classmethod
     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
         if check and not diff:
 
     @classmethod
     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
         if check and not diff:
-            return cls.NO
+            return cls.CHECK
 
         return cls.DIFF if diff else cls.YES
 
 
         return cls.DIFF if diff else cls.YES
 
@@ -169,7 +170,7 @@ def read_pyproject_toml(
     "--line-length",
     type=int,
     default=DEFAULT_LINE_LENGTH,
     "--line-length",
     type=int,
     default=DEFAULT_LINE_LENGTH,
-    help="How many character per line to allow.",
+    help="How many characters per line to allow.",
     show_default=True,
 )
 @click.option(
     show_default=True,
 )
 @click.option(
@@ -398,7 +399,9 @@ def reformat_one(
                 mode=mode,
             ):
                 changed = Changed.YES
                 mode=mode,
             ):
                 changed = Changed.YES
-            if write_back == WriteBack.YES and changed is not Changed.NO:
+            if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
+                write_back is WriteBack.CHECK and changed is Changed.NO
+            ):
                 write_cache(cache, [src], line_length, mode)
         report.done(src, changed)
     except Exception as exc:
                 write_cache(cache, [src], line_length, mode)
         report.done(src, changed)
     except Exception as exc:
@@ -428,50 +431,58 @@ async def schedule_formatting(
         sources, cached = filter_cached(cache, sources)
         for src in sorted(cached):
             report.done(src, Changed.CACHED)
         sources, cached = filter_cached(cache, sources)
         for src in sorted(cached):
             report.done(src, Changed.CACHED)
+    if not sources:
+        return
+
     cancelled = []
     cancelled = []
-    formatted = []
-    if sources:
-        lock = None
-        if write_back == WriteBack.DIFF:
-            # For diff output, we need locks to ensure we don't interleave output
-            # from different processes.
-            manager = Manager()
-            lock = manager.Lock()
-        tasks = {
-            loop.run_in_executor(
-                executor,
-                format_file_in_place,
-                src,
-                line_length,
-                fast,
-                write_back,
-                mode,
-                lock,
-            ): src
-            for src in sorted(sources)
-        }
-        pending: Iterable[asyncio.Task] = tasks.keys()
-        try:
-            loop.add_signal_handler(signal.SIGINT, cancel, pending)
-            loop.add_signal_handler(signal.SIGTERM, cancel, pending)
-        except NotImplementedError:
-            # There are no good alternatives for these on Windows
-            pass
-        while pending:
-            done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
-            for task in done:
-                src = tasks.pop(task)
-                if task.cancelled():
-                    cancelled.append(task)
-                elif task.exception():
-                    report.failed(src, str(task.exception()))
-                else:
-                    formatted.append(src)
-                    report.done(src, Changed.YES if task.result() else Changed.NO)
+    sources_to_cache = []
+    lock = None
+    if write_back == WriteBack.DIFF:
+        # For diff output, we need locks to ensure we don't interleave output
+        # from different processes.
+        manager = Manager()
+        lock = manager.Lock()
+    tasks = {
+        loop.run_in_executor(
+            executor,
+            format_file_in_place,
+            src,
+            line_length,
+            fast,
+            write_back,
+            mode,
+            lock,
+        ): src
+        for src in sorted(sources)
+    }
+    pending: Iterable[asyncio.Task] = tasks.keys()
+    try:
+        loop.add_signal_handler(signal.SIGINT, cancel, pending)
+        loop.add_signal_handler(signal.SIGTERM, cancel, pending)
+    except NotImplementedError:
+        # There are no good alternatives for these on Windows.
+        pass
+    while pending:
+        done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
+        for task in done:
+            src = tasks.pop(task)
+            if task.cancelled():
+                cancelled.append(task)
+            elif task.exception():
+                report.failed(src, str(task.exception()))
+            else:
+                changed = Changed.YES if task.result() else Changed.NO
+                # If the file was written back or was successfully checked as
+                # well-formatted, store this information in the cache.
+                if write_back is WriteBack.YES or (
+                    write_back is WriteBack.CHECK and changed is Changed.NO
+                ):
+                    sources_to_cache.append(src)
+                report.done(src, changed)
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
-    if write_back == WriteBack.YES and formatted:
-        write_cache(cache, formatted, line_length, mode)
+    if sources_to_cache:
+        write_cache(cache, sources_to_cache, line_length, mode)
 
 
 def format_file_in_place(
 
 
 def format_file_in_place(
@@ -484,7 +495,8 @@ def format_file_in_place(
 ) -> bool:
     """Format file under `src` path. Return True if changed.
 
 ) -> bool:
     """Format file under `src` path. Return True if changed.
 
-    If `write_back` is True, write reformatted code back to stdout.
+    If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
+    code to the file.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
@@ -533,7 +545,8 @@ def format_stdin_to_stdout(
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
-    If `write_back` is True, write reformatted code back to stdout.
+    If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
+    write a diff to stdout.
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
@@ -864,8 +877,8 @@ class BracketTracker:
     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
     delimiters: Dict[LeafID, Priority] = Factory(dict)
     previous: Optional[Leaf] = None
     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
     delimiters: Dict[LeafID, Priority] = Factory(dict)
     previous: Optional[Leaf] = None
-    _for_loop_variable: int = 0
-    _lambda_arguments: int = 0
+    _for_loop_depths: List[int] = Factory(list)
+    _lambda_argument_depths: List[int] = Factory(list)
 
     def mark(self, leaf: Leaf) -> None:
         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
 
     def mark(self, leaf: Leaf) -> None:
         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
@@ -938,16 +951,21 @@ class BracketTracker:
         """
         if leaf.type == token.NAME and leaf.value == "for":
             self.depth += 1
         """
         if leaf.type == token.NAME and leaf.value == "for":
             self.depth += 1
-            self._for_loop_variable += 1
+            self._for_loop_depths.append(self.depth)
             return True
 
         return False
 
     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
         """See `maybe_increment_for_loop_variable` above for explanation."""
             return True
 
         return False
 
     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
         """See `maybe_increment_for_loop_variable` above for explanation."""
-        if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
+        if (
+            self._for_loop_depths
+            and self._for_loop_depths[-1] == self.depth
+            and leaf.type == token.NAME
+            and leaf.value == "in"
+        ):
             self.depth -= 1
             self.depth -= 1
-            self._for_loop_variable -= 1
+            self._for_loop_depths.pop()
             return True
 
         return False
             return True
 
         return False
@@ -960,16 +978,20 @@ class BracketTracker:
         """
         if leaf.type == token.NAME and leaf.value == "lambda":
             self.depth += 1
         """
         if leaf.type == token.NAME and leaf.value == "lambda":
             self.depth += 1
-            self._lambda_arguments += 1
+            self._lambda_argument_depths.append(self.depth)
             return True
 
         return False
 
     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
         """See `maybe_increment_lambda_arguments` above for explanation."""
             return True
 
         return False
 
     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
         """See `maybe_increment_lambda_arguments` above for explanation."""
-        if self._lambda_arguments and leaf.type == token.COLON:
+        if (
+            self._lambda_argument_depths
+            and self._lambda_argument_depths[-1] == self.depth
+            and leaf.type == token.COLON
+        ):
             self.depth -= 1
             self.depth -= 1
-            self._lambda_arguments -= 1
+            self._lambda_argument_depths.pop()
             return True
 
         return False
             return True
 
         return False
@@ -1151,7 +1173,7 @@ class Line:
             self.remove_trailing_comma()
             return True
 
             self.remove_trailing_comma()
             return True
 
-        # Otheriwsse, if the trailing one is the only one, we might mistakenly
+        # Otherwise, if the trailing one is the only one, we might mistakenly
         # change a tuple into a different type by removing the comma.
         depth = closing.bracket_depth + 1
         commas = 0
         # change a tuple into a different type by removing the comma.
         depth = closing.bracket_depth + 1
         commas = 0
@@ -1364,7 +1386,7 @@ class EmptyLineTracker:
                 newlines = 1
             elif current_line.is_class or self.previous_line.is_class:
                 if current_line.is_stub_class and self.previous_line.is_stub_class:
                 newlines = 1
             elif current_line.is_class or self.previous_line.is_class:
                 if current_line.is_stub_class and self.previous_line.is_stub_class:
-                    # No blank line between classes with an emty body
+                    # No blank line between classes with an empty body
                     newlines = 0
                 else:
                     newlines = 1
                     newlines = 0
                 else:
                     newlines = 1
@@ -2498,30 +2520,32 @@ def normalize_string_quotes(leaf: Leaf) -> None:
 
 
 def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
 
 
 def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
-    """Normalizes numeric (float, int, and complex) literals."""
-    # We want all letters (e in exponents, j in complex literals, a-f
-    # in hex literals) to be lowercase.
+    """Normalizes numeric (float, int, and complex) literals.
+
+    All letters used in the representation are normalized to lowercase (except
+    in Python 2 long literals), and long number literals are split using underscores.
+    """
     text = leaf.value.lower()
     if text.startswith(("0o", "0x", "0b")):
     text = leaf.value.lower()
     if text.startswith(("0o", "0x", "0b")):
-        # Leave octal, hex, and binary literals alone for now.
+        # Leave octal, hex, and binary literals alone.
         pass
     elif "e" in text:
         before, after = text.split("e")
         pass
     elif "e" in text:
         before, after = text.split("e")
+        sign = ""
         if after.startswith("-"):
             after = after[1:]
             sign = "-"
         elif after.startswith("+"):
             after = after[1:]
         if after.startswith("-"):
             after = after[1:]
             sign = "-"
         elif after.startswith("+"):
             after = after[1:]
-            sign = ""
-        else:
-            sign = ""
         before = format_float_or_int_string(before, allow_underscores)
         after = format_int_string(after, allow_underscores)
         text = f"{before}e{sign}{after}"
         before = format_float_or_int_string(before, allow_underscores)
         after = format_int_string(after, allow_underscores)
         text = f"{before}e{sign}{after}"
-    # Complex numbers and Python 2 longs
-    elif "j" in text or "l" in text:
+    elif text.endswith(("j", "l")):
         number = text[:-1]
         suffix = text[-1]
         number = text[:-1]
         suffix = text[-1]
+        # Capitalize in "2L" because "l" looks too similar to "1".
+        if suffix == "l":
+            suffix = "L"
         text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
     else:
         text = format_float_or_int_string(text, allow_underscores)
         text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
     else:
         text = format_float_or_int_string(text, allow_underscores)
@@ -2532,24 +2556,40 @@ def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
     """Formats a float string like "1.0"."""
     if "." not in text:
         return format_int_string(text, allow_underscores)
     """Formats a float string like "1.0"."""
     if "." not in text:
         return format_int_string(text, allow_underscores)
+
     before, after = text.split(".")
     before = format_int_string(before, allow_underscores) if before else "0"
     before, after = text.split(".")
     before = format_int_string(before, allow_underscores) if before else "0"
-    after = format_int_string(after, allow_underscores) if after else "0"
+    if after:
+        after = format_int_string(after, allow_underscores, count_from_end=False)
+    else:
+        after = "0"
     return f"{before}.{after}"
 
 
     return f"{before}.{after}"
 
 
-def format_int_string(text: str, allow_underscores: bool) -> str:
+def format_int_string(
+    text: str, allow_underscores: bool, count_from_end: bool = True
+) -> str:
     """Normalizes underscores in a string to e.g. 1_000_000.
 
     """Normalizes underscores in a string to e.g. 1_000_000.
 
-    Input must be a string consisting only of digits and underscores.
+    Input must be a string of digits and optional underscores.
+    If count_from_end is False, we add underscores after groups of three digits
+    counting from the beginning instead of the end of the strings. This is used
+    for the fractional part of float literals.
     """
     if not allow_underscores:
         return text
     """
     if not allow_underscores:
         return text
+
     text = text.replace("_", "")
     if len(text) <= 6:
         # No underscores for numbers <= 6 digits long.
         return text
     text = text.replace("_", "")
     if len(text) <= 6:
         # No underscores for numbers <= 6 digits long.
         return text
-    return format(int(text), "3_")
+
+    if count_from_end:
+        # Avoid removing leading zeros, which are important if we're formatting
+        # part of a number like "0.001".
+        return format(int("1" + text), "3_")[1:].lstrip("_")
+    else:
+        return "_".join(text[i : i + 3] for i in range(0, len(text), 3))
 
 
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
 
 
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
@@ -2570,7 +2610,11 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     for index, child in enumerate(list(node.children)):
         if check_lpar:
             if child.type == syms.atom:
     for index, child in enumerate(list(node.children)):
         if check_lpar:
             if child.type == syms.atom:
-                maybe_make_parens_invisible_in_atom(child)
+                if maybe_make_parens_invisible_in_atom(child):
+                    lpar = Leaf(token.LPAR, "")
+                    rpar = Leaf(token.RPAR, "")
+                    index = child.remove() or 0
+                    node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
             elif is_one_tuple(child):
                 # wrap child in visible parentheses
                 lpar = Leaf(token.LPAR, "(")
             elif is_one_tuple(child):
                 # wrap child in visible parentheses
                 lpar = Leaf(token.LPAR, "(")
@@ -2678,7 +2722,11 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
 
 
 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
 
 
 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
-    """If it's safe, make the parens in the atom `node` invisible, recursively."""
+    """If it's safe, make the parens in the atom `node` invisible, recursively.
+
+    Returns whether the node should itself be wrapped in invisible parentheses.
+
+    """
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
@@ -2696,9 +2744,9 @@ def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
         last.value = ""  # type: ignore
         if len(node.children) > 1:
             maybe_make_parens_invisible_in_atom(node.children[1])
         last.value = ""  # type: ignore
         if len(node.children) > 1:
             maybe_make_parens_invisible_in_atom(node.children[1])
-        return True
+        return False
 
 
-    return False
+    return True
 
 
 def is_empty_tuple(node: LN) -> bool:
 
 
 def is_empty_tuple(node: LN) -> bool:
@@ -2876,7 +2924,8 @@ def is_python36(node: Node) -> bool:
     """Return True if the current file is using Python 3.6+ features.
 
     Currently looking for:
     """Return True if the current file is using Python 3.6+ features.
 
     Currently looking for:
-    - f-strings; and
+    - f-strings;
+    - underscores in numeric literals; and
     - trailing commas after * or ** in function signatures and calls.
     """
     for n in node.pre_order():
     - trailing commas after * or ** in function signatures and calls.
     """
     for n in node.pre_order():
@@ -2885,6 +2934,10 @@ def is_python36(node: Node) -> bool:
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
                 return True
 
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
                 return True
 
+        elif n.type == token.NUMBER:
+            if "_" in n.value:  # type: ignore
+                return True
+
         elif (
             n.type in {syms.typedargslist, syms.arglist}
             and n.children
         elif (
             n.type in {syms.typedargslist, syms.arglist}
             and n.children
@@ -3118,7 +3171,7 @@ class Report:
         - otherwise return 0.
         """
         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
         - otherwise return 0.
         """
         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
-        # 126 we have special returncodes reserved by the shell.
+        # 126 we have special return codes reserved by the shell.
         if self.failure_count:
             return 123
 
         if self.failure_count:
             return 123