]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Update calver version number (#835)
[etc/vim.git] / black.py
index 2850ae1a19cb0038ce5acd4bbb86837db171cda0..9ecfbe18fe8443a6bb9c7640de036523127cad35 100644 (file)
--- a/black.py
+++ b/black.py
@@ -49,7 +49,7 @@ from blib2to3.pgen2.grammar import Grammar
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.9b0"
+__version__ = "19.3b0"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
@@ -68,7 +68,7 @@ LeafID = int
 Priority = int
 Index = int
 LN = Union[Leaf, Node]
-SplitFunc = Callable[["Line", bool], Iterator["Line"]]
+SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
 Timestamp = float
 FileSize = int
 CacheInfo = Tuple[Timestamp, FileSize]
@@ -113,20 +113,19 @@ class Changed(Enum):
 
 
 class TargetVersion(Enum):
-    PYPY35 = 1
-    CPY27 = 2
-    CPY33 = 3
-    CPY34 = 4
-    CPY35 = 5
-    CPY36 = 6
-    CPY37 = 7
-    CPY38 = 8
+    PY27 = 2
+    PY33 = 3
+    PY34 = 4
+    PY35 = 5
+    PY36 = 6
+    PY37 = 7
+    PY38 = 8
 
     def is_python2(self) -> bool:
-        return self is TargetVersion.CPY27
+        return self is TargetVersion.PY27
 
 
-PY36_VERSIONS = {TargetVersion.CPY36, TargetVersion.CPY37, TargetVersion.CPY38}
+PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
 
 
 class Feature(Enum):
@@ -134,32 +133,35 @@ class Feature(Enum):
     UNICODE_LITERALS = 1
     F_STRINGS = 2
     NUMERIC_UNDERSCORES = 3
-    TRAILING_COMMA = 4
+    TRAILING_COMMA_IN_CALL = 4
+    TRAILING_COMMA_IN_DEF = 5
 
 
 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
-    TargetVersion.CPY27: set(),
-    TargetVersion.PYPY35: {Feature.UNICODE_LITERALS, Feature.F_STRINGS},
-    TargetVersion.CPY33: {Feature.UNICODE_LITERALS},
-    TargetVersion.CPY34: {Feature.UNICODE_LITERALS},
-    TargetVersion.CPY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
-    TargetVersion.CPY36: {
+    TargetVersion.PY27: set(),
+    TargetVersion.PY33: {Feature.UNICODE_LITERALS},
+    TargetVersion.PY34: {Feature.UNICODE_LITERALS},
+    TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
+    TargetVersion.PY36: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
-        Feature.TRAILING_COMMA,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
     },
-    TargetVersion.CPY37: {
+    TargetVersion.PY37: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
-        Feature.TRAILING_COMMA,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
     },
-    TargetVersion.CPY38: {
+    TargetVersion.PY38: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
-        Feature.TRAILING_COMMA,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
     },
 }
 
@@ -248,6 +250,16 @@ def read_pyproject_toml(
         "per-file auto-detection]"
     ),
 )
