]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Implement `# fmt: off` and `# fmt: on`
[etc/vim.git] / black.py
index d3e0761e273a58fe4b7318827e12bbd853f53eea..dc02128313154690de62de810abf399efcdb3433 100644 (file)
--- a/black.py
+++ b/black.py
@@ -10,7 +10,7 @@ from pathlib import Path
 import tokenize
 import sys
 from typing import (
-    Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
+    Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union
 )
 
 from attr import dataclass, Factory
@@ -44,10 +44,38 @@ class NothingChanged(UserWarning):
 class CannotSplit(Exception):
     """A readable split that fits the allotted line length is impossible.
 
-    Raised by `left_hand_split()` and `right_hand_split()`.
+    Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`.
     """
 
 
+class FormatError(Exception):
+    """Base fmt: on/off error.
+
+    It holds the number of bytes of the prefix consumed before the format
+    control comment appeared.
+    """
+
+    def __init__(self, consumed: int) -> None:
+        super().__init__(consumed)
+        self.consumed = consumed
+
+    def trim_prefix(self, leaf: Leaf) -> None:
+        leaf.prefix = leaf.prefix[self.consumed:]
+
+    def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
+        """Returns a new Leaf from the consumed part of the prefix."""
+        unformatted_prefix = leaf.prefix[:self.consumed]
+        return Leaf(token.NEWLINE, unformatted_prefix)
+
+
+class FormatOn(FormatError):
+    """Found a comment like `# fmt: on` in the file."""
+
+
+class FormatOff(FormatError):
+    """Found a comment like `# fmt: off` in the file."""
+
+
 @click.command()
 @click.option(
     '-l',
@@ -62,7 +90,7 @@ class CannotSplit(Exception):
     is_flag=True,
     help=(
         "Don't write back the files, just return the status.  Return code 0 "
-        "means nothing changed.  Return code 1 means some files were "
+        "means nothing would change.  Return code 1 means some files would be "
         "reformatted.  Return code 123 means there was an internal error."
     ),
 )
