]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Support PEP-570 (positional only arguments) (#946)
[etc/vim.git] / black.py
index 97393e164005404ed76951f5612d8fcd1b45fb3c..910a0ed1e3748f5996620d2cf4bedd3b199b01b3 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,5 +1,7 @@
+import ast
 import asyncio
 from concurrent.futures import Executor, ProcessPoolExecutor
 import asyncio
 from concurrent.futures import Executor, ProcessPoolExecutor
+from contextlib import contextmanager
 from datetime import datetime
 from enum import Enum
 from functools import lru_cache, partial, wraps
 from datetime import datetime
 from enum import Enum
 from functools import lru_cache, partial, wraps
@@ -140,6 +142,8 @@ class Feature(Enum):
     # set for every version of python.
     ASYNC_IDENTIFIERS = 6
     ASYNC_KEYWORDS = 7
     # set for every version of python.
     ASYNC_IDENTIFIERS = 6
     ASYNC_KEYWORDS = 7
+    ASSIGNMENT_EXPRESSIONS = 8
+    POS_ONLY_ARGUMENTS = 9
 
 
 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
 
 
 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
@@ -174,6 +178,8 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
         Feature.TRAILING_COMMA_IN_CALL,
         Feature.TRAILING_COMMA_IN_DEF,
         Feature.ASYNC_KEYWORDS,
         Feature.TRAILING_COMMA_IN_CALL,
         Feature.TRAILING_COMMA_IN_DEF,
         Feature.ASYNC_KEYWORDS,
+        Feature.ASSIGNMENT_EXPRESSIONS,
+        Feature.POS_ONLY_ARGUMENTS,
     },
 }
 
     },
 }
 
@@ -523,6 +529,7 @@ def reformat_many(
         )
     finally:
         shutdown(loop)
         )
     finally:
         shutdown(loop)
