]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Implement `# fmt: off` and `# fmt: on`
[etc/vim.git] / black.py
index 0fe8fef255be119d3569071404544b153d5ce89f..dc02128313154690de62de810abf399efcdb3433 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,4 +1,5 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
+
 import asyncio
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 import asyncio
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
@@ -9,7 +10,7 @@ from pathlib import Path
 import tokenize
 import sys
 from typing import (
 import tokenize
 import sys
 from typing import (
-    Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
+    Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union
 )
 
 from attr import dataclass, Factory
 )
 
 from attr import dataclass, Factory
@@ -21,7 +22,7 @@ from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
-__version__ = "18.3a2"
+__version__ = "18.3a3"
 DEFAULT_LINE_LENGTH = 88
 # types
 syms = pygram.python_symbols
 DEFAULT_LINE_LENGTH = 88
 # types
 syms = pygram.python_symbols
@@ -43,9 +44,37 @@ class NothingChanged(UserWarning):
 class CannotSplit(Exception):
     """A readable split that fits the allotted line length is impossible.
 
 class CannotSplit(Exception):
     """A readable split that fits the allotted line length is impossible.
 
-    Raised by `left_hand_split()` and `right_hand_split()`.
+    Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`.
+    """
+
+
+class FormatError(Exception):
+    """Base fmt: on/off error.
+
+    It holds the number of bytes of the prefix consumed before the format
+    control comment appeared.
     """
 
     """
 
+    def __init__(self, consumed: int) -> None:
+        super().__init__(consumed)
+        self.consumed = consumed
+
+    def trim_prefix(self, leaf: Leaf) -> None:
+        leaf.prefix = leaf.prefix[self.consumed:]
+
+    def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
+        """Returns a new Leaf from the consumed part of the prefix."""
+        unformatted_prefix = leaf.prefix[:self.consumed]
+        return Leaf(token.NEWLINE, unformatted_prefix)
+
+
+class FormatOn(FormatError):
+    """Found a comment like `# fmt: on` in the file."""
+
+
+class FormatOff(FormatError):
+    """Found a comment like `# fmt: off` in the file."""
+
 
 @click.command()
 @click.option(
 
 @click.command()
 @click.option(
@@ -61,7 +90,7 @@ class CannotSplit(Exception):
     is_flag=True,
     help=(
         "Don't write back the files, just return the status.  Return code 0 "
     is_flag=True,
     help=(
         "Don't write back the files, just return the status.  Return code 0 "
-        "means nothing changed.  Return code 1 means some files were "
+        "means nothing would change.  Return code 1 means some files would be "
         "reformatted.  Return code 123 means there was an internal error."
     ),
 )
         "reformatted.  Return code 123 means there was an internal error."
     ),
 )
