]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

use conda for rtd
[etc/vim.git] / black.py
index 89155f60103a32e2f89cd86ac1e6c92aa1bed22f..dab3f004ff3289711dc26006b533844ad2268bac 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,4 +1,5 @@
 #!/usr/bin/env python3
+
 import asyncio
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
@@ -9,7 +10,7 @@ from pathlib import Path
 import tokenize
 import sys
 from typing import (
-    Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
+    Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union
 )
 
 from attr import dataclass, Factory
@@ -21,7 +22,7 @@ from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
-__version__ = "18.3a2"
+__version__ = "18.3a4"
 DEFAULT_LINE_LENGTH = 88
 # types
 syms = pygram.python_symbols
@@ -37,16 +38,45 @@ err = partial(click.secho, fg='red', err=True)
 
 
 class NothingChanged(UserWarning):
-    """Raised by `format_file` when the reformatted code is the same as source."""
+    """Raised by :func:`format_file` when reformatted code is the same as source."""
 
 
 class CannotSplit(Exception):
     """A readable split that fits the allotted line length is impossible.
 
-    Raised by `left_hand_split()` and `right_hand_split()`.
+    Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
+    :func:`delimiter_split`.
     """
 
 
+class FormatError(Exception):
+    """Base exception for `# fmt: on` and `# fmt: off` handling.
+
+    It holds the number of bytes of the prefix consumed before the format
+    control comment appeared.
+    """
+
+    def __init__(self, consumed: int) -> None:
+        super().__init__(consumed)
+        self.consumed = consumed
+
+    def trim_prefix(self, leaf: Leaf) -> None:
+        leaf.prefix = leaf.prefix[self.consumed:]
+
+    def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
+        """Returns a new Leaf from the consumed part of the prefix."""
+        unformatted_prefix = leaf.prefix[:self.consumed]
+        return Leaf(token.NEWLINE, unformatted_prefix)
+
+
+class FormatOn(FormatError):
+    """Found a comment like `# fmt: on` in the file."""
+
+
+class FormatOff(FormatError):
+    """Found a comment like `# fmt: off` in the file."""
+
+
 @click.command()
 @click.option(
     '-l',
@@ -61,7 +91,7 @@ class CannotSplit(Exception):
     is_flag=True,
     help=(
         "Don't write back the files, just return the status.  Return code 0 "
-        "means nothing changed.  Return code 1 means some files were "
+        "means nothing would change.  Return code 1 means some files would be "
         "reformatted.  Return code 123 means there was an internal error."
     ),
 )
