]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Use 'args' to Avoid GH workflow warning (#1990)
[etc/vim.git] / src / black / __init__.py
index 9b9903c66458f1ac1583fcc369e1fd65c5ff7838..08930d11ceac497a429510e1710219b424eeb028 100644 (file)
@@ -93,7 +93,7 @@ Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
 Timestamp = float
 FileSize = int
 CacheInfo = Tuple[Timestamp, FileSize]
-Cache = Dict[Path, CacheInfo]
+Cache = Dict[str, CacheInfo]
 out = partial(click.secho, bold=True, err=True)
 err = partial(click.secho, fg="red", err=True)
 
@@ -260,6 +260,7 @@ class Mode:
     target_versions: Set[TargetVersion] = field(default_factory=set)
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
+    magic_trailing_comma: bool = True
     experimental_string_processing: bool = False
     is_pyi: bool = False
 
@@ -397,6 +398,12 @@ def target_version_option_callback(
     is_flag=True,
     help="Don't normalize string quotes or prefixes.",
 )
+@click.option(
+    "-C",
+    "--skip-magic-trailing-comma",
+    is_flag=True,
+    help="Don't use trailing commas as a reason to split lines.",
+)
 @click.option(
     "--experimental-string-processing",
     is_flag=True,
@@ -524,6 +531,7 @@ def main(
     fast: bool,
     pyi: bool,
     skip_string_normalization: bool,
+    skip_magic_trailing_comma: bool,
     experimental_string_processing: bool,
     quiet: bool,
     verbose: bool,
@@ -546,6 +554,7 @@ def main(
         line_length=line_length,
         is_pyi=pyi,
         string_normalization=not skip_string_normalization,
+        magic_trailing_comma=not skip_magic_trailing_comma,
         experimental_string_processing=experimental_string_processing,
     )
     if config and verbose:
@@ -715,7 +724,8 @@ def reformat_one(
             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
                 cache = read_cache(mode)
                 res_src = src.resolve()
-                if res_src in cache and cache[res_src] == get_cache_info(res_src):
+                res_src_s = str(res_src)
+                if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
                 src, fast=fast, write_back=write_back, mode=mode
@@ -741,7 +751,7 @@ def reformat_many(
     worker_count = os.cpu_count()
     if sys.platform == "win32":
         # Work around https://bugs.python.org/issue26903
-        worker_count = min(worker_count, 61)
+        worker_count = min(worker_count, 60)
     try:
         executor = ProcessPoolExecutor(max_workers=worker_count)
     except (ImportError, OSError):
@@ -1022,13 +1032,12 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent:
         versions = detect_target_versions(src_node)
     normalize_fmt_off(src_node)
     lines = LineGenerator(
+        mode=mode,
         remove_u_prefix="unicode_literals" in future_imports
         or supports_feature(versions, Feature.UNICODE_LITERALS),
-        is_pyi=mode.is_pyi,
-        normalize_strings=mode.string_normalization,
     )
     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
-    empty_line = Line()
+    empty_line = Line(mode=mode)
     after = 0
     split_line_features = {
         feature
@@ -1464,6 +1473,7 @@ class BracketTracker:
 class Line:
     """Holds leaves and comments. Can be printed with `str(line)`."""
 
+    mode: Mode
     depth: int = 0
     leaves: List[Leaf] = field(default_factory=list)
     # keys ordered like `leaves`
@@ -1496,8 +1506,11 @@ class Line:
             )
         if self.inside_brackets or not preformatted:
             self.bracket_tracker.mark(leaf)
-            if self.maybe_should_explode(leaf):
-                self.should_explode = True
+            if self.mode.magic_trailing_comma:
+                if self.has_magic_trailing_comma(leaf):
+                    self.should_explode = True
+            elif self.has_magic_trailing_comma(leaf, ensure_removable=True):
+                self.remove_trailing_comma()
         if not self.append_comment(leaf):
             self.leaves.append(leaf)
 
@@ -1673,10 +1686,14 @@ class Line:
     def contains_multiline_strings(self) -> bool:
         return any(is_multiline_string(leaf) for leaf in self.leaves)
 
-    def maybe_should_explode(self, closing: Leaf) -> bool:
-        """Return True if this line should explode (always be split), that is when:
-        - there's a trailing comma here; and
-        - it's not a one-tuple.
+    def has_magic_trailing_comma(
+        self, closing: Leaf, ensure_removable: bool = False
+    ) -> bool:
+        """Return True if we have a magic trailing comma, that is when:
+        - there's a trailing comma here
+        - it's not a one-tuple
+        Additionally, if ensure_removable:
+        - it's not from square bracket indexing
         """
         if not (
             closing.type in CLOSING_BRACKETS
@@ -1685,9 +1702,15 @@ class Line:
         ):
             return False
 
-        if closing.type in {token.RBRACE, token.RSQB}:
+        if closing.type == token.RBRACE:
             return True
 
+        if closing.type == token.RSQB:
+            if not ensure_removable:
+                return True
+            comma = self.leaves[-1]
+            return bool(comma.parent and comma.parent.type == syms.listmaker)
+
         if self.is_import:
             return True
 
@@ -1765,6 +1788,7 @@ class Line:
 
     def clone(self) -> "Line":
         return Line(
+            mode=self.mode,
             depth=self.depth,
             inside_brackets=self.inside_brackets,
             should_explode=self.should_explode,
@@ -1923,10 +1947,9 @@ class LineGenerator(Visitor[Line]):
     in ways that will no longer stringify to valid Python code on the tree.
     """
 
-    is_pyi: bool = False
-    normalize_strings: bool = True
-    current_line: Line = field(default_factory=Line)
+    mode: Mode
     remove_u_prefix: bool = False
+    current_line: Line = field(init=False)
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -1941,7 +1964,7 @@ class LineGenerator(Visitor[Line]):
             return  # Line is empty, don't emit. Creating a new one unnecessary.
 
         complete_line = self.current_line
-        self.current_line = Line(depth=complete_line.depth + indent)
+        self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent)
         yield complete_line
 
     def visit_default(self, node: LN) -> Iterator[Line]:
@@ -1965,7 +1988,7 @@ class LineGenerator(Visitor[Line]):
                     yield from self.line()
 
             normalize_prefix(node, inside_brackets=any_open_brackets)
-            if self.normalize_strings and node.type == token.STRING:
+            if self.mode.string_normalization and node.type == token.STRING:
                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                 normalize_string_quotes(node)
             if node.type == token.NUMBER:
@@ -2017,7 +2040,7 @@ class LineGenerator(Visitor[Line]):
 
     def visit_suite(self, node: Node) -> Iterator[Line]:
         """Visit a suite."""
-        if self.is_pyi and is_stub_suite(node):
+        if self.mode.is_pyi and is_stub_suite(node):
             yield from self.visit(node.children[2])
         else:
             yield from self.visit_default(node)
@@ -2026,7 +2049,7 @@ class LineGenerator(Visitor[Line]):
         """Visit a statement without nested statements."""
         is_suite_like = node.parent and node.parent.type in STATEMENT
         if is_suite_like:
-            if self.is_pyi and is_stub_body(node):
+            if self.mode.is_pyi and is_stub_body(node):
                 yield from self.visit_default(node)
             else:
                 yield from self.line(+1)
@@ -2034,7 +2057,11 @@ class LineGenerator(Visitor[Line]):
                 yield from self.line(-1)
 
         else:
-            if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
+            if (
+                not self.mode.is_pyi
+                or not node.parent
+                or not is_stub_suite(node.parent)
+            ):
                 yield from self.line()
             yield from self.visit_default(node)
 
@@ -2110,6 +2137,8 @@ class LineGenerator(Visitor[Line]):
 
     def __post_init__(self) -> None:
         """You are in a twisty little maze of passages."""
+        self.current_line = Line(mode=self.mode)
+
         v = self.visit_stmt
         Ø: Set[str] = set()
         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
@@ -2552,6 +2581,8 @@ def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Pr
 
 
 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
+FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
+FMT_PASS = {*FMT_OFF, *FMT_SKIP}
 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
 
 
@@ -2606,7 +2637,7 @@ def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
     consumed = 0
     nlines = 0
     ignored_lines = 0
-    for index, line in enumerate(prefix.split("\n")):
+    for index, line in enumerate(re.split("\r?\n", prefix)):
         consumed += len(line) + 1  # adding the length of the split '\n'
         line = line.lstrip()
         if not line:
@@ -4350,6 +4381,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
         # `StringSplitter` will break it down further if necessary.
         string_value = LL[string_idx].value
         string_line = Line(
+            mode=line.mode,
             depth=line.depth + 1,
             inside_brackets=True,
             should_explode=line.should_explode,
@@ -4943,7 +4975,7 @@ def bracket_split_build_line(
     If `is_body` is True, the result line is one-indented inside brackets and as such
     has its first leaf's prefix normalized and a trailing comma added when expected.
     """
-    result = Line(depth=original.depth)
+    result = Line(mode=original.mode, depth=original.depth)
     if is_body:
         result.inside_brackets = True
         result.depth += 1
@@ -5015,7 +5047,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[
         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
 
-    current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+    current_line = Line(
+        mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+    )
     lowest_depth = sys.maxsize
     trailing_comma_safe = True
 
@@ -5027,7 +5061,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[
         except ValueError:
             yield current_line
 
-            current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+            current_line = Line(
+                mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+            )
             current_line.append(leaf)
 
     for leaf in line.leaves:
@@ -5051,7 +5087,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[
         if leaf_priority == delimiter_priority:
             yield current_line
 
-            current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+            current_line = Line(
+                mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+            )
     if current_line:
         if (
             trailing_comma_safe
@@ -5072,7 +5110,9 @@ def standalone_comment_split(
     if not line.contains_standalone_comments(0):
         raise CannotSplit("Line does not have any standalone comments")
 
-    current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+    current_line = Line(
+        mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+    )
 
     def append_to_line(leaf: Leaf) -> Iterator[Line]:
         """Append `leaf` to current line or to new line if appending impossible."""
@@ -5082,7 +5122,9 @@ def standalone_comment_split(
         except ValueError:
             yield current_line
 
-            current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+            current_line = Line(
+                line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+            )
             current_line.append(leaf)
 
     for leaf in line.leaves:
@@ -5364,58 +5406,80 @@ def convert_one_fmt_off_pair(node: Node) -> bool:
     for leaf in node.leaves():
         previous_consumed = 0
         for comment in list_comments(leaf.prefix, is_endmarker=False):
-            if comment.value in FMT_OFF:
-                # We only want standalone comments. If there's no previous leaf or
-                # the previous leaf is indentation, it's a standalone comment in
-                # disguise.
-                if comment.type != STANDALONE_COMMENT:
-                    prev = preceding_leaf(leaf)
-                    if prev and prev.type not in WHITESPACE:
+            if comment.value not in FMT_PASS:
+                previous_consumed = comment.consumed
+                continue
+            # We only want standalone comments. If there's no previous leaf or
+            # the previous leaf is indentation, it's a standalone comment in
+            # disguise.
+            if comment.value in FMT_PASS and comment.type != STANDALONE_COMMENT:
+                prev = preceding_leaf(leaf)
+                if prev:
+                    if comment.value in FMT_OFF and prev.type not in WHITESPACE:
+                        continue
+                    if comment.value in FMT_SKIP and prev.type in WHITESPACE:
                         continue
 
-                ignored_nodes = list(generate_ignored_nodes(leaf))
-                if not ignored_nodes:
-                    continue
-
-                first = ignored_nodes[0]  # Can be a container node with the `leaf`.
-                parent = first.parent
-                prefix = first.prefix
-                first.prefix = prefix[comment.consumed :]
-                hidden_value = (
-                    comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
-                )
-                if hidden_value.endswith("\n"):
-                    # That happens when one of the `ignored_nodes` ended with a NEWLINE
-                    # leaf (possibly followed by a DEDENT).
-                    hidden_value = hidden_value[:-1]
-                first_idx: Optional[int] = None
-                for ignored in ignored_nodes:
-                    index = ignored.remove()
-                    if first_idx is None:
-                        first_idx = index
-                assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
-                assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
-                parent.insert_child(
-                    first_idx,
-                    Leaf(
-                        STANDALONE_COMMENT,
-                        hidden_value,
-                        prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
-                    ),
-                )
-                return True
+            ignored_nodes = list(generate_ignored_nodes(leaf, comment))
+            if not ignored_nodes:
+                continue
 
-            previous_consumed = comment.consumed
+            first = ignored_nodes[0]  # Can be a container node with the `leaf`.
+            parent = first.parent
+            prefix = first.prefix
+            first.prefix = prefix[comment.consumed :]
+            hidden_value = "".join(str(n) for n in ignored_nodes)
+            if comment.value in FMT_OFF:
+                hidden_value = comment.value + "\n" + hidden_value
+            if comment.value in FMT_SKIP:
+                hidden_value += "  " + comment.value
+            if hidden_value.endswith("\n"):
+                # That happens when one of the `ignored_nodes` ended with a NEWLINE
+                # leaf (possibly followed by a DEDENT).
+                hidden_value = hidden_value[:-1]
+            first_idx: Optional[int] = None
+            for ignored in ignored_nodes:
+                index = ignored.remove()
+                if first_idx is None:
+                    first_idx = index
+            assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
+            assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
+            parent.insert_child(
+                first_idx,
+                Leaf(
+                    STANDALONE_COMMENT,
+                    hidden_value,
+                    prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
+                ),
+            )
+            return True
 
     return False
 
 
-def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
+def generate_ignored_nodes(leaf: Leaf, comment: ProtoComment) -> Iterator[LN]:
     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
 
+    If comment is skip, returns leaf only.
     Stops at the end of the block.
     """
     container: Optional[LN] = container_of(leaf)
+    if comment.value in FMT_SKIP:
+        prev_sibling = leaf.prev_sibling
+        if comment.value in leaf.prefix and prev_sibling is not None:
+            leaf.prefix = leaf.prefix.replace(comment.value, "")
+            siblings = [prev_sibling]
+            while (
+                "\n" not in prev_sibling.prefix
+                and prev_sibling.prev_sibling is not None
+            ):
+                prev_sibling = prev_sibling.prev_sibling
+                siblings.insert(0, prev_sibling)
+            for sibling in siblings:
+                yield sibling
+        elif leaf.parent is not None:
+            yield leaf.parent
+        return
     while container is not None and container.type != token.ENDMARKER:
         if is_fmt_on(container):
             return
@@ -5767,7 +5831,7 @@ def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool:
         return False
 
     return max_priority == COMMA_PRIORITY and (
-        trailing_comma
+        (line.mode.magic_trailing_comma and trailing_comma)
         # always explode imports
         or opening_bracket.parent.type in {syms.atom, syms.import_from}
     )
@@ -6742,8 +6806,8 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set
     """
     todo, done = set(), set()
     for src in sources:
-        src = src.resolve()
-        if cache.get(src) != get_cache_info(src):
+        res_src = src.resolve()
+        if cache.get(str(res_src)) != get_cache_info(res_src):
             todo.add(src)
         else:
             done.add(src)
@@ -6755,7 +6819,10 @@ def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
     cache_file = get_cache_file(mode)
     try:
         CACHE_DIR.mkdir(parents=True, exist_ok=True)
-        new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
+        new_cache = {
+            **cache,
+            **{str(src.resolve()): get_cache_info(src) for src in sources},
+        }
         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
             pickle.dump(new_cache, f, protocol=4)
         os.replace(f.name, cache_file)