@@ -74,7 +103,9 @@ class CannotSplit(Exception):
 @click.argument(
     'src',
     nargs=-1,
 @click.argument(
     'src',
     nargs=-1,
-    type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
+    type=click.Path(
+        exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
+    ),
 )
 @click.pass_context
 def main(
 )
 @click.pass_context
 def main(
@@ -89,17 +120,24 @@ def main(
         elif p.is_file():
             # if a file was explicitly given, we don't care about its extension
             sources.append(p)
         elif p.is_file():
             # if a file was explicitly given, we don't care about its extension
             sources.append(p)
+        elif s == '-':
+            sources.append(Path('-'))
         else:
             err(f'invalid path: {s}')
     if len(sources) == 0:
         ctx.exit(0)
     elif len(sources) == 1:
         p = sources[0]
         else:
             err(f'invalid path: {s}')
     if len(sources) == 0:
         ctx.exit(0)
     elif len(sources) == 1:
         p = sources[0]
-        report = Report()
+        report = Report(check=check)
         try:
         try:
-            changed = format_file_in_place(
-                p, line_length=line_length, fast=fast, write_back=not check
-            )
+            if not p.is_file() and str(p) == '-':
+                changed = format_stdin_to_stdout(
+                    line_length=line_length, fast=fast, write_back=not check
+                )
+            else:
+                changed = format_file_in_place(
+                    p, line_length=line_length, fast=fast, write_back=not check
+                )
             report.done(p, changed)
         except Exception as exc:
             report.failed(p, str(exc))
             report.done(p, changed)
         except Exception as exc:
             report.failed(p, str(exc))
@@ -156,41 +194,59 @@ def format_file_in_place(
     src: Path, line_length: int, fast: bool, write_back: bool = False
 ) -> bool:
     """Format the file and rewrite if changed. Return True if changed."""
     src: Path, line_length: int, fast: bool, write_back: bool = False
 ) -> bool:
     """Format the file and rewrite if changed. Return True if changed."""
+    with tokenize.open(src) as src_buffer:
+        src_contents = src_buffer.read()
     try:
     try:
-        contents, encoding = format_file(src, line_length=line_length, fast=fast)
+        contents = format_file_contents(
+            src_contents, line_length=line_length, fast=fast
+        )
     except NothingChanged:
         return False
 
     if write_back:
     except NothingChanged:
         return False
 
     if write_back:
-        with open(src, "w", encoding=encoding) as f:
+        with open(src, "w", encoding=src_buffer.encoding) as f:
             f.write(contents)
     return True
 
 
             f.write(contents)
     return True
 
 
-def format_file(
-    src: Path, line_length: int, fast: bool
-) -> Tuple[FileContent, Encoding]:
+def format_stdin_to_stdout(
+    line_length: int, fast: bool, write_back: bool = False
+) -> bool:
+    """Format file on stdin and pipe output to stdout. Return True if changed."""
+    contents = sys.stdin.read()
+    try:
+        contents = format_file_contents(contents, line_length=line_length, fast=fast)
+        return True
+
+    except NothingChanged:
+        return False
+
+    finally:
+        if write_back:
+            sys.stdout.write(contents)
+
+
+def format_file_contents(
+    src_contents: str, line_length: int, fast: bool
+) -> FileContent:
     """Reformats a file and returns its contents and encoding."""
     """Reformats a file and returns its contents and encoding."""
-    with tokenize.open(src) as src_buffer:
-        src_contents = src_buffer.read()
     if src_contents.strip() == '':
     if src_contents.strip() == '':
-        raise NothingChanged(src)
+        raise NothingChanged
 
     dst_contents = format_str(src_contents, line_length=line_length)
     if src_contents == dst_contents:
 
     dst_contents = format_str(src_contents, line_length=line_length)
     if src_contents == dst_contents:
-        raise NothingChanged(src)
+        raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
         assert_stable(src_contents, dst_contents, line_length=line_length)
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
         assert_stable(src_contents, dst_contents, line_length=line_length)
-    return dst_contents, src_buffer.encoding
+    return dst_contents
 
 
 def format_str(src_contents: str, line_length: int) -> FileContent:
     """Reformats a string and returns new contents."""
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
 
 
 def format_str(src_contents: str, line_length: int) -> FileContent:
     """Reformats a string and returns new contents."""
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
-    comments: List[Line] = []
     lines = LineGenerator()
     elt = EmptyLineTracker()
     py36 = is_python36(src_node)
     lines = LineGenerator()
     elt = EmptyLineTracker()
     py36 = is_python36(src_node)
@@ -202,36 +258,41 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
         before, after = elt.maybe_empty_lines(current_line)
         for _ in range(before):
             dst_contents += str(empty_line)
         before, after = elt.maybe_empty_lines(current_line)
         for _ in range(before):
             dst_contents += str(empty_line)
-        if not current_line.is_comment:
-            for comment in comments:
-                dst_contents += str(comment)
-            comments = []
-            for line in split_line(current_line, line_length=line_length, py36=py36):
-                dst_contents += str(line)
-        else:
-            comments.append(current_line)
-    for comment in comments:
-        dst_contents += str(comment)
+        for line in split_line(current_line, line_length=line_length, py36=py36):
+            dst_contents += str(line)
     return dst_contents
 
 
     return dst_contents
 
 
+GRAMMARS = [
+    pygram.python_grammar_no_print_statement_no_exec_statement,
+    pygram.python_grammar_no_print_statement,
+    pygram.python_grammar_no_exec_statement,
+    pygram.python_grammar,
+]
+
+
 def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
 def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
-    drv = driver.Driver(grammar, pytree.convert)
     if src_txt[-1] != '\n':
         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
         src_txt += nl
     if src_txt[-1] != '\n':
         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
         src_txt += nl
-    try:
-        result = drv.parse_string(src_txt, True)
-    except ParseError as pe:
-        lineno, column = pe.context[1]
-        lines = src_txt.splitlines()
+    for grammar in GRAMMARS:
+        drv = driver.Driver(grammar, pytree.convert)
         try:
         try:
-            faulty_line = lines[lineno - 1]
-        except IndexError:
-            faulty_line = "<line number missing in source>"
-        raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
+            result = drv.parse_string(src_txt, True)
+            break
+
+        except ParseError as pe:
+            lineno, column = pe.context[1]
+            lines = src_txt.splitlines()
+            try:
+                faulty_line = lines[lineno - 1]
+            except IndexError:
+                faulty_line = "<line number missing in source>"
+            exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
+    else:
+        raise exc from None
 
     if isinstance(result, Leaf):
         result = Node(syms.file_input, [result])
 
     if isinstance(result, Leaf):
         result = Node(syms.file_input, [result])
@@ -287,6 +348,15 @@ class DebugVisitor(Visitor[T]):
                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
             out(f' {node.value!r}', fg='blue', bold=False)
 
                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
             out(f' {node.value!r}', fg='blue', bold=False)
 
+    @classmethod
+    def show(cls, code: str) -> None:
+        """Pretty-prints a given string of `code`.
+
+        Convenience method for debugging.
+        """
+        v: DebugVisitor[None] = DebugVisitor()
+        list(v.visit(lib2to3_parse(code)))
+
 
 KEYWORDS = set(keyword.kwlist)
 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
 
 KEYWORDS = set(keyword.kwlist)
 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
@@ -320,6 +390,7 @@ MATH_OPERATORS = {
     token.AMPER,
     token.PERCENT,
     token.CIRCUMFLEX,
     token.AMPER,
     token.PERCENT,
     token.CIRCUMFLEX,
+    token.TILDE,
     token.LEFTSHIFT,
     token.RIGHTSHIFT,
     token.DOUBLESTAR,
     token.LEFTSHIFT,
     token.RIGHTSHIFT,
     token.DOUBLESTAR,
@@ -385,7 +456,7 @@ class BracketTracker:
         """Returns True if there is an yet unmatched open bracket on the line."""
         return bool(self.bracket_match)
 
         """Returns True if there is an yet unmatched open bracket on the line."""
         return bool(self.bracket_match)
 
-    def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
+    def max_priority(self, exclude: Iterable[LeafID] = ()) -> int:
         """Returns the highest priority of a delimiter found on the line.
 
         Values are consistent with what `is_delimiter()` returns.
         """Returns the highest priority of a delimiter found on the line.
 
         Values are consistent with what `is_delimiter()` returns.
@@ -458,8 +529,7 @@ class Line:
         return (
             (first_leaf.type == token.NAME and first_leaf.value == 'def')
             or (
         return (
             (first_leaf.type == token.NAME and first_leaf.value == 'def')
             or (
-                first_leaf.type == token.NAME
-                and first_leaf.value == 'async'
+                first_leaf.type == token.ASYNC
                 and second_leaf is not None
                 and second_leaf.type == token.NAME
                 and second_leaf.value == 'def'
                 and second_leaf is not None
                 and second_leaf.type == token.NAME
                 and second_leaf.value == 'def'
@@ -490,10 +560,16 @@ class Line:
         ):
             return False
 
         ):
             return False
 
-        if closing.type == token.RSQB or closing.type == token.RBRACE:
+        if closing.type == token.RBRACE:
             self.leaves.pop()
             return True
 
             self.leaves.pop()
             return True
 
+        if closing.type == token.RSQB:
+            comma = self.leaves[-1]
+            if comma.parent and comma.parent.type == syms.listmaker:
+                self.leaves.pop()
+                return True
+
         # For parens let's check if it's safe to remove the comma.  If the
         # trailing one is the only one, we might mistakenly change a tuple
         # into a different type by removing the comma.
         # For parens let's check if it's safe to remove the comma.  If the
         # trailing one is the only one, we might mistakenly change a tuple
         # into a different type by removing the comma.
@@ -610,12 +686,51 @@ class Line:
         return bool(self.leaves or self.comments)
 
 
         return bool(self.leaves or self.comments)
 
 
+class UnformattedLines(Line):
+
+    def append(self, leaf: Leaf, preformatted: bool = False) -> None:
+        try:
+            list(generate_comments(leaf))
+        except FormatOn as f_on:
+            self.leaves.append(f_on.leaf_from_consumed(leaf))
+            raise
+
+        self.leaves.append(leaf)
+        if leaf.type == token.INDENT:
+            self.depth += 1
+        elif leaf.type == token.DEDENT:
+            self.depth -= 1
+
+    def append_comment(self, comment: Leaf) -> bool:
+        raise NotImplementedError("Unformatted lines don't store comments separately.")
+
+    def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
+        return False
+
+    def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
+        return False
+
+    def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
+        return False
+
+    def __str__(self) -> str:
+        if not self:
+            return '\n'
+
+        res = ''
+        for leaf in self.leaves:
+            res += str(leaf)
+        return res
+
+
 @dataclass
 class EmptyLineTracker:
     """Provides a stateful method that returns the number of potential extra
     empty lines needed before and after the currently processed line.
 
 @dataclass
 class EmptyLineTracker:
     """Provides a stateful method that returns the number of potential extra
     empty lines needed before and after the currently processed line.
 
-    Note: this tracker works on lines that haven't been split yet.
+    Note: this tracker works on lines that haven't been split yet.  It assumes
+    the prefix of the first leaf consists of optional newlines.  Those newlines
+    are consumed by `maybe_empty_lines()` and included in the computation.
     """
     previous_line: Optional[Line] = None
     previous_after: int = 0
     """
     previous_line: Optional[Line] = None
     previous_after: int = 0
@@ -628,21 +743,31 @@ class EmptyLineTracker:
         (two on module-level), as well as providing an extra empty line after flow
         control keywords to make them more prominent.
         """
         (two on module-level), as well as providing an extra empty line after flow
         control keywords to make them more prominent.
         """
-        if current_line.is_comment:
-            # Don't count standalone comments towards previous empty lines.
+        if isinstance(current_line, UnformattedLines):
             return 0, 0
 
         before, after = self._maybe_empty_lines(current_line)
             return 0, 0
 
         before, after = self._maybe_empty_lines(current_line)
+        before -= self.previous_after
         self.previous_after = after
         self.previous_line = current_line
         return before, after
 
     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
         self.previous_after = after
         self.previous_line = current_line
         return before, after
 
     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
-        before = 0
+        max_allowed = 1
+        if current_line.is_comment and current_line.depth == 0:
+            max_allowed = 2
+        if current_line.leaves:
+            # Consume the first leaf's extra newlines.
+            first_leaf = current_line.leaves[0]
+            before = first_leaf.prefix.count('\n')
+            before = min(before, max(before, max_allowed))
+            first_leaf.prefix = ''
+        else:
+            before = 0
         depth = current_line.depth
         while self.previous_defs and self.previous_defs[-1] >= depth:
             self.previous_defs.pop()
         depth = current_line.depth
         while self.previous_defs and self.previous_defs[-1] >= depth:
             self.previous_defs.pop()
-            before = (1 if depth else 2) - self.previous_after
+            before = 1 if depth else 2
         is_decorator = current_line.is_decorator
         if is_decorator or current_line.is_def or current_line.is_class:
             if not is_decorator:
         is_decorator = current_line.is_decorator
         if is_decorator or current_line.is_def or current_line.is_class:
             if not is_decorator:
@@ -658,7 +783,6 @@ class EmptyLineTracker:
             newlines = 2
             if current_line.depth:
                 newlines -= 1
             newlines = 2
             if current_line.depth:
                 newlines -= 1
-            newlines -= self.previous_after
             return newlines, 0
 
         if current_line.is_flow_control:
             return newlines, 0
 
         if current_line.is_flow_control:
@@ -690,9 +814,8 @@ class LineGenerator(Visitor[Line]):
     in ways that will no longer stringify to valid Python code on the tree.
     """
     current_line: Line = Factory(Line)
     in ways that will no longer stringify to valid Python code on the tree.
     """
     current_line: Line = Factory(Line)
-    standalone_comments: List[Leaf] = Factory(list)
 
 
-    def line(self, indent: int = 0) -> Iterator[Line]:
+    def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
         """Generate a line.
 
         If the line is empty, only emit if it makes sense.
         """Generate a line.
 
         If the line is empty, only emit if it makes sense.
@@ -701,52 +824,68 @@ class LineGenerator(Visitor[Line]):
         If any lines were generated, set up a new current_line.
         """
         if not self.current_line:
         If any lines were generated, set up a new current_line.
         """
         if not self.current_line:
-            self.current_line.depth += indent
+            if self.current_line.__class__ == type:
+                self.current_line.depth += indent
+            else:
+                self.current_line = type(depth=self.current_line.depth + indent)
             return  # Line is empty, don't emit. Creating a new one unnecessary.
 
         complete_line = self.current_line
             return  # Line is empty, don't emit. Creating a new one unnecessary.
 
         complete_line = self.current_line
-        self.current_line = Line(depth=complete_line.depth + indent)
+        self.current_line = type(depth=complete_line.depth + indent)
         yield complete_line
 
         yield complete_line
 
+    def visit(self, node: LN) -> Iterator[Line]:
+        """High-level entry point to the visitor."""
+        if isinstance(self.current_line, UnformattedLines):
+            # File contained `# fmt: off`
+            yield from self.visit_unformatted(node)
+
+        else:
+            yield from super().visit(node)
+
     def visit_default(self, node: LN) -> Iterator[Line]:
         if isinstance(node, Leaf):
     def visit_default(self, node: LN) -> Iterator[Line]:
         if isinstance(node, Leaf):
-            for comment in generate_comments(node):
-                if self.current_line.bracket_tracker.any_open_brackets():
-                    # any comment within brackets is subject to splitting
-                    self.current_line.append(comment)
-                elif comment.type == token.COMMENT:
-                    # regular trailing comment
-                    self.current_line.append(comment)
-                    yield from self.line()
-
-                else:
-                    # regular standalone comment, to be processed later (see
-                    # docstring in `generate_comments()`
-                    self.standalone_comments.append(comment)
-            normalize_prefix(node)
-            if node.type not in WHITESPACE:
-                for comment in self.standalone_comments:
-                    yield from self.line()
-
-                    self.current_line.append(comment)
-                    yield from self.line()
-
-                self.standalone_comments = []
-                self.current_line.append(node)
+            any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
+            try:
+                for comment in generate_comments(node):
+                    if any_open_brackets:
+                        # any comment within brackets is subject to splitting
+                        self.current_line.append(comment)
+                    elif comment.type == token.COMMENT:
+                        # regular trailing comment
+                        self.current_line.append(comment)
+                        yield from self.line()
+
+                    else:
+                        # regular standalone comment
+                        yield from self.line()
+
+                        self.current_line.append(comment)
+                        yield from self.line()
+
+            except FormatOff as f_off:
+                f_off.trim_prefix(node)
+                yield from self.line(type=UnformattedLines)
+                yield from self.visit(node)
+
+            except FormatOn as f_on:
+                # This only happens here if somebody says "fmt: on" multiple
+                # times in a row.
+                f_on.trim_prefix(node)
+                yield from self.visit_default(node)
+
+            else:
+                normalize_prefix(node, inside_brackets=any_open_brackets)
+                if node.type not in WHITESPACE:
+                    self.current_line.append(node)
         yield from super().visit_default(node)
 
         yield from super().visit_default(node)
 
-    def visit_suite(self, node: Node) -> Iterator[Line]:
-        """Body of a statement after a colon."""
-        children = iter(node.children)
-        # Process newline before indenting.  It might contain an inline
-        # comment that should go right after the colon.
-        newline = next(children)
-        yield from self.visit(newline)
+    def visit_INDENT(self, node: Node) -> Iterator[Line]:
         yield from self.line(+1)
         yield from self.line(+1)
+        yield from self.visit_default(node)
 
 
-        for child in children:
-            yield from self.visit(child)
-
+    def visit_DEDENT(self, node: Node) -> Iterator[Line]:
+        # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
         yield from self.line(-1)
 
     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
         yield from self.line(-1)
 
     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
@@ -780,7 +919,7 @@ class LineGenerator(Visitor[Line]):
         for child in children:
             yield from self.visit(child)
 
         for child in children:
             yield from self.visit(child)
 
-            if child.type == token.NAME and child.value == 'async':  # type: ignore
+            if child.type == token.ASYNC:
                 break
 
         internal_stmt = next(children)
                 break
 
         internal_stmt = next(children)
@@ -799,6 +938,19 @@ class LineGenerator(Visitor[Line]):
         yield from self.visit_default(leaf)
         yield from self.line()
 
         yield from self.visit_default(leaf)
         yield from self.line()
 
+    def visit_unformatted(self, node: LN) -> Iterator[Line]:
+        if isinstance(node, Node):
+            for child in node.children:
+                yield from self.visit(child)
+
+        else:
+            try:
+                self.current_line.append(node)
+            except FormatOn as f_on:
+                f_on.trim_prefix(node)
+                yield from self.line()
+                yield from self.visit(node)
+
     def __attrs_post_init__(self) -> None:
         """You are in a twisty little maze of passages."""
         v = self.visit_stmt
     def __attrs_post_init__(self) -> None:
         """You are in a twisty little maze of passages."""
         v = self.visit_stmt
@@ -818,7 +970,7 @@ BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.R
 OPENING_BRACKETS = set(BRACKET.keys())
 CLOSING_BRACKETS = set(BRACKET.values())
 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
 OPENING_BRACKETS = set(BRACKET.keys())
 CLOSING_BRACKETS = set(BRACKET.values())
 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
-ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
+ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
 
 def whitespace(leaf: Leaf) -> str:  # noqa C901
 
 
 def whitespace(leaf: Leaf) -> str:  # noqa C901
@@ -836,37 +988,62 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
         return DOUBLESPACE
 
     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
         return DOUBLESPACE
 
     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
+    if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
+        return NO
+
     prev = leaf.prev_sibling
     if not prev:
         prevp = preceding_leaf(p)
         if not prevp or prevp.type in OPENING_BRACKETS:
             return NO
 
     prev = leaf.prev_sibling
     if not prev:
         prevp = preceding_leaf(p)
         if not prevp or prevp.type in OPENING_BRACKETS:
             return NO
 
+        if t == token.COLON:
+            return SPACE if prevp.type == token.COMMA else NO
+
         if prevp.type == token.EQUAL:
         if prevp.type == token.EQUAL:
-            if prevp.parent and prevp.parent.type in {
-                syms.typedargslist,
-                syms.varargslist,
-                syms.parameters,
-                syms.arglist,
-                syms.argument,
-            }:
-                return NO
+            if prevp.parent:
+                if prevp.parent.type in {
+                    syms.arglist, syms.argument, syms.parameters, syms.varargslist
+                }:
+                    return NO
+
+                elif prevp.parent.type == syms.typedargslist:
+                    # A bit hacky: if the equal sign has whitespace, it means we
+                    # previously found it's a typed argument.  So, we're using
+                    # that, too.
+                    return prevp.prefix
 
         elif prevp.type == token.DOUBLESTAR:
             if prevp.parent and prevp.parent.type in {
 
         elif prevp.type == token.DOUBLESTAR:
             if prevp.parent and prevp.parent.type in {
-                syms.typedargslist,
-                syms.varargslist,
-                syms.parameters,
                 syms.arglist,
                 syms.arglist,
+                syms.argument,
                 syms.dictsetmaker,
                 syms.dictsetmaker,
+                syms.parameters,
+                syms.typedargslist,
+                syms.varargslist,
             }:
                 return NO
 
         elif prevp.type == token.COLON:
             }:
                 return NO
 
         elif prevp.type == token.COLON:
-            if prevp.parent and prevp.parent.type == syms.subscript:
+            if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
                 return NO
 
                 return NO
 
-        elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
+        elif (
+            prevp.parent
+            and prevp.parent.type in {syms.factor, syms.star_expr}
+            and prevp.type in MATH_OPERATORS
+        ):
+            return NO
+
+        elif (
+            prevp.type == token.RIGHTSHIFT
+            and prevp.parent
+            and prevp.parent.type == syms.shift_expr
+            and prevp.prev_sibling
+            and prevp.prev_sibling.type == token.NAME
+            and prevp.prev_sibling.value == 'print'  # type: ignore
+        ):
+            # Python 2 print chevron
             return NO
 
     elif prev.type in OPENING_BRACKETS:
             return NO
 
     elif prev.type in OPENING_BRACKETS:
@@ -880,7 +1057,7 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
         if not prev or prev.type != token.COMMA:
             return NO
 
         if not prev or prev.type != token.COMMA:
             return NO
 
-    if p.type == syms.varargslist:
+    elif p.type == syms.varargslist:
         # lambdas
         if t == token.RPAR:
             return NO
         # lambdas
         if t == token.RPAR:
             return NO
@@ -970,7 +1147,7 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
 
             return NO
 
 
             return NO
 
-        elif prev.type == token.COLON:
+        else:
             return NO
 
     elif p.type == syms.atom:
             return NO
 
     elif p.type == syms.atom:
@@ -1091,30 +1268,49 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
     are emitted with a fake STANDALONE_COMMENT token identifier.
     """
     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
     are emitted with a fake STANDALONE_COMMENT token identifier.
     """
-    if not leaf.prefix:
+    p = leaf.prefix
+    if not p:
         return
 
         return
 
-    if '#' not in leaf.prefix:
+    if '#' not in p:
         return
 
         return
 
-    before_comment, content = leaf.prefix.split('#', 1)
-    content = content.rstrip()
-    if content and (content[0] not in {' ', '!', '#'}):
-        content = ' ' + content
-    is_standalone_comment = (
-        '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
-    )
-    if not is_standalone_comment:
-        # simple trailing comment
-        yield Leaf(token.COMMENT, value='#' + content)
-        return
-
-    for line in ('#' + content).split('\n'):
+    consumed = 0
+    nlines = 0
+    for index, line in enumerate(p.split('\n')):
+        consumed += len(line) + 1  # adding the length of the split '\n'
         line = line.lstrip()
         line = line.lstrip()
+        if not line:
+            nlines += 1
         if not line.startswith('#'):
             continue
 
         if not line.startswith('#'):
             continue
 
-        yield Leaf(STANDALONE_COMMENT, line)
+        if index == 0 and leaf.type != token.ENDMARKER:
+            comment_type = token.COMMENT  # simple trailing comment
+        else:
+            comment_type = STANDALONE_COMMENT
+        comment = make_comment(line)
+        yield Leaf(comment_type, comment, prefix='\n' * nlines)
+
+        if comment in {'# fmt: on', '# yapf: enable'}:
+            raise FormatOn(consumed)
+
+        if comment in {'# fmt: off', '# yapf: disable'}:
+            raise FormatOff(consumed)
+
+        nlines = 0
+
+
+def make_comment(content: str) -> str:
+    content = content.rstrip()
+    if not content:
+        return '#'
+
+    if content[0] == '#':
+        content = content[1:]
+    if content and content[0] not in ' !:#':
+        content = ' ' + content
+    return '#' + content
 
 
 def split_line(
 
 
 def split_line(
@@ -1130,6 +1326,10 @@ def split_line(
     If `py36` is True, splitting may generate syntax that is only compatible
     with Python 3.6 and later.
     """
     If `py36` is True, splitting may generate syntax that is only compatible
     with Python 3.6 and later.
     """
+    if isinstance(line, UnformattedLines):
+        yield line
+        return
+
     line_str = str(line).strip('\n')
     if len(line_str) <= line_length and '\n' not in line_str:
         yield line
     line_str = str(line).strip('\n')
     if len(line_str) <= line_length and '\n' not in line_str:
         yield line
@@ -1197,7 +1397,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
                 current_leaves = body_leaves
     # Since body is a new indent level, remove spurious leading whitespace.
     if body_leaves:
                 current_leaves = body_leaves
     # Since body is a new indent level, remove spurious leading whitespace.
     if body_leaves:
-        normalize_prefix(body_leaves[0])
+        normalize_prefix(body_leaves[0], inside_brackets=True)
     # Build the new lines.
     for result, leaves in (
         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
     # Build the new lines.
     for result, leaves in (
         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
@@ -1237,7 +1437,7 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
     head_leaves.reverse()
     # Since body is a new indent level, remove spurious leading whitespace.
     if body_leaves:
     head_leaves.reverse()
     # Since body is a new indent level, remove spurious leading whitespace.
     if body_leaves:
-        normalize_prefix(body_leaves[0])
+        normalize_prefix(body_leaves[0], inside_brackets=True)
     # Build the new lines.
     for result, leaves in (
         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
     # Build the new lines.
     for result, leaves in (
         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
@@ -1301,7 +1501,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
             trailing_comma_safe = trailing_comma_safe and py36
         leaf_priority = delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             trailing_comma_safe = trailing_comma_safe and py36
         leaf_priority = delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
-            normalize_prefix(current_line.leaves[0])
+            normalize_prefix(current_line.leaves[0], inside_brackets=True)
             yield current_line
 
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
             yield current_line
 
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
@@ -1312,7 +1512,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
             and trailing_comma_safe
         ):
             current_line.append(Leaf(token.COMMA, ','))
             and trailing_comma_safe
         ):
             current_line.append(Leaf(token.COMMA, ','))
-        normalize_prefix(current_line.leaves[0])
+        normalize_prefix(current_line.leaves[0], inside_brackets=True)
         yield current_line
 
 
         yield current_line
 
 
@@ -1330,16 +1530,20 @@ def is_import(leaf: Leaf) -> bool:
     )
 
 
     )
 
 
-def normalize_prefix(leaf: Leaf) -> None:
-    """Leave existing extra newlines for imports.  Remove everything else."""
-    if is_import(leaf):
-        spl = leaf.prefix.split('#', 1)
-        nl_count = spl[0].count('\n')
-        if len(spl) > 1:
-            # Skip one newline since it was for a standalone comment.
-            nl_count -= 1
-        leaf.prefix = '\n' * nl_count
-        return
+def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
+    """Leave existing extra newlines if not `inside_brackets`.
+
+    Remove everything else.  Note: don't use backslashes for formatting or
+    you'll lose your voting rights.
+    """
+    if not inside_brackets:
+        spl = leaf.prefix.split('#')
+        if '\\' not in spl[0]:
+            nl_count = spl[-1].count('\n')
+            if len(spl) > 1:
+                nl_count -= 1
+            leaf.prefix = '\n' * nl_count
+            return
 
     leaf.prefix = ''
 
 
     leaf.prefix = ''
 
@@ -1390,6 +1594,7 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
 @dataclass
 class Report:
     """Provides a reformatting counter."""
 @dataclass
 class Report:
     """Provides a reformatting counter."""
+    check: bool = False
     change_count: int = 0
     same_count: int = 0
     failure_count: int = 0
     change_count: int = 0
     same_count: int = 0
     failure_count: int = 0
@@ -1397,7 +1602,8 @@ class Report:
     def done(self, src: Path, changed: bool) -> None:
         """Increment the counter for successful reformatting. Write out a message."""
         if changed:
     def done(self, src: Path, changed: bool) -> None:
         """Increment the counter for successful reformatting. Write out a message."""
         if changed:
-            out(f'reformatted {src}')
+            reformatted = 'would reformat' if self.check else 'reformatted'
+            out(f'{reformatted} {src}')
             self.change_count += 1
         else:
             out(f'{src} already well formatted, good job.', bold=False)
             self.change_count += 1
         else:
             out(f'{src} already well formatted, good job.', bold=False)
@@ -1416,7 +1622,7 @@ class Report:
         if self.failure_count:
             return 123
 
         if self.failure_count:
             return 123
 
-        elif self.change_count:
+        elif self.change_count and self.check:
             return 1
 
         return 0
             return 1
 
         return 0
@@ -1426,21 +1632,27 @@ class Report:
 
         Use `click.unstyle` to remove colors.
         """
 
         Use `click.unstyle` to remove colors.
         """
+        if self.check:
+            reformatted = "would be reformatted"
+            unchanged = "would be left unchanged"
+            failed = "would fail to reformat"
+        else:
+            reformatted = "reformatted"
+            unchanged = "left unchanged"
+            failed = "failed to reformat"
         report = []
         if self.change_count:
             s = 's' if self.change_count > 1 else ''
             report.append(
         report = []
         if self.change_count:
             s = 's' if self.change_count > 1 else ''
             report.append(
-                click.style(f'{self.change_count} file{s} reformatted', bold=True)
+                click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
             )
         if self.same_count:
             s = 's' if self.same_count > 1 else ''
             )
         if self.same_count:
             s = 's' if self.same_count > 1 else ''
-            report.append(f'{self.same_count} file{s} left unchanged')
+            report.append(f'{self.same_count} file{s} {unchanged}')
         if self.failure_count:
             s = 's' if self.failure_count > 1 else ''
             report.append(
         if self.failure_count:
             s = 's' if self.failure_count > 1 else ''
             report.append(
-                click.style(
-                    f'{self.failure_count} file{s} failed to reformat', fg='red'
-                )
+                click.style(f'{self.failure_count} file{s} {failed}', fg='red')
             )
         return ', '.join(report) + '.'
 
             )
         return ', '.join(report) + '.'
 
@@ -1482,7 +1694,12 @@ def assert_equivalent(src: str, dst: str) -> None:
     try:
         src_ast = ast.parse(src)
     except Exception as exc:
     try:
         src_ast = ast.parse(src)
     except Exception as exc:
-        raise AssertionError(f"cannot parse source: {exc}") from None
+        major, minor = sys.version_info[:2]
+        raise AssertionError(
+            f"cannot use --safe with this file; failed to parse source file "
+            f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
+            f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
+        )
 
     try:
         dst_ast = ast.parse(dst)
 
     try:
         dst_ast = ast.parse(dst)