@@ -100,7 +128,7 @@ def main(
         ctx.exit(0)
     elif len(sources) == 1:
         p = sources[0]
-        report = Report()
+        report = Report(check=check)
         try:
             if not p.is_file() and str(p) == '-':
                 changed = format_stdin_to_stdout(
@@ -235,23 +263,36 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
     return dst_contents
 
 
+GRAMMARS = [
+    pygram.python_grammar_no_print_statement_no_exec_statement,
+    pygram.python_grammar_no_print_statement,
+    pygram.python_grammar_no_exec_statement,
+    pygram.python_grammar,
+]
+
+
 def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
-    drv = driver.Driver(grammar, pytree.convert)
     if src_txt[-1] != '\n':
         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
         src_txt += nl
-    try:
-        result = drv.parse_string(src_txt, True)
-    except ParseError as pe:
-        lineno, column = pe.context[1]
-        lines = src_txt.splitlines()
+    for grammar in GRAMMARS:
+        drv = driver.Driver(grammar, pytree.convert)
         try:
-            faulty_line = lines[lineno - 1]
-        except IndexError:
-            faulty_line = "<line number missing in source>"
-        raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
+            result = drv.parse_string(src_txt, True)
+            break
+
+        except ParseError as pe:
+            lineno, column = pe.context[1]
+            lines = src_txt.splitlines()
+            try:
+                faulty_line = lines[lineno - 1]
+            except IndexError:
+                faulty_line = "<line number missing in source>"
+            exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
+    else:
+        raise exc from None
 
     if isinstance(result, Leaf):
         result = Node(syms.file_input, [result])
@@ -307,6 +348,15 @@ class DebugVisitor(Visitor[T]):
                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
             out(f' {node.value!r}', fg='blue', bold=False)
 
+    @classmethod
+    def show(cls, code: str) -> None:
+        """Pretty-prints a given string of `code`.
+
+        Convenience method for debugging.
+        """
+        v: DebugVisitor[None] = DebugVisitor()
+        list(v.visit(lib2to3_parse(code)))
+
 
 KEYWORDS = set(keyword.kwlist)
 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
@@ -406,7 +456,7 @@ class BracketTracker:
         """Returns True if there is an yet unmatched open bracket on the line."""
         return bool(self.bracket_match)
 
-    def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
+    def max_priority(self, exclude: Iterable[LeafID] = ()) -> int:
         """Returns the highest priority of a delimiter found on the line.
 
         Values are consistent with what `is_delimiter()` returns.
@@ -510,10 +560,16 @@ class Line:
         ):
             return False
 
-        if closing.type == token.RSQB or closing.type == token.RBRACE:
+        if closing.type == token.RBRACE:
             self.leaves.pop()
             return True
 
+        if closing.type == token.RSQB:
+            comma = self.leaves[-1]
+            if comma.parent and comma.parent.type == syms.listmaker:
+                self.leaves.pop()
+                return True
+
         # For parens let's check if it's safe to remove the comma.  If the
         # trailing one is the only one, we might mistakenly change a tuple
         # into a different type by removing the comma.
@@ -630,6 +686,43 @@ class Line:
         return bool(self.leaves or self.comments)
 
 
+class UnformattedLines(Line):
+
+    def append(self, leaf: Leaf, preformatted: bool = False) -> None:
+        try:
+            list(generate_comments(leaf))
+        except FormatOn as f_on:
+            self.leaves.append(f_on.leaf_from_consumed(leaf))
+            raise
+
+        self.leaves.append(leaf)
+        if leaf.type == token.INDENT:
+            self.depth += 1
+        elif leaf.type == token.DEDENT:
+            self.depth -= 1
+
+    def append_comment(self, comment: Leaf) -> bool:
+        raise NotImplementedError("Unformatted lines don't store comments separately.")
+
+    def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
+        return False
+
+    def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
+        return False
+
+    def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
+        return False
+
+    def __str__(self) -> str:
+        if not self:
+            return '\n'
+
+        res = ''
+        for leaf in self.leaves:
+            res += str(leaf)
+        return res
+
+
 @dataclass
 class EmptyLineTracker:
     """Provides a stateful method that returns the number of potential extra
@@ -650,6 +743,9 @@ class EmptyLineTracker:
         (two on module-level), as well as providing an extra empty line after flow
         control keywords to make them more prominent.
         """
+        if isinstance(current_line, UnformattedLines):
+            return 0, 0
+
         before, after = self._maybe_empty_lines(current_line)
         before -= self.previous_after
         self.previous_after = after
@@ -719,7 +815,7 @@ class LineGenerator(Visitor[Line]):
     """
     current_line: Line = Factory(Line)
 
-    def line(self, indent: int = 0) -> Iterator[Line]:
+    def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
         """Generate a line.
 
         If the line is empty, only emit if it makes sense.
@@ -728,35 +824,60 @@ class LineGenerator(Visitor[Line]):
         If any lines were generated, set up a new current_line.
         """
         if not self.current_line:
-            self.current_line.depth += indent
+            if self.current_line.__class__ == type:
+                self.current_line.depth += indent
+            else:
+                self.current_line = type(depth=self.current_line.depth + indent)
             return  # Line is empty, don't emit. Creating a new one unnecessary.
 
         complete_line = self.current_line
-        self.current_line = Line(depth=complete_line.depth + indent)
+        self.current_line = type(depth=complete_line.depth + indent)
         yield complete_line
 
+    def visit(self, node: LN) -> Iterator[Line]:
+        """High-level entry point to the visitor."""
+        if isinstance(self.current_line, UnformattedLines):
+            # File contained `# fmt: off`
+            yield from self.visit_unformatted(node)
+
+        else:
+            yield from super().visit(node)
+
     def visit_default(self, node: LN) -> Iterator[Line]:
         if isinstance(node, Leaf):
             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
-            for comment in generate_comments(node):
-                if any_open_brackets:
-                    # any comment within brackets is subject to splitting
-                    self.current_line.append(comment)
-                elif comment.type == token.COMMENT:
-                    # regular trailing comment
-                    self.current_line.append(comment)
-                    yield from self.line()
-
-                else:
-                    # regular standalone comment
-                    yield from self.line()
-
-                    self.current_line.append(comment)
-                    yield from self.line()
-
-            normalize_prefix(node, inside_brackets=any_open_brackets)
-            if node.type not in WHITESPACE:
-                self.current_line.append(node)
+            try:
+                for comment in generate_comments(node):
+                    if any_open_brackets:
+                        # any comment within brackets is subject to splitting
+                        self.current_line.append(comment)
+                    elif comment.type == token.COMMENT:
+                        # regular trailing comment
+                        self.current_line.append(comment)
+                        yield from self.line()
+
+                    else:
+                        # regular standalone comment
+                        yield from self.line()
+
+                        self.current_line.append(comment)
+                        yield from self.line()
+
+            except FormatOff as f_off:
+                f_off.trim_prefix(node)
+                yield from self.line(type=UnformattedLines)
+                yield from self.visit(node)
+
+            except FormatOn as f_on:
+                # This only happens here if somebody says "fmt: on" multiple
+                # times in a row.
+                f_on.trim_prefix(node)
+                yield from self.visit_default(node)
+
+            else:
+                normalize_prefix(node, inside_brackets=any_open_brackets)
+                if node.type not in WHITESPACE:
+                    self.current_line.append(node)
         yield from super().visit_default(node)
 
     def visit_INDENT(self, node: Node) -> Iterator[Line]:
@@ -764,6 +885,7 @@ class LineGenerator(Visitor[Line]):
         yield from self.visit_default(node)
 
     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
+        # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
         yield from self.line(-1)
 
     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
@@ -816,6 +938,19 @@ class LineGenerator(Visitor[Line]):
         yield from self.visit_default(leaf)
         yield from self.line()
 
+    def visit_unformatted(self, node: LN) -> Iterator[Line]:
+        if isinstance(node, Node):
+            for child in node.children:
+                yield from self.visit(child)
+
+        else:
+            try:
+                self.current_line.append(node)
+            except FormatOn as f_on:
+                f_on.trim_prefix(node)
+                yield from self.line()
+                yield from self.visit(node)
+
     def __attrs_post_init__(self) -> None:
         """You are in a twisty little maze of passages."""
         v = self.visit_stmt
@@ -866,14 +1001,17 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
             return SPACE if prevp.type == token.COMMA else NO
 
         if prevp.type == token.EQUAL:
-            if prevp.parent and prevp.parent.type in {
-                syms.arglist,
-                syms.argument,
-                syms.parameters,
-                syms.typedargslist,
-                syms.varargslist,
-            }:
-                return NO
+            if prevp.parent:
+                if prevp.parent.type in {
+                    syms.arglist, syms.argument, syms.parameters, syms.varargslist
+                }:
+                    return NO
+
+                elif prevp.parent.type == syms.typedargslist:
+                    # A bit hacky: if the equal sign has whitespace, it means we
+                    # previously found it's a typed argument.  So, we're using
+                    # that, too.
+                    return prevp.prefix
 
         elif prevp.type == token.DOUBLESTAR:
             if prevp.parent and prevp.parent.type in {
@@ -897,6 +1035,17 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
         ):
             return NO
 
+        elif (
+            prevp.type == token.RIGHTSHIFT
+            and prevp.parent
+            and prevp.parent.type == syms.shift_expr
+            and prevp.prev_sibling
+            and prevp.prev_sibling.type == token.NAME
+            and prevp.prev_sibling.value == 'print'  # type: ignore
+        ):
+            # Python 2 print chevron
+            return NO
+
     elif prev.type in OPENING_BRACKETS:
         return NO
 
@@ -908,7 +1057,7 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
         if not prev or prev.type != token.COMMA:
             return NO
 
-    if p.type == syms.varargslist:
+    elif p.type == syms.varargslist:
         # lambdas
         if t == token.RPAR:
             return NO
@@ -1126,8 +1275,10 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
     if '#' not in p:
         return
 
+    consumed = 0
     nlines = 0
     for index, line in enumerate(p.split('\n')):
+        consumed += len(line) + 1  # adding the length of the split '\n'
         line = line.lstrip()
         if not line:
             nlines += 1
@@ -1138,7 +1289,14 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
             comment_type = token.COMMENT  # simple trailing comment
         else:
             comment_type = STANDALONE_COMMENT
-        yield Leaf(comment_type, make_comment(line), prefix='\n' * nlines)
+        comment = make_comment(line)
+        yield Leaf(comment_type, comment, prefix='\n' * nlines)
+
+        if comment in {'# fmt: on', '# yapf: enable'}:
+            raise FormatOn(consumed)
+
+        if comment in {'# fmt: off', '# yapf: disable'}:
+            raise FormatOff(consumed)
 
         nlines = 0
 
@@ -1150,7 +1308,7 @@ def make_comment(content: str) -> str:
 
     if content[0] == '#':
         content = content[1:]
-    if content and content[0] not in {' ', '!', '#'}:
+    if content and content[0] not in ' !:#':
         content = ' ' + content
     return '#' + content
 
@@ -1168,6 +1326,10 @@ def split_line(
     If `py36` is True, splitting may generate syntax that is only compatible
     with Python 3.6 and later.
     """
+    if isinstance(line, UnformattedLines):
+        yield line
+        return
+
     line_str = str(line).strip('\n')
     if len(line_str) <= line_length and '\n' not in line_str:
         yield line
@@ -1432,6 +1594,7 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
 @dataclass
 class Report:
     """Provides a reformatting counter."""
+    check: bool = False
     change_count: int = 0
     same_count: int = 0
     failure_count: int = 0
@@ -1439,7 +1602,8 @@ class Report:
     def done(self, src: Path, changed: bool) -> None:
         """Increment the counter for successful reformatting. Write out a message."""
         if changed:
-            out(f'reformatted {src}')
+            reformatted = 'would reformat' if self.check else 'reformatted'
+            out(f'{reformatted} {src}')
             self.change_count += 1
         else:
             out(f'{src} already well formatted, good job.', bold=False)
@@ -1458,7 +1622,7 @@ class Report:
         if self.failure_count:
             return 123
 
-        elif self.change_count:
+        elif self.change_count and self.check:
             return 1
 
         return 0
@@ -1468,21 +1632,27 @@ class Report:
 
         Use `click.unstyle` to remove colors.
         """
+        if self.check:
+            reformatted = "would be reformatted"
+            unchanged = "would be left unchanged"
+            failed = "would fail to reformat"
+        else:
+            reformatted = "reformatted"
+            unchanged = "left unchanged"
+            failed = "failed to reformat"
         report = []
         if self.change_count:
             s = 's' if self.change_count > 1 else ''
             report.append(
-                click.style(f'{self.change_count} file{s} reformatted', bold=True)
+                click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
             )
         if self.same_count:
             s = 's' if self.same_count > 1 else ''
-            report.append(f'{self.same_count} file{s} left unchanged')
+            report.append(f'{self.same_count} file{s} {unchanged}')
         if self.failure_count:
             s = 's' if self.failure_count > 1 else ''
             report.append(
-                click.style(
-                    f'{self.failure_count} file{s} failed to reformat', fg='red'
-                )
+                click.style(f'{self.failure_count} file{s} {failed}', fg='red')
             )
         return ', '.join(report) + '.'
 
@@ -1524,7 +1694,12 @@ def assert_equivalent(src: str, dst: str) -> None:
     try:
         src_ast = ast.parse(src)
     except Exception as exc:
-        raise AssertionError(f"cannot parse source: {exc}") from None
+        major, minor = sys.version_info[:2]
+        raise AssertionError(
+            f"cannot use --safe with this file; failed to parse source file "
+            f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
+            f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
+        )
 
     try:
         dst_ast = ast.parse(dst)