]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Nits around numeral normalization.
[etc/vim.git] / black.py
index 36a180da702a3a003276dc3a8e1f25b96b4abd01..d6887483903ff3282d6e1f63b071a7c587b21178 100644 (file)
--- a/black.py
+++ b/black.py
@@ -94,11 +94,12 @@ class WriteBack(Enum):
     NO = 0
     YES = 1
     DIFF = 2
+    CHECK = 3
 
     @classmethod
     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
         if check and not diff:
-            return cls.NO
+            return cls.CHECK
 
         return cls.DIFF if diff else cls.YES
 
@@ -398,7 +399,9 @@ def reformat_one(
                 mode=mode,
             ):
                 changed = Changed.YES
-            if write_back == WriteBack.YES and changed is not Changed.NO:
+            if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
+                write_back is WriteBack.CHECK and changed is Changed.NO
+            ):
                 write_cache(cache, [src], line_length, mode)
         report.done(src, changed)
     except Exception as exc:
@@ -428,50 +431,58 @@ async def schedule_formatting(
         sources, cached = filter_cached(cache, sources)
         for src in sorted(cached):
             report.done(src, Changed.CACHED)
+    if not sources:
+        return
+
     cancelled = []
-    formatted = []
-    if sources:
-        lock = None
-        if write_back == WriteBack.DIFF:
-            # For diff output, we need locks to ensure we don't interleave output
-            # from different processes.
-            manager = Manager()
-            lock = manager.Lock()
-        tasks = {
-            loop.run_in_executor(
-                executor,
-                format_file_in_place,
-                src,
-                line_length,
-                fast,
-                write_back,
-                mode,
-                lock,
-            ): src
-            for src in sorted(sources)
-        }
-        pending: Iterable[asyncio.Task] = tasks.keys()
-        try:
-            loop.add_signal_handler(signal.SIGINT, cancel, pending)
-            loop.add_signal_handler(signal.SIGTERM, cancel, pending)
-        except NotImplementedError:
-            # There are no good alternatives for these on Windows
-            pass
-        while pending:
-            done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
-            for task in done:
-                src = tasks.pop(task)
-                if task.cancelled():
-                    cancelled.append(task)
-                elif task.exception():
-                    report.failed(src, str(task.exception()))
-                else:
-                    formatted.append(src)
-                    report.done(src, Changed.YES if task.result() else Changed.NO)
+    sources_to_cache = []
+    lock = None
+    if write_back == WriteBack.DIFF:
+        # For diff output, we need locks to ensure we don't interleave output
+        # from different processes.
+        manager = Manager()
+        lock = manager.Lock()
+    tasks = {
+        loop.run_in_executor(
+            executor,
+            format_file_in_place,
+            src,
+            line_length,
+            fast,
+            write_back,
+            mode,
+            lock,
+        ): src
+        for src in sorted(sources)
+    }
+    pending: Iterable[asyncio.Task] = tasks.keys()
+    try:
+        loop.add_signal_handler(signal.SIGINT, cancel, pending)
+        loop.add_signal_handler(signal.SIGTERM, cancel, pending)
+    except NotImplementedError:
+        # There are no good alternatives for these on Windows.
+        pass
+    while pending:
+        done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
+        for task in done:
+            src = tasks.pop(task)
+            if task.cancelled():
+                cancelled.append(task)
+            elif task.exception():
+                report.failed(src, str(task.exception()))
+            else:
+                changed = Changed.YES if task.result() else Changed.NO
+                # If the file was written back or was successfully checked as
+                # well-formatted, store this information in the cache.
+                if write_back is WriteBack.YES or (
+                    write_back is WriteBack.CHECK and changed is Changed.NO
+                ):
+                    sources_to_cache.append(src)
+                report.done(src, changed)
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
-    if write_back == WriteBack.YES and formatted:
-        write_cache(cache, formatted, line_length, mode)
+    if sources_to_cache:
+        write_cache(cache, sources_to_cache, line_length, mode)
 
 
 def format_file_in_place(
@@ -484,7 +495,8 @@ def format_file_in_place(
 ) -> bool:
     """Format file under `src` path. Return True if changed.
 
-    If `write_back` is True, write reformatted code back to stdout.
+    If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
+    code to the file.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
@@ -533,7 +545,8 @@ def format_stdin_to_stdout(
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
-    If `write_back` is True, write reformatted code back to stdout.
+    If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
+    write a diff to stdout.
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
@@ -605,6 +618,7 @@ def format_str(
         remove_u_prefix=py36 or "unicode_literals" in future_imports,
         is_pyi=is_pyi,
         normalize_strings=normalize_strings,
+        allow_underscores=py36,
     )
     elt = EmptyLineTracker(is_pyi=is_pyi)
     empty_line = Line()
@@ -797,18 +811,6 @@ UNPACKING_PARENTS = {
     syms.testlist_gexp,
     syms.testlist_star_expr,
 }
-SURROUNDED_BY_BRACKETS = {
-    syms.typedargslist,
-    syms.arglist,
-    syms.subscriptlist,
-    syms.vfplist,
-    syms.import_as_names,
-    syms.yield_expr,
-    syms.testlist_gexp,
-    syms.testlist_star_expr,
-    syms.listmaker,
-    syms.dictsetmaker,
-}
 TEST_DESCENDANTS = {
     syms.test,
     syms.lambdef,
@@ -1403,6 +1405,7 @@ class LineGenerator(Visitor[Line]):
     normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
+    allow_underscores: bool = False
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -1444,6 +1447,8 @@ class LineGenerator(Visitor[Line]):
             if self.normalize_strings and node.type == token.STRING:
                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                 normalize_string_quotes(node)
+            if node.type == token.NUMBER:
+                normalize_numeric_literal(node, self.allow_underscores)
             if node.type not in WHITESPACE:
                 self.current_line.append(node)
         yield from super().visit_default(node)
@@ -1853,7 +1858,7 @@ def container_of(leaf: Leaf) -> LN:
         if parent.type == syms.file_input:
             break
 
-        if parent.type in SURROUNDED_BY_BRACKETS:
+        if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
             break
 
         container = parent
@@ -2505,6 +2510,63 @@ def normalize_string_quotes(leaf: Leaf) -> None:
     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
 
 
+def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
+    """Normalizes numeric (float, int, and complex) literals.
+
+    All letters used in the representation are normalized to lowercase, long number
+    literals are split using underscores.
+    """
+    text = leaf.value.lower()
+    if text.startswith(("0o", "0x", "0b")):
+        # Leave octal, hex, and binary literals alone.
+        pass
+    elif "e" in text:
+        before, after = text.split("e")
+        sign = ""
+        if after.startswith("-"):
+            after = after[1:]
+            sign = "-"
+        elif after.startswith("+"):
+            after = after[1:]
+        before = format_float_or_int_string(before, allow_underscores)
+        after = format_int_string(after, allow_underscores)
+        text = f"{before}e{sign}{after}"
+    elif text.endswith(("j", "l")):
+        number = text[:-1]
+        suffix = text[-1]
+        text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
+    else:
+        text = format_float_or_int_string(text, allow_underscores)
+    leaf.value = text
+
+
+def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
+    """Formats a float string like "1.0"."""
+    if "." not in text:
+        return format_int_string(text, allow_underscores)
+
+    before, after = text.split(".")
+    before = format_int_string(before, allow_underscores) if before else "0"
+    after = format_int_string(after, allow_underscores) if after else "0"
+    return f"{before}.{after}"
+
+
+def format_int_string(text: str, allow_underscores: bool) -> str:
+    """Normalizes underscores in a string to e.g. 1_000_000.
+
+    Input must be a string of at least six digits and optional underscores.
+    """
+    if not allow_underscores:
+        return text
+
+    text = text.replace("_", "")
+    if len(text) <= 6:
+        # No underscores for numbers <= 6 digits long.
+        return text
+
+    return format(int(text), "3_")
+
+
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     """Make existing optional parentheses invisible or create new ones.