+        executor.shutdown()
 
 
 async def schedule_formatting(
 
 
 async def schedule_formatting(
@@ -628,9 +635,8 @@ def format_file_in_place(
         src_name = f"{src}\t{then} +0000"
         dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
         src_name = f"{src}\t{then} +0000"
         dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
-        if lock:
-            lock.acquire()
-        try:
+
+        with lock or nullcontext():
             f = io.TextIOWrapper(
                 sys.stdout.buffer,
                 encoding=encoding,
             f = io.TextIOWrapper(
                 sys.stdout.buffer,
                 encoding=encoding,
@@ -639,9 +645,7 @@ def format_file_in_place(
             )
             f.write(diff_contents)
             f.detach()
             )
             f.write(diff_contents)
             f.detach()
-        finally:
-            if lock:
-                lock.release()
+
     return True
 
 
     return True
 
 
@@ -933,6 +937,7 @@ MATH_OPERATORS = {
     token.DOUBLESTAR,
 }
 STARS = {token.STAR, token.DOUBLESTAR}
     token.DOUBLESTAR,
 }
 STARS = {token.STAR, token.DOUBLESTAR}
+VARARGS_SPECIALS = STARS | {token.SLASH}
 VARARGS_PARENTS = {
     syms.arglist,
     syms.argument,  # double star in arglist
 VARARGS_PARENTS = {
     syms.arglist,
     syms.argument,  # double star in arglist
@@ -1643,6 +1648,19 @@ class LineGenerator(Visitor[Line]):
             node.children[2].value = ""
         yield from super().visit_default(node)
 
             node.children[2].value = ""
         yield from super().visit_default(node)
 
+    def visit_factor(self, node: Node) -> Iterator[Line]:
+        """Force parentheses between a unary op and a binary power:
+
+        -2 ** 8 -> -(2 ** 8)
+        """
+        child = node.children[1]
+        if child.type == syms.power and len(child.children) == 3:
+            lpar = Leaf(token.LPAR, "(")
+            rpar = Leaf(token.RPAR, ")")
+            index = child.remove() or 0
+            node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+        yield from self.visit_default(node)
+
     def visit_INDENT(self, node: Node) -> Iterator[Line]:
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
     def visit_INDENT(self, node: Node) -> Iterator[Line]:
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
@@ -1832,7 +1850,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
                     # that, too.
                     return prevp.prefix
 
                     # that, too.
                     return prevp.prefix
 
-        elif prevp.type in STARS:
+        elif prevp.type in VARARGS_SPECIALS:
             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
                 return NO
 
             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
                 return NO
 
@@ -1922,7 +1940,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
             if not prevp or prevp.type == token.LPAR:
                 return NO
 
             if not prevp or prevp.type == token.LPAR:
                 return NO
 
-        elif prev.type in {token.EQUAL} | STARS:
+        elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
             return NO
 
     elif p.type == syms.decorator:
             return NO
 
     elif p.type == syms.decorator:
@@ -2851,6 +2869,8 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
             check_lpar = True
 
         if check_lpar:
             check_lpar = True
 
         if check_lpar:
+            if is_walrus_assignment(child):
+                continue
             if child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(child, parent=node):
                     lpar = Leaf(token.LPAR, "")
             if child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(child, parent=node):
                     lpar = Leaf(token.LPAR, "")
@@ -3005,18 +3025,24 @@ def is_empty_tuple(node: LN) -> bool:
     )
 
 
     )
 
 
+def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
+    """Returns `wrapped` if `node` is of the shape ( wrapped ).
+
+    Parenthesis can be optional. Returns None otherwise"""
+    if len(node.children) != 3:
+        return None
+    lpar, wrapped, rpar = node.children
+    if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
+        return None
+
+    return wrapped
+
+
 def is_one_tuple(node: LN) -> bool:
     """Return True if `node` holds a tuple with one element, with or without parens."""
     if node.type == syms.atom:
 def is_one_tuple(node: LN) -> bool:
     """Return True if `node` holds a tuple with one element, with or without parens."""
     if node.type == syms.atom:
-        if len(node.children) != 3:
-            return False
-
-        lpar, gexp, rpar = node.children
-        if not (
-            lpar.type == token.LPAR
-            and gexp.type == syms.testlist_gexp
-            and rpar.type == token.RPAR
-        ):
+        gexp = unwrap_singleton_parenthesis(node)
+        if gexp is None or gexp.type != syms.testlist_gexp:
             return False
 
         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
             return False
 
         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
@@ -3028,6 +3054,12 @@ def is_one_tuple(node: LN) -> bool:
     )
 
 
     )
 
 
+def is_walrus_assignment(node: LN) -> bool:
+    """Return True iff `node` is of the shape ( test := test )"""
+    inner = unwrap_singleton_parenthesis(node)
+    return inner is not None and inner.type == syms.namedexpr_test
+
+
 def is_yield(node: LN) -> bool:
     """Return True if `node` holds a `yield` or `yield from` expression."""
     if node.type == syms.yield_expr:
 def is_yield(node: LN) -> bool:
     """Return True if `node` holds a `yield` or `yield from` expression."""
     if node.type == syms.yield_expr:
@@ -3057,7 +3089,7 @@ def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
     extended iterable unpacking (PEP 3132) and additional unpacking
     generalizations (PEP 448).
     """
     extended iterable unpacking (PEP 3132) and additional unpacking
     generalizations (PEP 448).
     """
-    if leaf.type not in STARS or not leaf.parent:
+    if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
         return False
 
     p = leaf.parent
         return False
 
     p = leaf.parent
@@ -3139,7 +3171,7 @@ def ensure_visible(leaf: Leaf) -> None:
     """Make sure parentheses are visible.
 
     They could be invisible as part of some statements (see
     """Make sure parentheses are visible.
 
     They could be invisible as part of some statements (see
-    :func:`normalize_invible_parens` and :func:`visit_import_from`).
+    :func:`normalize_invisible_parens` and :func:`visit_import_from`).
     """
     if leaf.type == token.LPAR:
         leaf.value = "("
     """
     if leaf.type == token.LPAR:
         leaf.value = "("
@@ -3172,8 +3204,9 @@ def get_features_used(node: Node) -> Set[Feature]:
 
     Currently looking for:
     - f-strings;
 
     Currently looking for:
     - f-strings;
-    - underscores in numeric literals; and
-    - trailing commas after * or ** in function signatures and calls.
+    - underscores in numeric literals;
+    - trailing commas after * or ** in function signatures and calls;
+    - positional only arguments in function signatures and lambdas;
     """
     features: Set[Feature] = set()
     for n in node.pre_order():
     """
     features: Set[Feature] = set()
     for n in node.pre_order():
@@ -3186,6 +3219,13 @@ def get_features_used(node: Node) -> Set[Feature]:
             if "_" in n.value:  # type: ignore
                 features.add(Feature.NUMERIC_UNDERSCORES)
 
             if "_" in n.value:  # type: ignore
                 features.add(Feature.NUMERIC_UNDERSCORES)
 