+@click.option(
+    "--py36",
+    is_flag=True,
+    help=(
+        "Allow using Python 3.6-only syntax on all input files.  This will put "
+        "trailing commas in function signatures and calls also after *args and "
+        "**kwargs. Deprecated; use --target-version instead. "
+        "[default: per-file auto-detection]"
+    ),
+)
 @click.option(
     "--pyi",
     is_flag=True,
@@ -351,6 +363,7 @@ def main(
     diff: bool,
     fast: bool,
     pyi: bool,
+    py36: bool,
     skip_string_normalization: bool,
     quiet: bool,
     verbose: bool,
@@ -362,7 +375,17 @@ def main(
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
     if target_version:
-        versions = set(target_version)
+        if py36:
+            err(f"Cannot use both --target-version and --py36")
+            ctx.exit(2)
+        else:
+            versions = set(target_version)
+    elif py36:
+        err(
+            "--py36 is deprecated and will be removed in a future version. "
+            "Use --target-version py36 instead."
+        )
+        versions = PY36_VERSIONS
     else:
         # We'll autodetect later.
         versions = set()
@@ -503,12 +526,14 @@ async def schedule_formatting(
         manager = Manager()
         lock = manager.Lock()
     tasks = {
-        loop.run_in_executor(
-            executor, format_file_in_place, src, fast, mode, write_back, lock
+        asyncio.ensure_future(
+            loop.run_in_executor(
+                executor, format_file_in_place, src, fast, mode, write_back, lock
+            )
         ): src
         for src in sorted(sources)
     }
-    pending: Iterable[asyncio.Task] = tasks.keys()
+    pending: Iterable[asyncio.Future] = tasks.keys()
     try:
         loop.add_signal_handler(signal.SIGINT, cancel, pending)
         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
@@ -664,6 +689,11 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
     empty_line = Line()
     after = 0
+    split_line_features = {
+        feature
+        for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
+        if supports_feature(versions, feature)
+    }
     for current_line in lines.visit(src_node):
         for _ in range(after):
             dst_contents += str(empty_line)
@@ -671,9 +701,7 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
         for _ in range(before):
             dst_contents += str(empty_line)
         for line in split_line(
-            current_line,
-            line_length=mode.line_length,
-            supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
+            current_line, line_length=mode.line_length, features=split_line_features
         ):
             dst_contents += str(line)
     return dst_contents
@@ -696,24 +724,20 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
         return tiow.read(), encoding, newline
 
 
-GRAMMARS = [
-    pygram.python_grammar_no_print_statement_no_exec_statement,
-    pygram.python_grammar_no_print_statement,
-    pygram.python_grammar,
-]
-
-
 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
     if not target_versions:
-        return GRAMMARS
-    elif all(not version.is_python2() for version in target_versions):
-        # Python 2-compatible code, so don't try Python 3 grammar.
+        # No target_version specified, so try all grammars.
         return [
             pygram.python_grammar_no_print_statement_no_exec_statement,
             pygram.python_grammar_no_print_statement,
+            pygram.python_grammar,
         ]
+    elif all(version.is_python2() for version in target_versions):
+        # Python 2-only code, so try Python 2 grammars.
+        return [pygram.python_grammar_no_print_statement, pygram.python_grammar]
     else:
-        return [pygram.python_grammar]
+        # Python 3-compatible code, so only try Python 3 grammar.
+        return [pygram.python_grammar_no_print_statement_no_exec_statement]
 
 
 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
@@ -1063,9 +1087,7 @@ class Line:
 
     depth: int = 0
     leaves: List[Leaf] = Factory(list)
-    # The LeafID keys of comments must remain ordered by the corresponding leaf's index
-    # in leaves
-    comments: Dict[LeafID, List[Leaf]] = Factory(dict)
+    comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
     bracket_tracker: BracketTracker = Factory(BracketTracker)
     inside_brackets: bool = False
     should_explode: bool = False
@@ -1196,6 +1218,29 @@ class Line:
             if leaf.type == STANDALONE_COMMENT:
                 if leaf.bracket_depth <= depth_limit:
                     return True
+        return False
+
+    def contains_inner_type_comments(self) -> bool:
+        ignored_ids = set()
+        try:
+            last_leaf = self.leaves[-1]
+            ignored_ids.add(id(last_leaf))
+            if last_leaf.type == token.COMMA:
+                # When trailing commas are inserted by Black for consistency, comments
+                # after the previous last element are not moved (they don't have to,
+                # rendering will still be correct).  So we ignore trailing commas.
+                last_leaf = self.leaves[-2]
+                ignored_ids.add(id(last_leaf))
+        except IndexError:
+            return False
+
+        for leaf_id, comments in self.comments.items():
+            if leaf_id in ignored_ids:
+                continue
+
+            for comment in comments:
+                if is_type_comment(comment):
+                    return True
 
         return False
 
@@ -1277,13 +1322,8 @@ class Line:
             comment.prefix = ""
             return False
 
-        else:
-            leaf_id = id(self.leaves[-1])
-            if leaf_id not in self.comments:
-                self.comments[leaf_id] = [comment]
-            else:
-                self.comments[leaf_id].append(comment)
-            return True
+        self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
+        return True
 
     def comments_after(self, leaf: Leaf) -> List[Leaf]:
         """Generate comments that should appear directly after `leaf`."""
@@ -1291,17 +1331,11 @@ class Line:
 
     def remove_trailing_comma(self) -> None:
         """Remove the trailing comma and moves the comments attached to it."""
-        # Remember, the LeafID keys of self.comments are ordered by the
-        # corresponding leaf's index in self.leaves
-        # If id(self.leaves[-2]) is in self.comments, the order doesn't change.
-        # Otherwise, we insert it into self.comments, and it becomes the last entry.
-        # However, since we delete id(self.leaves[-1]) from self.comments, the invariant
-        # is maintained
-        self.comments.setdefault(id(self.leaves[-2]), []).extend(
-            self.comments.get(id(self.leaves[-1]), [])
+        trailing_comma = self.leaves.pop()
+        trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
+        self.comments.setdefault(id(self.leaves[-1]), []).extend(
+            trailing_comma_comments
         )
-        self.comments.pop(id(self.leaves[-1]), None)
-        self.leaves.pop()
 
     def is_complex_subscript(self, leaf: Leaf) -> bool:
         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
@@ -2133,7 +2167,7 @@ def split_line(
     line: Line,
     line_length: int,
     inner: bool = False,
-    supports_trailing_commas: bool = False,
+    features: Collection[Feature] = (),
 ) -> Iterator[Line]:
     """Split a `line` into potentially many lines.
 
@@ -2142,7 +2176,7 @@ def split_line(
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
 
-    If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
+    `features` are syntactical features that may be used in the output.
     """
     if line.is_comment:
         yield line
@@ -2150,16 +2184,8 @@ def split_line(
 
     line_str = str(line).strip("\n")
 
-    # we don't want to split special comments like type annotations
-    # https://github.com/python/typing/issues/186
-    has_special_comment = False
-    for leaf in line.leaves:
-        for comment in line.comments_after(leaf):
-            if leaf.type == token.COMMA and is_special_comment(comment):
-                has_special_comment = True
-
     if (
-        not has_special_comment
+        not line.contains_inner_type_comments()
         and not line.should_explode
         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
     ):
@@ -2171,13 +2197,9 @@ def split_line(
         split_funcs = [left_hand_split]
     else:
 
-        def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
+        def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
             for omit in generate_trailers_to_omit(line, line_length):
-                lines = list(
-                    right_hand_split(
-                        line, line_length, supports_trailing_commas, omit=omit
-                    )
-                )
+                lines = list(right_hand_split(line, line_length, features, omit=omit))
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
@@ -2185,7 +2207,7 @@ def split_line(
             # All splits failed, best effort split with no omits.
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line.
-            yield from right_hand_split(line, supports_trailing_commas)
+            yield from right_hand_split(line, line_length, features=features)
 
         if line.inside_brackets:
             split_funcs = [delimiter_split, standalone_comment_split, rhs]
@@ -2197,16 +2219,13 @@ def split_line(
         # split altogether.
         result: List[Line] = []
         try:
-            for l in split_func(line, supports_trailing_commas):
+            for l in split_func(line, features):
                 if str(l).strip("\n") == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
                 result.extend(
                     split_line(
-                        l,
-                        line_length=line_length,
-                        inner=True,
-                        supports_trailing_commas=supports_trailing_commas,
+                        l, line_length=line_length, inner=True, features=features
                     )
                 )
         except CannotSplit:
@@ -2220,9 +2239,7 @@ def split_line(
         yield line
 
 
-def left_hand_split(
-    line: Line, supports_trailing_commas: bool = False
-) -> Iterator[Line]:
+def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
@@ -2261,7 +2278,7 @@ def left_hand_split(
 def right_hand_split(
     line: Line,
     line_length: int,
-    supports_trailing_commas: bool = False,
+    features: Collection[Feature] = (),
     omit: Collection[LeafID] = (),
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
@@ -2320,12 +2337,7 @@ def right_hand_split(
     ):
         omit = {id(closing_bracket), *omit}
         try:
-            yield from right_hand_split(
-                line,
-                line_length,
-                supports_trailing_commas=supports_trailing_commas,
-                omit=omit,
-            )
+            yield from right_hand_split(line, line_length, features=features, omit=omit)
             return
 
         except CannotSplit:
@@ -2393,10 +2405,17 @@ def bracket_split_build_line(
         if leaves:
             # Since body is a new indent level, remove spurious leading whitespace.
             normalize_prefix(leaves[0], inside_brackets=True)
-            # Ensure a trailing comma when expected.
+            # Ensure a trailing comma for imports, but be careful not to add one after
+            # any comments.
             if original.is_import:
-                if leaves[-1].type != token.COMMA:
-                    leaves.append(Leaf(token.COMMA, ","))
+                for i in range(len(leaves) - 1, -1, -1):
+                    if leaves[i].type == STANDALONE_COMMENT:
+                        continue
+                    elif leaves[i].type == token.COMMA:
+                        break
+                    else:
+                        leaves.insert(i + 1, Leaf(token.COMMA, ","))
+                        break
     # Populate the line
     for leaf in leaves:
         result.append(leaf, preformatted=True)
@@ -2414,10 +2433,8 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
     """
 
     @wraps(split_func)
-    def split_wrapper(
-        line: Line, supports_trailing_commas: bool = False
-    ) -> Iterator[Line]:
-        for l in split_func(line, supports_trailing_commas):
+    def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
+        for l in split_func(line, features):
             normalize_prefix(l.leaves[0], inside_brackets=True)
             yield l
 
@@ -2425,13 +2442,11 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
 
 
 @dont_increase_indentation
-def delimiter_split(
-    line: Line, supports_trailing_commas: bool = False
-) -> Iterator[Line]:
+def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
     """Split according to delimiters of the highest priority.
 
-    If `supports_trailing_commas` is True, the split will add trailing commas
-    also in function signatures that contain `*` and `**`.
+    If the appropriate Features are given, the split will add trailing commas
+    also in function signatures and calls that contain `*` and `**`.
     """
     try:
         last_leaf = line.leaves[-1]
@@ -2470,10 +2485,16 @@ def delimiter_split(
             yield from append_to_line(comment_after)
 
         lowest_depth = min(lowest_depth, leaf.bracket_depth)
-        if leaf.bracket_depth == lowest_depth and is_vararg(
-            leaf, within=VARARGS_PARENTS
-        ):
-            trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
+        if leaf.bracket_depth == lowest_depth:
+            if is_vararg(leaf, within={syms.typedargslist}):
+                trailing_comma_safe = (
+                    trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
+                )
+            elif is_vararg(leaf, within={syms.arglist, syms.argument}):
+                trailing_comma_safe = (
+                    trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
+                )
+
         leaf_priority = bt.delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             yield current_line
@@ -2492,7 +2513,7 @@ def delimiter_split(
 
 @dont_increase_indentation
 def standalone_comment_split(
-    line: Line, supports_trailing_commas: bool = False
+    line: Line, features: Collection[Feature] = ()
 ) -> Iterator[Line]:
     """Split standalone comments from the rest of the line."""
     if not line.contains_standalone_comments(0):
@@ -2535,14 +2556,12 @@ def is_import(leaf: Leaf) -> bool:
     )
 
 
-def is_special_comment(leaf: Leaf) -> bool:
+def is_type_comment(leaf: Leaf) -> bool:
     """Return True if the given leaf is a special comment.
     Only returns true for type comments for now."""
     t = leaf.type
     v = leaf.value
-    return bool(
-        (t == token.COMMENT or t == STANDALONE_COMMENT) and (v.startswith("# type:"))
-    )
+    return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
 
 
 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
@@ -3044,14 +3063,19 @@ def get_features_used(node: Node) -> Set[Feature]:
             and n.children
             and n.children[-1].type == token.COMMA
         ):
+            if n.type == syms.typedargslist:
+                feature = Feature.TRAILING_COMMA_IN_DEF
+            else:
+                feature = Feature.TRAILING_COMMA_IN_CALL
+
             for ch in n.children:
                 if ch.type in STARS:
-                    features.add(Feature.TRAILING_COMMA)
+                    features.add(feature)
 
                 if ch.type == syms.argument:
                     for argch in ch.children:
                         if argch.type in STARS:
-                            features.add(Feature.TRAILING_COMMA)
+                            features.add(feature)
 
     return features
 
@@ -3131,7 +3155,7 @@ def get_future_imports(node: Node) -> Set[str]:
             elif child.type == syms.import_as_names:
                 yield from get_imports_from_children(child.children)
             else:
-                assert False, "Invalid syntax parsing imports"
+                raise AssertionError("Invalid syntax parsing imports")
 
     for child in node.children:
         if child.type != syms.simple_stmt:
@@ -3371,7 +3395,7 @@ def assert_equivalent(src: str, dst: str) -> None:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
-            f"Please report a bug on https://github.com/ambv/black/issues.  "
+            f"Please report a bug on https://github.com/python/black/issues.  "
             f"This invalid output might be helpful: {log}"
         ) from None
 
@@ -3382,7 +3406,7 @@ def assert_equivalent(src: str, dst: str) -> None:
         raise AssertionError(
             f"INTERNAL ERROR: Black produced code that is not equivalent to "
             f"the source.  "
-            f"Please report a bug on https://github.com/ambv/black/issues.  "
+            f"Please report a bug on https://github.com/python/black/issues.  "
             f"This diff might be helpful: {log}"
         ) from None
 
@@ -3398,7 +3422,7 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None:
         raise AssertionError(
             f"INTERNAL ERROR: Black produced different code on the second pass "
             f"of the formatter.  "
-            f"Please report a bug on https://github.com/ambv/black/issues.  "
+            f"Please report a bug on https://github.com/python/black/issues.  "
             f"This diff might be helpful: {log}"
         ) from None
 
@@ -3438,8 +3462,12 @@ def cancel(tasks: Iterable[asyncio.Task]) -> None:
 def shutdown(loop: BaseEventLoop) -> None:
     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
     try:
+        if sys.version_info[:2] >= (3, 7):
+            all_tasks = asyncio.all_tasks
+        else:
+            all_tasks = asyncio.Task.all_tasks
         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
-        to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
+        to_cancel = [task for task in all_tasks(loop) if not task.done()]
         if not to_cancel:
             return