@@ -74,7 +104,9 @@ class CannotSplit(Exception):
 @click.argument(
     'src',
     nargs=-1,
-    type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
+    type=click.Path(
+        exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
+    ),
 )
 @click.pass_context
 def main(
@@ -89,17 +121,24 @@ def main(
         elif p.is_file():
             # if a file was explicitly given, we don't care about its extension
             sources.append(p)
+        elif s == '-':
+            sources.append(Path('-'))
         else:
             err(f'invalid path: {s}')
     if len(sources) == 0:
         ctx.exit(0)
     elif len(sources) == 1:
         p = sources[0]
-        report = Report()
+        report = Report(check=check)
         try:
-            changed = format_file_in_place(
-                p, line_length=line_length, fast=fast, write_back=not check
-            )
+            if not p.is_file() and str(p) == '-':
+                changed = format_stdin_to_stdout(
+                    line_length=line_length, fast=fast, write_back=not check
+                )
+            else:
+                changed = format_file_in_place(
+                    p, line_length=line_length, fast=fast, write_back=not check
+                )
             report.done(p, changed)
         except Exception as exc:
             report.failed(p, str(exc))
@@ -127,6 +166,13 @@ async def schedule_formatting(
     loop: BaseEventLoop,
     executor: Executor,
 ) -> int:
+    """Run formatting of `sources` in parallel using the provided `executor`.
+
+    (Use ProcessPoolExecutors for actual parallelism.)
+
+    `line_length`, `write_back`, and `fast` options are passed to
+    :func:`format_file_in_place`.
+    """
     tasks = {
         src: loop.run_in_executor(
             executor, format_file_in_place, src, line_length, fast, write_back
@@ -135,7 +181,7 @@ async def schedule_formatting(
     }
     await asyncio.wait(tasks.values())
     cancelled = []
-    report = Report()
+    report = Report(check=not write_back)
     for src, task in tasks.items():
         if not task.done():
             report.failed(src, 'timed out, cancelling')
@@ -155,42 +201,76 @@ async def schedule_formatting(
 def format_file_in_place(
     src: Path, line_length: int, fast: bool, write_back: bool = False
 ) -> bool:
-    """Format the file and rewrite if changed. Return True if changed."""
+    """Format file under `src` path. Return True if changed.
+
+    If `write_back` is True, write reformatted code back to stdout.
+    `line_length` and `fast` options are passed to :func:`format_file_contents`.
+    """
+    with tokenize.open(src) as src_buffer:
+        src_contents = src_buffer.read()
     try:
-        contents, encoding = format_file(src, line_length=line_length, fast=fast)
+        contents = format_file_contents(
+            src_contents, line_length=line_length, fast=fast
+        )
     except NothingChanged:
         return False
 
     if write_back:
-        with open(src, "w", encoding=encoding) as f:
+        with open(src, "w", encoding=src_buffer.encoding) as f:
             f.write(contents)
     return True
 
 
-def format_file(
-    src: Path, line_length: int, fast: bool
-) -> Tuple[FileContent, Encoding]:
-    """Reformats a file and returns its contents and encoding."""
-    with tokenize.open(src) as src_buffer:
-        src_contents = src_buffer.read()
+def format_stdin_to_stdout(
+    line_length: int, fast: bool, write_back: bool = False
+) -> bool:
+    """Format file on stdin. Return True if changed.
+
+    If `write_back` is True, write reformatted code back to stdout.
+    `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
+    """
+    contents = sys.stdin.read()
+    try:
+        contents = format_file_contents(contents, line_length=line_length, fast=fast)
+        return True
+
+    except NothingChanged:
+        return False
+
+    finally:
+        if write_back:
+            sys.stdout.write(contents)
+
+
+def format_file_contents(
+    src_contents: str, line_length: int, fast: bool
+) -> FileContent:
+    """Reformat contents a file and return new contents.
+
+    If `fast` is False, additionally confirm that the reformatted code is
+    valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
+    `line_length` is passed to :func:`format_str`.
+    """
     if src_contents.strip() == '':
-        raise NothingChanged(src)
+        raise NothingChanged
 
     dst_contents = format_str(src_contents, line_length=line_length)
     if src_contents == dst_contents:
-        raise NothingChanged(src)
+        raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
         assert_stable(src_contents, dst_contents, line_length=line_length)
-    return dst_contents, src_buffer.encoding
+    return dst_contents
 
 
 def format_str(src_contents: str, line_length: int) -> FileContent:
-    """Reformats a string and returns new contents."""
+    """Reformat a string and return new contents.
+
+    `line_length` determines how many characters per line are allowed.
+    """
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
-    comments: List[Line] = []
     lines = LineGenerator()
     elt = EmptyLineTracker()
     py36 = is_python36(src_node)
@@ -202,41 +282,41 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
         before, after = elt.maybe_empty_lines(current_line)
         for _ in range(before):
             dst_contents += str(empty_line)
-        if not current_line.is_comment:
-            for comment in comments:
-                dst_contents += str(comment)
-            comments = []
-            for line in split_line(current_line, line_length=line_length, py36=py36):
-                dst_contents += str(line)
-        else:
-            comments.append(current_line)
-    if comments:
-        if elt.previous_defs:
-            # Separate postscriptum comments from the last module-level def.
-            dst_contents += str(empty_line)
-            dst_contents += str(empty_line)
-        for comment in comments:
-            dst_contents += str(comment)
+        for line in split_line(current_line, line_length=line_length, py36=py36):
+            dst_contents += str(line)
     return dst_contents
 
 
+GRAMMARS = [
+    pygram.python_grammar_no_print_statement_no_exec_statement,
+    pygram.python_grammar_no_print_statement,
+    pygram.python_grammar_no_exec_statement,
+    pygram.python_grammar,
+]
+
+
 def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
-    drv = driver.Driver(grammar, pytree.convert)
     if src_txt[-1] != '\n':
         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
         src_txt += nl
-    try:
-        result = drv.parse_string(src_txt, True)
-    except ParseError as pe:
-        lineno, column = pe.context[1]
-        lines = src_txt.splitlines()
+    for grammar in GRAMMARS:
+        drv = driver.Driver(grammar, pytree.convert)
         try:
-            faulty_line = lines[lineno - 1]
-        except IndexError:
-            faulty_line = "<line number missing in source>"
-        raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
+            result = drv.parse_string(src_txt, True)
+            break
+
+        except ParseError as pe:
+            lineno, column = pe.context[1]
+            lines = src_txt.splitlines()
+            try:
+                faulty_line = lines[lineno - 1]
+            except IndexError:
+                faulty_line = "<line number missing in source>"
+            exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
+    else:
+        raise exc from None
 
     if isinstance(result, Leaf):
         result = Node(syms.file_input, [result])
@@ -253,9 +333,18 @@ T = TypeVar('T')
 
 
 class Visitor(Generic[T]):
-    """Basic lib2to3 visitor that yields things on visiting."""
+    """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
 
     def visit(self, node: LN) -> Iterator[T]:
+        """Main method to visit `node` and its children.
+
+        It tries to find a `visit_*()` method for the given `node.type`, like
+        `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
+        If no dedicated `visit_*()` method is found, chooses `visit_default()`
+        instead.
+
+        Then yields objects of type `T` from the selected visitor.
+        """
         if node.type < 256:
             name = token.tok_name[node.type]
         else:
@@ -263,6 +352,7 @@ class Visitor(Generic[T]):
         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
 
     def visit_default(self, node: LN) -> Iterator[T]:
+        """Default `visit_*()` implementation. Recurses to children of `node`."""
         if isinstance(node, Node):
             for child in node.children:
                 yield from self.visit(child)
@@ -292,6 +382,15 @@ class DebugVisitor(Visitor[T]):
                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
             out(f' {node.value!r}', fg='blue', bold=False)
 
+    @classmethod
+    def show(cls, code: str) -> None:
+        """Pretty-print the lib2to3 AST of a given string of `code`.
+
+        Convenience method for debugging.
+        """
+        v: DebugVisitor[None] = DebugVisitor()
+        list(v.visit(lib2to3_parse(code)))
+
 
 KEYWORDS = set(keyword.kwlist)
 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
@@ -325,6 +424,7 @@ MATH_OPERATORS = {
     token.AMPER,
     token.PERCENT,
     token.CIRCUMFLEX,
+    token.TILDE,
     token.LEFTSHIFT,
     token.RIGHTSHIFT,
     token.DOUBLESTAR,
@@ -340,12 +440,28 @@ MATH_PRIORITY = 1
 
 @dataclass
 class BracketTracker:
+    """Keeps track of brackets on a line."""
+
     depth: int = 0
     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
     delimiters: Dict[LeafID, Priority] = Factory(dict)
     previous: Optional[Leaf] = None
 
     def mark(self, leaf: Leaf) -> None:
+        """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
+
+        All leaves receive an int `bracket_depth` field that stores how deep
+        within brackets a given leaf is. 0 means there are no enclosing brackets
+        that started on this line.
+
+        If a leaf is itself a closing bracket, it receives an `opening_bracket`
+        field that it forms a pair with. This is a one-directional link to
+        avoid reference cycles.
+
+        If a leaf is a delimiter (a token on which Black can split the line if
+        needed) and it's on depth 0, its `id()` is stored in the tracker's
+        `delimiters` field.
+        """
         if leaf.type == token.COMMENT:
             return
 
@@ -387,11 +503,11 @@ class BracketTracker:
         self.previous = leaf
 
     def any_open_brackets(self) -> bool:
-        """Returns True if there is an yet unmatched open bracket on the line."""
+        """Return True if there is an yet unmatched open bracket on the line."""
         return bool(self.bracket_match)
 
-    def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
-        """Returns the highest priority of a delimiter found on the line.
+    def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
+        """Return the highest priority of a delimiter found on the line.
 
         Values are consistent with what `is_delimiter()` returns.
         """
@@ -400,6 +516,8 @@ class BracketTracker:
 
 @dataclass
 class Line:
+    """Holds leaves and comments. Can be printed with `str(line)`."""
+
     depth: int = 0
     leaves: List[Leaf] = Factory(list)
     comments: Dict[LeafID, Leaf] = Factory(dict)
@@ -409,6 +527,15 @@ class Line:
     _for_loop_variable: bool = False
 
     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
+        """Add a new `leaf` to the end of the line.
+
+        Unless `preformatted` is True, the `leaf` will receive a new consistent
+        whitespace prefix and metadata applied by :class:`BracketTracker`.
+        Trailing commas are maybe removed, unpacked for loop variables are
+        demoted from being delimiters.
+
+        Inline comments are put aside.
+        """
         has_value = leaf.value.strip()
         if not has_value:
             return
@@ -430,18 +557,22 @@ class Line:
 
     @property
     def is_comment(self) -> bool:
+        """Is this line a standalone comment?"""
         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
 
     @property
     def is_decorator(self) -> bool:
+        """Is this line a decorator?"""
         return bool(self) and self.leaves[0].type == token.AT
 
     @property
     def is_import(self) -> bool:
+        """Is this an import line?"""
         return bool(self) and is_import(self.leaves[0])
 
     @property
     def is_class(self) -> bool:
+        """Is this line a class definition?"""
         return (
             bool(self)
             and self.leaves[0].type == token.NAME
@@ -450,7 +581,7 @@ class Line:
 
     @property
     def is_def(self) -> bool:
-        """Also returns True for async defs."""
+        """Is this a function definition? (Also returns True for async defs.)"""
         try:
             first_leaf = self.leaves[0]
         except IndexError:
@@ -463,8 +594,7 @@ class Line:
         return (
             (first_leaf.type == token.NAME and first_leaf.value == 'def')
             or (
-                first_leaf.type == token.NAME
-                and first_leaf.value == 'async'
+                first_leaf.type == token.ASYNC
                 and second_leaf is not None
                 and second_leaf.type == token.NAME
                 and second_leaf.value == 'def'
@@ -473,6 +603,10 @@ class Line:
 
     @property
     def is_flow_control(self) -> bool:
+        """Is this line a flow control statement?
+
+        Those are `return`, `raise`, `break`, and `continue`.
+        """
         return (
             bool(self)
             and self.leaves[0].type == token.NAME
@@ -481,6 +615,7 @@ class Line:
 
     @property
     def is_yield(self) -> bool:
+        """Is this line a yield statement?"""
         return (
             bool(self)
             and self.leaves[0].type == token.NAME
@@ -488,6 +623,7 @@ class Line:
         )
 
     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
+        """Remove trailing comma if there is one and it's safe."""
         if not (
             self.leaves
             and self.leaves[-1].type == token.COMMA
@@ -495,10 +631,16 @@ class Line:
         ):
             return False
 
-        if closing.type == token.RSQB or closing.type == token.RBRACE:
+        if closing.type == token.RBRACE:
             self.leaves.pop()
             return True
 
+        if closing.type == token.RSQB:
+            comma = self.leaves[-1]
+            if comma.parent and comma.parent.type == syms.listmaker:
+                self.leaves.pop()
+                return True
+
         # For parens let's check if it's safe to remove the comma.  If the
         # trailing one is the only one, we might mistakenly change a tuple
         # into a different type by removing the comma.
@@ -532,8 +674,8 @@ class Line:
     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
         """In a for loop, or comprehension, the variables are often unpacks.
 
-        To avoid splitting on the comma in this situation, we will increase
-        the depth of tokens between `for` and `in`.
+        To avoid splitting on the comma in this situation, increase the depth of
+        tokens between `for` and `in`.
         """
         if leaf.type == token.NAME and leaf.value == 'for':
             self.has_for = True
@@ -544,7 +686,7 @@ class Line:
         return False
 
     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
-        # See `maybe_increment_for_loop_variable` above for explanation.
+        """See `maybe_increment_for_loop_variable` above for explanation."""
         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
             self.bracket_tracker.depth -= 1
             self._for_loop_variable = False
@@ -572,6 +714,7 @@ class Line:
         return self.append_comment(comment)
 
     def append_comment(self, comment: Leaf) -> bool:
+        """Add an inline comment to the line."""
         if comment.type != token.COMMENT:
             return False
 
@@ -590,6 +733,7 @@ class Line:
             return True
 
     def last_non_delimiter(self) -> Leaf:
+        """Return the last non-delimiter on the line. Raise LookupError otherwise."""
         for i in range(len(self.leaves)):
             last = self.leaves[-i - 1]
             if not is_delimiter(last):
@@ -598,6 +742,7 @@ class Line:
         raise LookupError("No non-delimiters found")
 
     def __str__(self) -> str:
+        """Render the line."""
         if not self:
             return '\n'
 
@@ -612,9 +757,63 @@ class Line:
         return res + '\n'
 
     def __bool__(self) -> bool:
+        """Return True if the line has leaves or comments."""
         return bool(self.leaves or self.comments)
 
 
+class UnformattedLines(Line):
+    """Just like :class:`Line` but stores lines which aren't reformatted."""
+
+    def append(self, leaf: Leaf, preformatted: bool = True) -> None:
+        """Just add a new `leaf` to the end of the lines.
+
+        The `preformatted` argument is ignored.
+
+        Keeps track of indentation `depth`, which is useful when the user
+        says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
+        """
+        try:
+            list(generate_comments(leaf))
+        except FormatOn as f_on:
+            self.leaves.append(f_on.leaf_from_consumed(leaf))
+            raise
+
+        self.leaves.append(leaf)
+        if leaf.type == token.INDENT:
+            self.depth += 1
+        elif leaf.type == token.DEDENT:
+            self.depth -= 1
+
+    def __str__(self) -> str:
+        """Render unformatted lines from leaves which were added with `append()`.
+
+        `depth` is not used for indentation in this case.
+        """
+        if not self:
+            return '\n'
+
+        res = ''
+        for leaf in self.leaves:
+            res += str(leaf)
+        return res
+
+    def append_comment(self, comment: Leaf) -> bool:
+        """Not implemented in this class. Raises `NotImplementedError`."""
+        raise NotImplementedError("Unformatted lines don't store comments separately.")
+
+    def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
+        """Does nothing and returns False."""
+        return False
+
+    def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
+        """Does nothing and returns False."""
+        return False
+
+    def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
+        """Does nothing and returns False."""
+        return False
+
+
 @dataclass
 class EmptyLineTracker:
     """Provides a stateful method that returns the number of potential extra
@@ -629,14 +828,13 @@ class EmptyLineTracker:
     previous_defs: List[int] = Factory(list)
 
     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
-        """Returns the number of extra empty lines before and after the `current_line`.
+        """Return the number of extra empty lines before and after the `current_line`.
 
-        This is for separating `def`, `async def` and `class` with extra empty lines
-        (two on module-level), as well as providing an extra empty line after flow
-        control keywords to make them more prominent.
+        This is for separating `def`, `async def` and `class` with extra empty
+        lines (two on module-level), as well as providing an extra empty line
+        after flow control keywords to make them more prominent.
         """
-        if current_line.is_comment:
-            # Don't count standalone comments towards previous empty lines.
+        if isinstance(current_line, UnformattedLines):
             return 0, 0
 
         before, after = self._maybe_empty_lines(current_line)
@@ -646,10 +844,14 @@ class EmptyLineTracker:
         return before, after
 
     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
+        max_allowed = 1
+        if current_line.depth == 0:
+            max_allowed = 2
         if current_line.leaves:
             # Consume the first leaf's extra newlines.
             first_leaf = current_line.leaves[0]
-            before = int('\n' in first_leaf.prefix)
+            before = first_leaf.prefix.count('\n')
+            before = min(before, max_allowed)
             first_leaf.prefix = ''
         else:
             before = 0
@@ -703,9 +905,8 @@ class LineGenerator(Visitor[Line]):
     in ways that will no longer stringify to valid Python code on the tree.
     """
     current_line: Line = Factory(Line)
-    standalone_comments: List[Leaf] = Factory(list)
 
-    def line(self, indent: int = 0) -> Iterator[Line]:
+    def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
         """Generate a line.
 
         If the line is empty, only emit if it makes sense.
@@ -714,59 +915,85 @@ class LineGenerator(Visitor[Line]):
         If any lines were generated, set up a new current_line.
         """
         if not self.current_line:
-            self.current_line.depth += indent
+            if self.current_line.__class__ == type:
+                self.current_line.depth += indent
+            else:
+                self.current_line = type(depth=self.current_line.depth + indent)
             return  # Line is empty, don't emit. Creating a new one unnecessary.
 
         complete_line = self.current_line
-        self.current_line = Line(depth=complete_line.depth + indent)
+        self.current_line = type(depth=complete_line.depth + indent)
         yield complete_line
 
+    def visit(self, node: LN) -> Iterator[Line]:
+        """Main method to visit `node` and its children.
+
+        Yields :class:`Line` objects.
+        """
+        if isinstance(self.current_line, UnformattedLines):
+            # File contained `# fmt: off`
+            yield from self.visit_unformatted(node)
+
+        else:
+            yield from super().visit(node)
+
     def visit_default(self, node: LN) -> Iterator[Line]:
+        """Default `visit_*()` implementation. Recurses to children of `node`."""
         if isinstance(node, Leaf):
-            for comment in generate_comments(node):
-                if self.current_line.bracket_tracker.any_open_brackets():
-                    # any comment within brackets is subject to splitting
-                    self.current_line.append(comment)
-                elif comment.type == token.COMMENT:
-                    # regular trailing comment
-                    self.current_line.append(comment)
-                    yield from self.line()
-
-                else:
-                    # regular standalone comment, to be processed later (see
-                    # docstring in `generate_comments()`
-                    self.standalone_comments.append(comment)
-            normalize_prefix(node)
-            if node.type not in WHITESPACE:
-                for comment in self.standalone_comments:
-                    yield from self.line()
-
-                    self.current_line.append(comment)
-                    yield from self.line()
-
-                self.standalone_comments = []
-                self.current_line.append(node)
+            any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
+            try:
+                for comment in generate_comments(node):
+                    if any_open_brackets:
+                        # any comment within brackets is subject to splitting
+                        self.current_line.append(comment)
+                    elif comment.type == token.COMMENT:
+                        # regular trailing comment
+                        self.current_line.append(comment)
+                        yield from self.line()
+
+                    else:
+                        # regular standalone comment
+                        yield from self.line()
+
+                        self.current_line.append(comment)
+                        yield from self.line()
+
+            except FormatOff as f_off:
+                f_off.trim_prefix(node)
+                yield from self.line(type=UnformattedLines)
+                yield from self.visit(node)
+
+            except FormatOn as f_on:
+                # This only happens here if somebody says "fmt: on" multiple
+                # times in a row.
+                f_on.trim_prefix(node)
+                yield from self.visit_default(node)
+
+            else:
+                normalize_prefix(node, inside_brackets=any_open_brackets)
+                if node.type not in WHITESPACE:
+                    self.current_line.append(node)
         yield from super().visit_default(node)
 
-    def visit_suite(self, node: Node) -> Iterator[Line]:
-        """Body of a statement after a colon."""
-        children = iter(node.children)
-        # Process newline before indenting.  It might contain an inline
-        # comment that should go right after the colon.
-        newline = next(children)
-        yield from self.visit(newline)
+    def visit_INDENT(self, node: Node) -> Iterator[Line]:
+        """Increase indentation level, maybe yield a line."""
+        # In blib2to3 INDENT never holds comments.
         yield from self.line(+1)
+        yield from self.visit_default(node)
 
-        for child in children:
-            yield from self.visit(child)
-
+    def visit_DEDENT(self, node: Node) -> Iterator[Line]:
+        """Decrease indentation level, maybe yield a line."""
+        # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
         yield from self.line(-1)
 
     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
         """Visit a statement.
 
-        The relevant Python language keywords for this statement are NAME leaves
-        within it.
+        This implementation is shared for `if`, `while`, `for`, `try`, `except`,
+        `def`, `with`, and `class`.
+
+        The relevant Python language `keywords` for a given statement will be NAME
+        leaves within it. This methods puts those on a separate line.
         """
         for child in node.children:
             if child.type == token.NAME and child.value in keywords:  # type: ignore
@@ -775,7 +1002,7 @@ class LineGenerator(Visitor[Line]):
             yield from self.visit(child)
 
     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
-        """A statement without nested statements."""
+        """Visit a statement without nested statements."""
         is_suite_like = node.parent and node.parent.type in STATEMENT
         if is_suite_like:
             yield from self.line(+1)
@@ -787,13 +1014,14 @@ class LineGenerator(Visitor[Line]):
             yield from self.visit_default(node)
 
     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
+        """Visit `async def`, `async for`, `async with`."""
         yield from self.line()
 
         children = iter(node.children)
         for child in children:
             yield from self.visit(child)
 
-            if child.type == token.NAME and child.value == 'async':  # type: ignore
+            if child.type == token.ASYNC:
                 break
 
         internal_stmt = next(children)
@@ -801,17 +1029,34 @@ class LineGenerator(Visitor[Line]):
             yield from self.visit(child)
 
     def visit_decorators(self, node: Node) -> Iterator[Line]:
+        """Visit decorators."""
         for child in node.children:
             yield from self.line()
             yield from self.visit(child)
 
     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
+        """Remove a semicolon and put the other statement on a separate line."""
         yield from self.line()
 
     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
+        """End of file. Process outstanding comments and end with a newline."""
         yield from self.visit_default(leaf)
         yield from self.line()
 
+    def visit_unformatted(self, node: LN) -> Iterator[Line]:
+        """Used when file contained a `# fmt: off`."""
+        if isinstance(node, Node):
+            for child in node.children:
+                yield from self.visit(child)
+
+        else:
+            try:
+                self.current_line.append(node)
+            except FormatOn as f_on:
+                f_on.trim_prefix(node)
+                yield from self.line()
+                yield from self.visit(node)
+
     def __attrs_post_init__(self) -> None:
         """You are in a twisty little maze of passages."""
         v = self.visit_stmt
@@ -831,7 +1076,7 @@ BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.R
 OPENING_BRACKETS = set(BRACKET.keys())
 CLOSING_BRACKETS = set(BRACKET.values())
 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
-ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
+ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
 
 def whitespace(leaf: Leaf) -> str:  # noqa C901
@@ -849,37 +1094,62 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
         return DOUBLESPACE
 
     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
+    if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
+        return NO
+
     prev = leaf.prev_sibling
     if not prev:
         prevp = preceding_leaf(p)
         if not prevp or prevp.type in OPENING_BRACKETS:
             return NO
 
+        if t == token.COLON:
+            return SPACE if prevp.type == token.COMMA else NO
+
         if prevp.type == token.EQUAL:
-            if prevp.parent and prevp.parent.type in {
-                syms.typedargslist,
-                syms.varargslist,
-                syms.parameters,
-                syms.arglist,
-                syms.argument,
-            }:
-                return NO
+            if prevp.parent:
+                if prevp.parent.type in {
+                    syms.arglist, syms.argument, syms.parameters, syms.varargslist
+                }:
+                    return NO
+
+                elif prevp.parent.type == syms.typedargslist:
+                    # A bit hacky: if the equal sign has whitespace, it means we
+                    # previously found it's a typed argument.  So, we're using
+                    # that, too.
+                    return prevp.prefix
 
         elif prevp.type == token.DOUBLESTAR:
             if prevp.parent and prevp.parent.type in {
-                syms.typedargslist,
-                syms.varargslist,
-                syms.parameters,
                 syms.arglist,
+                syms.argument,
                 syms.dictsetmaker,
+                syms.parameters,
+                syms.typedargslist,
+                syms.varargslist,
             }:
                 return NO
 
         elif prevp.type == token.COLON:
-            if prevp.parent and prevp.parent.type == syms.subscript:
+            if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
                 return NO
 
-        elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
+        elif (
+            prevp.parent
+            and prevp.parent.type in {syms.factor, syms.star_expr}
+            and prevp.type in MATH_OPERATORS
+        ):
+            return NO
+
+        elif (
+            prevp.type == token.RIGHTSHIFT
+            and prevp.parent
+            and prevp.parent.type == syms.shift_expr
+            and prevp.prev_sibling
+            and prevp.prev_sibling.type == token.NAME
+            and prevp.prev_sibling.value == 'print'  # type: ignore
+        ):
+            # Python 2 print chevron
             return NO
 
     elif prev.type in OPENING_BRACKETS:
@@ -893,7 +1163,7 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
         if not prev or prev.type != token.COMMA:
             return NO
 
-    if p.type == syms.varargslist:
+    elif p.type == syms.varargslist:
         # lambdas
         if t == token.RPAR:
             return NO
@@ -983,7 +1253,7 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
 
             return NO
 
-        elif prev.type == token.COLON:
+        else:
             return NO
 
     elif p.type == syms.atom:
@@ -1047,7 +1317,7 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
 
 
 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
-    """Returns the first leaf that precedes `node`, if any."""
+    """Return the first leaf that precedes `node`, if any."""
     while node:
         res = node.prev_sibling
         if res:
@@ -1065,7 +1335,7 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
 
 
 def is_delimiter(leaf: Leaf) -> int:
-    """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
+    """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
 
     Higher numbers are higher priority.
     """
@@ -1086,7 +1356,7 @@ def is_delimiter(leaf: Leaf) -> int:
 
 
 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
-    """Cleans the prefix of the `leaf` and generates comments from it, if any.
+    """Clean the prefix of the `leaf` and generate comments from it, if any.
 
     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
@@ -1104,36 +1374,62 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
     are emitted with a fake STANDALONE_COMMENT token identifier.
     """
-    if not leaf.prefix:
-        return
-
-    if '#' not in leaf.prefix:
+    p = leaf.prefix
+    if not p:
         return
 
-    before_comment, content = leaf.prefix.split('#', 1)
-    content = content.rstrip()
-    if content and (content[0] not in {' ', '!', '#'}):
-        content = ' ' + content
-    is_standalone_comment = (
-        '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
-    )
-    if not is_standalone_comment:
-        # simple trailing comment
-        yield Leaf(token.COMMENT, value='#' + content)
+    if '#' not in p:
         return
 
-    for line in ('#' + content).split('\n'):
+    consumed = 0
+    nlines = 0
+    for index, line in enumerate(p.split('\n')):
+        consumed += len(line) + 1  # adding the length of the split '\n'
         line = line.lstrip()
+        if not line:
+            nlines += 1
         if not line.startswith('#'):
             continue
 
-        yield Leaf(STANDALONE_COMMENT, line)
+        if index == 0 and leaf.type != token.ENDMARKER:
+            comment_type = token.COMMENT  # simple trailing comment
+        else:
+            comment_type = STANDALONE_COMMENT
+        comment = make_comment(line)
+        yield Leaf(comment_type, comment, prefix='\n' * nlines)
+
+        if comment in {'# fmt: on', '# yapf: enable'}:
+            raise FormatOn(consumed)
+
+        if comment in {'# fmt: off', '# yapf: disable'}:
+            raise FormatOff(consumed)
+
+        nlines = 0
+
+
+def make_comment(content: str) -> str:
+    """Return a consistently formatted comment from the given `content` string.
+
+    All comments (except for "##", "#!", "#:") should have a single space between
+    the hash sign and the content.
+
+    If `content` didn't start with a hash sign, one is provided.
+    """
+    content = content.rstrip()
+    if not content:
+        return '#'
+
+    if content[0] == '#':
+        content = content[1:]
+    if content and content[0] not in ' !:#':
+        content = ' ' + content
+    return '#' + content
 
 
 def split_line(
     line: Line, line_length: int, inner: bool = False, py36: bool = False
 ) -> Iterator[Line]:
-    """Splits a `line` into potentially many lines.
+    """Split a `line` into potentially many lines.
 
     They should fit in the allotted `line_length` but might not be able to.
     `inner` signifies that there were a pair of brackets somewhere around the
@@ -1143,6 +1439,10 @@ def split_line(
     If `py36` is True, splitting may generate syntax that is only compatible
     with Python 3.6 and later.
     """
+    if isinstance(line, UnformattedLines):
+        yield line
+        return
+
     line_str = str(line).strip('\n')
     if len(line_str) <= line_length and '\n' not in line_str:
         yield line
@@ -1210,7 +1510,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
                 current_leaves = body_leaves
     # Since body is a new indent level, remove spurious leading whitespace.
     if body_leaves:
-        normalize_prefix(body_leaves[0])
+        normalize_prefix(body_leaves[0], inside_brackets=True)
     # Build the new lines.
     for result, leaves in (
         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
@@ -1220,7 +1520,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
             comment_after = line.comments.get(id(leaf))
             if comment_after:
                 result.append(comment_after, preformatted=True)
-    split_succeeded_or_raise(head, body, tail)
+    bracket_split_succeeded_or_raise(head, body, tail)
     for result in (head, body, tail):
         if result:
             yield result
@@ -1250,7 +1550,7 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
     head_leaves.reverse()
     # Since body is a new indent level, remove spurious leading whitespace.
     if body_leaves:
-        normalize_prefix(body_leaves[0])
+        normalize_prefix(body_leaves[0], inside_brackets=True)
     # Build the new lines.
     for result, leaves in (
         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
@@ -1260,13 +1560,26 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
             comment_after = line.comments.get(id(leaf))
             if comment_after:
                 result.append(comment_after, preformatted=True)
-    split_succeeded_or_raise(head, body, tail)
+    bracket_split_succeeded_or_raise(head, body, tail)
     for result in (head, body, tail):
         if result:
             yield result
 
 
-def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
+def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
+    """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
+
+    Do nothing otherwise.
+
+    A left- or right-hand split is based on a pair of brackets. Content before
+    (and including) the opening bracket is left on one line, content inside the
+    brackets is put on a separate line, and finally content starting with and
+    following the closing bracket is put on a separate line.
+
+    Those are called `head`, `body`, and `tail`, respectively. If the split
+    produced the same line (all content in `head`) or ended up with an empty `body`
+    and the `tail` is just the closing bracket, then it's considered failed.
+    """
     tail_len = len(str(tail).strip())
     if not body:
         if tail_len == 0:
@@ -1284,7 +1597,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
     This kind of split doesn't increase indentation.
     If `py36` is True, the split will add trailing commas also in function
-    signatures that contain * and **.
+    signatures that contain `*` and `**`.
     """
     try:
         last_leaf = line.leaves[-1]
@@ -1293,7 +1606,9 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
     delimiters = line.bracket_tracker.delimiters
     try:
-        delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
+        delimiter_priority = line.bracket_tracker.max_delimiter_priority(
+            exclude={id(last_leaf)}
+        )
     except ValueError:
         raise CannotSplit("No delimiters found")
 
@@ -1314,7 +1629,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
             trailing_comma_safe = trailing_comma_safe and py36
         leaf_priority = delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
-            normalize_prefix(current_line.leaves[0])
+            normalize_prefix(current_line.leaves[0], inside_brackets=True)
             yield current_line
 
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
@@ -1325,12 +1640,12 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
             and trailing_comma_safe
         ):
             current_line.append(Leaf(token.COMMA, ','))
-        normalize_prefix(current_line.leaves[0])
+        normalize_prefix(current_line.leaves[0], inside_brackets=True)
         yield current_line
 
 
 def is_import(leaf: Leaf) -> bool:
-    """Returns True if the given leaf starts an import statement."""
+    """Return True if the given leaf starts an import statement."""
     p = leaf.parent
     t = leaf.type
     v = leaf.value
@@ -1343,19 +1658,26 @@ def is_import(leaf: Leaf) -> bool:
     )
 
 
-def normalize_prefix(leaf: Leaf) -> None:
-    """Leave existing extra newlines for imports.  Remove everything else."""
-    if is_import(leaf):
-        spl = leaf.prefix.split('#', 1)
-        nl_count = spl[0].count('\n')
-        leaf.prefix = '\n' * nl_count
-        return
+def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
+    """Leave existing extra newlines if not `inside_brackets`. Remove everything
+    else.
+
+    Note: don't use backslashes for formatting or you'll lose your voting rights.
+    """
+    if not inside_brackets:
+        spl = leaf.prefix.split('#')
+        if '\\' not in spl[0]:
+            nl_count = spl[-1].count('\n')
+            if len(spl) > 1:
+                nl_count -= 1
+            leaf.prefix = '\n' * nl_count
+            return
 
     leaf.prefix = ''
 
 
 def is_python36(node: Node) -> bool:
-    """Returns True if the current file is using Python 3.6+ features.
+    """Return True if the current file is using Python 3.6+ features.
 
     Currently looking for:
     - f-strings; and
@@ -1386,6 +1708,9 @@ BLACKLISTED_DIRECTORIES = {
 
 
 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
+    """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
+    and have one of the PYTHON_EXTENSIONS.
+    """
     for child in path.iterdir():
         if child.is_dir():
             if child.name in BLACKLISTED_DIRECTORIES:
@@ -1399,7 +1724,8 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
 
 @dataclass
 class Report:
-    """Provides a reformatting counter."""
+    """Provides a reformatting counter. Can be rendered with `str(report)`."""
+    check: bool = False
     change_count: int = 0
     same_count: int = 0
     failure_count: int = 0
@@ -1407,7 +1733,8 @@ class Report:
     def done(self, src: Path, changed: bool) -> None:
         """Increment the counter for successful reformatting. Write out a message."""
         if changed:
-            out(f'reformatted {src}')
+            reformatted = 'would reformat' if self.check else 'reformatted'
+            out(f'{reformatted} {src}')
             self.change_count += 1
         else:
             out(f'{src} already well formatted, good job.', bold=False)
@@ -1420,46 +1747,55 @@ class Report:
 
     @property
     def return_code(self) -> int:
-        """Which return code should the app use considering the current state."""
+        """Return the exit code that the app should use.
+
+        This considers the current state of changed files and failures:
+        - if there were any failures, return 123;
+        - if any files were changed and --check is being used, return 1;
+        - otherwise return 0.
+        """
         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
         # 126 we have special returncodes reserved by the shell.
         if self.failure_count:
             return 123
 
-        elif self.change_count:
+        elif self.change_count and self.check:
             return 1
 
         return 0
 
     def __str__(self) -> str:
-        """A color report of the current state.
+        """Render a color report of the current state.
 
         Use `click.unstyle` to remove colors.
         """
+        if self.check:
+            reformatted = "would be reformatted"
+            unchanged = "would be left unchanged"
+            failed = "would fail to reformat"
+        else:
+            reformatted = "reformatted"
+            unchanged = "left unchanged"
+            failed = "failed to reformat"
         report = []
         if self.change_count:
             s = 's' if self.change_count > 1 else ''
             report.append(
-                click.style(f'{self.change_count} file{s} reformatted', bold=True)
+                click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
             )
         if self.same_count:
             s = 's' if self.same_count > 1 else ''
-            report.append(f'{self.same_count} file{s} left unchanged')
+            report.append(f'{self.same_count} file{s} {unchanged}')
         if self.failure_count:
             s = 's' if self.failure_count > 1 else ''
             report.append(
-                click.style(
-                    f'{self.failure_count} file{s} failed to reformat', fg='red'
-                )
+                click.style(f'{self.failure_count} file{s} {failed}', fg='red')
             )
         return ', '.join(report) + '.'
 
 
 def assert_equivalent(src: str, dst: str) -> None:
-    """Raises AssertionError if `src` and `dst` aren't equivalent.
-
-    This is a temporary sanity check until Black becomes stable.
-    """
+    """Raise AssertionError if `src` and `dst` aren't equivalent."""
 
     import ast
     import traceback
@@ -1492,7 +1828,12 @@ def assert_equivalent(src: str, dst: str) -> None:
     try:
         src_ast = ast.parse(src)
     except Exception as exc:
-        raise AssertionError(f"cannot parse source: {exc}") from None
+        major, minor = sys.version_info[:2]
+        raise AssertionError(
+            f"cannot use --safe with this file; failed to parse source file "
+            f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
+            f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
+        )
 
     try:
         dst_ast = ast.parse(dst)
@@ -1517,10 +1858,7 @@ def assert_equivalent(src: str, dst: str) -> None:
 
 
 def assert_stable(src: str, dst: str, line_length: int) -> None:
-    """Raises AssertionError if `dst` reformats differently the second time.
-
-    This is a temporary sanity check until Black becomes stable.
-    """
+    """Raise AssertionError if `dst` reformats differently the second time."""
     newdst = format_str(dst, line_length=line_length)
     if dst != newdst:
         log = dump_to_file(
@@ -1536,7 +1874,7 @@ def assert_stable(src: str, dst: str, line_length: int) -> None:
 
 
 def dump_to_file(*output: str) -> str:
-    """Dumps `output` to a temporary file. Returns path to the file."""
+    """Dump `output` to a temporary file. Return path to the file."""
     import tempfile
 
     with tempfile.NamedTemporaryFile(
@@ -1549,7 +1887,7 @@ def dump_to_file(*output: str) -> str:
 
 
 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
-    """Returns a udiff string between strings `a` and `b`."""
+    """Return a unified diff string between strings `a` and `b`."""
     import difflib
 
     a_lines = [line + '\n' for line in a.split('\n')]