+        elif n.type == token.SLASH:
+            if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
+                features.add(Feature.POS_ONLY_ARGUMENTS)
+
+        elif n.type == token.COLONEQUAL:
+            features.add(Feature.ASSIGNMENT_EXPRESSIONS)
+
         elif (
             n.type in {syms.typedargslist, syms.arglist}
             and n.children
         elif (
             n.type in {syms.typedargslist, syms.arglist}
             and n.children
@@ -3467,32 +3507,58 @@ class Report:
         return ", ".join(report) + "."
 
 
         return ", ".join(report) + "."
 
 
-def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
-    for feature_version in (7, 6):
-        try:
-            return ast3.parse(src, feature_version=feature_version)
-        except SyntaxError:
-            continue
+def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
+    filename = "<unknown>"
+    if sys.version_info >= (3, 8):
+        # TODO: support Python 4+ ;)
+        for minor_version in range(sys.version_info[1], 4, -1):
+            try:
+                return ast.parse(src, filename, feature_version=(3, minor_version))
+            except SyntaxError:
+                continue
+    else:
+        for feature_version in (7, 6):
+            try:
+                return ast3.parse(src, filename, feature_version=feature_version)
+            except SyntaxError:
+                continue
 
     return ast27.parse(src)
 
 
 
     return ast27.parse(src)
 
 
+def _fixup_ast_constants(
+    node: Union[ast.AST, ast3.AST, ast27.AST]
+) -> Union[ast.AST, ast3.AST, ast27.AST]:
+    """Map ast nodes deprecated in 3.8 to Constant."""
+    # casts are required until this is released:
+    # https://github.com/python/typeshed/pull/3142
+    if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
+        return cast(ast.AST, ast.Constant(value=node.s))
+    elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
+        return cast(ast.AST, ast.Constant(value=node.n))
+    elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
+        return cast(ast.AST, ast.Constant(value=node.value))
+    return node
+
+
 def assert_equivalent(src: str, dst: str) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
 
 def assert_equivalent(src: str, dst: str) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
 
-    def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
+    def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
         """Simple visitor generating strings to compare ASTs by content."""
         """Simple visitor generating strings to compare ASTs by content."""
+
+        node = _fixup_ast_constants(node)
+
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
         for field in sorted(node._fields):
             # TypeIgnore has only one field 'lineno' which breaks this comparison
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
         for field in sorted(node._fields):
             # TypeIgnore has only one field 'lineno' which breaks this comparison
-            if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
+            type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
+            if sys.version_info >= (3, 8):
+                type_ignore_classes += (ast.TypeIgnore,)
+            if isinstance(node, type_ignore_classes):
                 break
 
                 break
 
-            # Ignore str kind which is case sensitive / and ignores unicode_literals
-            if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
-                continue
-
             try:
                 value = getattr(node, field)
             except AttributeError:
             try:
                 value = getattr(node, field)
             except AttributeError:
@@ -3506,15 +3572,15 @@ def assert_equivalent(src: str, dst: str) -> None:
                     # parentheses and they change the AST.
                     if (
                         field == "targets"
                     # parentheses and they change the AST.
                     if (
                         field == "targets"
-                        and isinstance(node, (ast3.Delete, ast27.Delete))
-                        and isinstance(item, (ast3.Tuple, ast27.Tuple))
+                        and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
+                        and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
                     ):
                         for item in item.elts:
                             yield from _v(item, depth + 2)
                     ):
                         for item in item.elts:
                             yield from _v(item, depth + 2)
-                    elif isinstance(item, (ast3.AST, ast27.AST)):
+                    elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
                         yield from _v(item, depth + 2)
 
                         yield from _v(item, depth + 2)
 
-            elif isinstance(value, (ast3.AST, ast27.AST)):
+            elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
                 yield from _v(value, depth + 2)
 
             else:
                 yield from _v(value, depth + 2)
 
             else:
@@ -3536,7 +3602,7 @@ def assert_equivalent(src: str, dst: str) -> None:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
-            f"Please report a bug on https://github.com/python/black/issues.  "
+            f"Please report a bug on https://github.com/psf/black/issues.  "
             f"This invalid output might be helpful: {log}"
         ) from None
 
             f"This invalid output might be helpful: {log}"
         ) from None
 
@@ -3547,7 +3613,7 @@ def assert_equivalent(src: str, dst: str) -> None:
         raise AssertionError(
             f"INTERNAL ERROR: Black produced code that is not equivalent to "
             f"the source.  "
         raise AssertionError(
             f"INTERNAL ERROR: Black produced code that is not equivalent to "
             f"the source.  "
-            f"Please report a bug on https://github.com/python/black/issues.  "
+            f"Please report a bug on https://github.com/psf/black/issues.  "
             f"This diff might be helpful: {log}"
         ) from None
 
             f"This diff might be helpful: {log}"
         ) from None
 
@@ -3563,7 +3629,7 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None:
         raise AssertionError(
             f"INTERNAL ERROR: Black produced different code on the second pass "
             f"of the formatter.  "
         raise AssertionError(
             f"INTERNAL ERROR: Black produced different code on the second pass "
             f"of the formatter.  "
-            f"Please report a bug on https://github.com/python/black/issues.  "
+            f"Please report a bug on https://github.com/psf/black/issues.  "
             f"This diff might be helpful: {log}"
         ) from None
 
             f"This diff might be helpful: {log}"
         ) from None
 
@@ -3580,6 +3646,13 @@ def dump_to_file(*output: str) -> str:
     return f.name
 
 
     return f.name
 
 
+@contextmanager
+def nullcontext() -> Iterator[None]:
+    """Return context manager that does nothing.
+    Similar to `nullcontext` from python 3.7"""
+    yield
+
+
 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib