All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
   4 from asyncio.base_events import BaseEventLoop
 
   5 from concurrent.futures import Executor, ProcessPoolExecutor
 
   7 from functools import partial, wraps
 
  10 from multiprocessing import Manager
 
  12 from pathlib import Path
 
  32 from attr import dataclass, Factory
 
  36 from blib2to3.pytree import Node, Leaf, type_repr
 
  37 from blib2to3 import pygram, pytree
 
  38 from blib2to3.pgen2 import driver, token
 
  39 from blib2to3.pgen2.parse import ParseError
 
  41 __version__ = "18.4a0"
 
  42 DEFAULT_LINE_LENGTH = 88
 
  44 syms = pygram.python_symbols
 
  52 LN = Union[Leaf, Node]
 
  53 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
 
  54 out = partial(click.secho, bold=True, err=True)
 
  55 err = partial(click.secho, fg="red", err=True)
 
  58 class NothingChanged(UserWarning):
 
  59     """Raised by :func:`format_file` when reformatted code is the same as source."""
 
  62 class CannotSplit(Exception):
 
  63     """A readable split that fits the allotted line length is impossible.
 
  65     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
 
  66     :func:`delimiter_split`.
 
  70 class FormatError(Exception):
 
  71     """Base exception for `# fmt: on` and `# fmt: off` handling.
 
  73     It holds the number of bytes of the prefix consumed before the format
 
  74     control comment appeared.
 
  77     def __init__(self, consumed: int) -> None:
 
  78         super().__init__(consumed)
 
  79         self.consumed = consumed
 
  81     def trim_prefix(self, leaf: Leaf) -> None:
 
  82         leaf.prefix = leaf.prefix[self.consumed:]
 
  84     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
 
  85         """Returns a new Leaf from the consumed part of the prefix."""
 
  86         unformatted_prefix = leaf.prefix[:self.consumed]
 
  87         return Leaf(token.NEWLINE, unformatted_prefix)
 
  90 class FormatOn(FormatError):
 
  91     """Found a comment like `# fmt: on` in the file."""
 
  94 class FormatOff(FormatError):
 
  95     """Found a comment like `# fmt: off` in the file."""
 
  98 class WriteBack(Enum):
 
 109     default=DEFAULT_LINE_LENGTH,
 
 110     help="How many character per line to allow.",
 
 117         "Don't write the files back, just return the status.  Return code 0 "
 
 118         "means nothing would change.  Return code 1 means some files would be "
 
 119         "reformatted.  Return code 123 means there was an internal error."
 
 125     help="Don't write the files back, just output a diff for each file on stdout.",
 
 130     help="If --fast given, skip temporary sanity checks. [default: --safe]",
 
 137         "Don't emit non-error messages to stderr. Errors are still emitted, "
 
 138         "silence those with 2>/dev/null."
 
 141 @click.version_option(version=__version__)
 
 146         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
 
 159     """The uncompromising code formatter."""
 
 160     sources: List[Path] = []
 
 164             sources.extend(gen_python_files_in_dir(p))
 
 166             # if a file was explicitly given, we don't care about its extension
 
 169             sources.append(Path("-"))
 
 171             err(f"invalid path: {s}")
 
 173         exc = click.ClickException("Options --check and --diff are mutually exclusive")
 
 178         write_back = WriteBack.NO
 
 180         write_back = WriteBack.DIFF
 
 182         write_back = WriteBack.YES
 
 183     if len(sources) == 0:
 
 185     elif len(sources) == 1:
 
 187         report = Report(check=check, quiet=quiet)
 
 189             if not p.is_file() and str(p) == "-":
 
 190                 changed = format_stdin_to_stdout(
 
 191                     line_length=line_length, fast=fast, write_back=write_back
 
 194                 changed = format_file_in_place(
 
 195                     p, line_length=line_length, fast=fast, write_back=write_back
 
 197             report.done(p, changed)
 
 198         except Exception as exc:
 
 199             report.failed(p, str(exc))
 
 200         ctx.exit(report.return_code)
 
 202         loop = asyncio.get_event_loop()
 
 203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
 
 206             return_code = loop.run_until_complete(
 
 208                     sources, line_length, write_back, fast, quiet, loop, executor
 
 213             ctx.exit(return_code)
 
 216 async def schedule_formatting(
 
 219     write_back: WriteBack,
 
 225     """Run formatting of `sources` in parallel using the provided `executor`.
 
 227     (Use ProcessPoolExecutors for actual parallelism.)
 
 229     `line_length`, `write_back`, and `fast` options are passed to
 
 230     :func:`format_file_in_place`.
 
 233     if write_back == WriteBack.DIFF:
 
 234         # For diff output, we need locks to ensure we don't interleave output
 
 235         # from different processes.
 
 237         lock = manager.Lock()
 
 239         src: loop.run_in_executor(
 
 240             executor, format_file_in_place, src, line_length, fast, write_back, lock
 
 244     _task_values = list(tasks.values())
 
 245     loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
 
 246     loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
 
 247     await asyncio.wait(tasks.values())
 
 249     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
 
 250     for src, task in tasks.items():
 
 252             report.failed(src, "timed out, cancelling")
 
 254             cancelled.append(task)
 
 255         elif task.cancelled():
 
 256             cancelled.append(task)
 
 257         elif task.exception():
 
 258             report.failed(src, str(task.exception()))
 
 260             report.done(src, task.result())
 
 262         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
 
 264         out("All done! ✨ 🍰 ✨")
 
 266         click.echo(str(report))
 
 267     return report.return_code
 
 270 def format_file_in_place(
 
 274     write_back: WriteBack = WriteBack.NO,
 
 275     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 
 277     """Format file under `src` path. Return True if changed.
 
 279     If `write_back` is True, write reformatted code back to stdout.
 
 280     `line_length` and `fast` options are passed to :func:`format_file_contents`.
 
 282     with tokenize.open(src) as src_buffer:
 
 283         src_contents = src_buffer.read()
 
 285         dst_contents = format_file_contents(
 
 286             src_contents, line_length=line_length, fast=fast
 
 288     except NothingChanged:
 
 291     if write_back == write_back.YES:
 
 292         with open(src, "w", encoding=src_buffer.encoding) as f:
 
 293             f.write(dst_contents)
 
 294     elif write_back == write_back.DIFF:
 
 295         src_name = f"{src.name}  (original)"
 
 296         dst_name = f"{src.name}  (formatted)"
 
 297         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
 
 301             sys.stdout.write(diff_contents)
 
 308 def format_stdin_to_stdout(
 
 309     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
 
 311     """Format file on stdin. Return True if changed.
 
 313     If `write_back` is True, write reformatted code back to stdout.
 
 314     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
 
 316     src = sys.stdin.read()
 
 318         dst = format_file_contents(src, line_length=line_length, fast=fast)
 
 321     except NothingChanged:
 
 326         if write_back == WriteBack.YES:
 
 327             sys.stdout.write(dst)
 
 328         elif write_back == WriteBack.DIFF:
 
 329             src_name = "<stdin>  (original)"
 
 330             dst_name = "<stdin>  (formatted)"
 
 331             sys.stdout.write(diff(src, dst, src_name, dst_name))
 
 334 def format_file_contents(
 
 335     src_contents: str, line_length: int, fast: bool
 
 337     """Reformat contents a file and return new contents.
 
 339     If `fast` is False, additionally confirm that the reformatted code is
 
 340     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
 
 341     `line_length` is passed to :func:`format_str`.
 
 343     if src_contents.strip() == "":
 
 346     dst_contents = format_str(src_contents, line_length=line_length)
 
 347     if src_contents == dst_contents:
 
 351         assert_equivalent(src_contents, dst_contents)
 
 352         assert_stable(src_contents, dst_contents, line_length=line_length)
 
 356 def format_str(src_contents: str, line_length: int) -> FileContent:
 
 357     """Reformat a string and return new contents.
 
 359     `line_length` determines how many characters per line are allowed.
 
 361     src_node = lib2to3_parse(src_contents)
 
 363     lines = LineGenerator()
 
 364     elt = EmptyLineTracker()
 
 365     py36 = is_python36(src_node)
 
 368     for current_line in lines.visit(src_node):
 
 369         for _ in range(after):
 
 370             dst_contents += str(empty_line)
 
 371         before, after = elt.maybe_empty_lines(current_line)
 
 372         for _ in range(before):
 
 373             dst_contents += str(empty_line)
 
 374         for line in split_line(current_line, line_length=line_length, py36=py36):
 
 375             dst_contents += str(line)
 
 380     pygram.python_grammar_no_print_statement_no_exec_statement,
 
 381     pygram.python_grammar_no_print_statement,
 
 382     pygram.python_grammar_no_exec_statement,
 
 383     pygram.python_grammar,
 
 387 def lib2to3_parse(src_txt: str) -> Node:
 
 388     """Given a string with source, return the lib2to3 Node."""
 
 389     grammar = pygram.python_grammar_no_print_statement
 
 390     if src_txt[-1] != "\n":
 
 391         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
 
 393     for grammar in GRAMMARS:
 
 394         drv = driver.Driver(grammar, pytree.convert)
 
 396             result = drv.parse_string(src_txt, True)
 
 399         except ParseError as pe:
 
 400             lineno, column = pe.context[1]
 
 401             lines = src_txt.splitlines()
 
 403                 faulty_line = lines[lineno - 1]
 
 405                 faulty_line = "<line number missing in source>"
 
 406             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
 
 410     if isinstance(result, Leaf):
 
 411         result = Node(syms.file_input, [result])
 
 415 def lib2to3_unparse(node: Node) -> str:
 
 416     """Given a lib2to3 node, return its string representation."""
 
 424 class Visitor(Generic[T]):
 
 425     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
 
 427     def visit(self, node: LN) -> Iterator[T]:
 
 428         """Main method to visit `node` and its children.
 
 430         It tries to find a `visit_*()` method for the given `node.type`, like
 
 431         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
 
 432         If no dedicated `visit_*()` method is found, chooses `visit_default()`
 
 435         Then yields objects of type `T` from the selected visitor.
 
 438             name = token.tok_name[node.type]
 
 440             name = type_repr(node.type)
 
 441         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
 
 443     def visit_default(self, node: LN) -> Iterator[T]:
 
 444         """Default `visit_*()` implementation. Recurses to children of `node`."""
 
 445         if isinstance(node, Node):
 
 446             for child in node.children:
 
 447                 yield from self.visit(child)
 
 451 class DebugVisitor(Visitor[T]):
 
 454     def visit_default(self, node: LN) -> Iterator[T]:
 
 455         indent = " " * (2 * self.tree_depth)
 
 456         if isinstance(node, Node):
 
 457             _type = type_repr(node.type)
 
 458             out(f"{indent}{_type}", fg="yellow")
 
 460             for child in node.children:
 
 461                 yield from self.visit(child)
 
 464             out(f"{indent}/{_type}", fg="yellow", bold=False)
 
 466             _type = token.tok_name.get(node.type, str(node.type))
 
 467             out(f"{indent}{_type}", fg="blue", nl=False)
 
 469                 # We don't have to handle prefixes for `Node` objects since
 
 470                 # that delegates to the first child anyway.
 
 471                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
 
 472             out(f" {node.value!r}", fg="blue", bold=False)
 
 475     def show(cls, code: str) -> None:
 
 476         """Pretty-print the lib2to3 AST of a given string of `code`.
 
 478         Convenience method for debugging.
 
 480         v: DebugVisitor[None] = DebugVisitor()
 
 481         list(v.visit(lib2to3_parse(code)))
 
 484 KEYWORDS = set(keyword.kwlist)
 
 485 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
 
 486 FLOW_CONTROL = {"return", "raise", "break", "continue"}
 
 497 STANDALONE_COMMENT = 153
 
 498 LOGIC_OPERATORS = {"and", "or"}
 
 522 VARARGS = {token.STAR, token.DOUBLESTAR}
 
 523 COMPREHENSION_PRIORITY = 20
 
 527 COMPARATOR_PRIORITY = 3
 
 532 class BracketTracker:
 
 533     """Keeps track of brackets on a line."""
 
 536     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
 
 537     delimiters: Dict[LeafID, Priority] = Factory(dict)
 
 538     previous: Optional[Leaf] = None
 
 540     def mark(self, leaf: Leaf) -> None:
 
 541         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
 
 543         All leaves receive an int `bracket_depth` field that stores how deep
 
 544         within brackets a given leaf is. 0 means there are no enclosing brackets
 
 545         that started on this line.
 
 547         If a leaf is itself a closing bracket, it receives an `opening_bracket`
 
 548         field that it forms a pair with. This is a one-directional link to
 
 549         avoid reference cycles.
 
 551         If a leaf is a delimiter (a token on which Black can split the line if
 
 552         needed) and it's on depth 0, its `id()` is stored in the tracker's
 
 555         if leaf.type == token.COMMENT:
 
 558         if leaf.type in CLOSING_BRACKETS:
 
 560             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
 
 561             leaf.opening_bracket = opening_bracket
 
 562         leaf.bracket_depth = self.depth
 
 564             after_delim = is_split_after_delimiter(leaf, self.previous)
 
 565             before_delim = is_split_before_delimiter(leaf, self.previous)
 
 566             if after_delim > before_delim:
 
 567                 self.delimiters[id(leaf)] = after_delim
 
 568             elif before_delim > after_delim and self.previous is not None:
 
 569                 self.delimiters[id(self.previous)] = before_delim
 
 570         if leaf.type in OPENING_BRACKETS:
 
 571             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
 
 575     def any_open_brackets(self) -> bool:
 
 576         """Return True if there is an yet unmatched open bracket on the line."""
 
 577         return bool(self.bracket_match)
 
 579     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
 
 580         """Return the highest priority of a delimiter found on the line.
 
 582         Values are consistent with what `is_delimiter()` returns.
 
 584         return max(v for k, v in self.delimiters.items() if k not in exclude)
 
 589     """Holds leaves and comments. Can be printed with `str(line)`."""
 
 592     leaves: List[Leaf] = Factory(list)
 
 593     comments: List[Tuple[Index, Leaf]] = Factory(list)
 
 594     bracket_tracker: BracketTracker = Factory(BracketTracker)
 
 595     inside_brackets: bool = False
 
 596     has_for: bool = False
 
 597     _for_loop_variable: bool = False
 
 599     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
 
 600         """Add a new `leaf` to the end of the line.
 
 602         Unless `preformatted` is True, the `leaf` will receive a new consistent
 
 603         whitespace prefix and metadata applied by :class:`BracketTracker`.
 
 604         Trailing commas are maybe removed, unpacked for loop variables are
 
 605         demoted from being delimiters.
 
 607         Inline comments are put aside.
 
 609         has_value = leaf.value.strip()
 
 613         if self.leaves and not preformatted:
 
 614             # Note: at this point leaf.prefix should be empty except for
 
 615             # imports, for which we only preserve newlines.
 
 616             leaf.prefix += whitespace(leaf)
 
 617         if self.inside_brackets or not preformatted:
 
 618             self.maybe_decrement_after_for_loop_variable(leaf)
 
 619             self.bracket_tracker.mark(leaf)
 
 620             self.maybe_remove_trailing_comma(leaf)
 
 621             self.maybe_increment_for_loop_variable(leaf)
 
 623         if not self.append_comment(leaf):
 
 624             self.leaves.append(leaf)
 
 626     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
 
 627         """Like :func:`append()` but disallow invalid standalone comment structure.
 
 629         Raises ValueError when any `leaf` is appended after a standalone comment
 
 630         or when a standalone comment is not the first leaf on the line.
 
 632         if self.bracket_tracker.depth == 0:
 
 634                 raise ValueError("cannot append to standalone comments")
 
 636             if self.leaves and leaf.type == STANDALONE_COMMENT:
 
 638                     "cannot append standalone comments to a populated line"
 
 641         self.append(leaf, preformatted=preformatted)
 
 644     def is_comment(self) -> bool:
 
 645         """Is this line a standalone comment?"""
 
 646         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
 
 649     def is_decorator(self) -> bool:
 
 650         """Is this line a decorator?"""
 
 651         return bool(self) and self.leaves[0].type == token.AT
 
 654     def is_import(self) -> bool:
 
 655         """Is this an import line?"""
 
 656         return bool(self) and is_import(self.leaves[0])
 
 659     def is_class(self) -> bool:
 
 660         """Is this line a class definition?"""
 
 663             and self.leaves[0].type == token.NAME
 
 664             and self.leaves[0].value == "class"
 
 668     def is_def(self) -> bool:
 
 669         """Is this a function definition? (Also returns True for async defs.)"""
 
 671             first_leaf = self.leaves[0]
 
 676             second_leaf: Optional[Leaf] = self.leaves[1]
 
 680             (first_leaf.type == token.NAME and first_leaf.value == "def")
 
 682                 first_leaf.type == token.ASYNC
 
 683                 and second_leaf is not None
 
 684                 and second_leaf.type == token.NAME
 
 685                 and second_leaf.value == "def"
 
 690     def is_flow_control(self) -> bool:
 
 691         """Is this line a flow control statement?
 
 693         Those are `return`, `raise`, `break`, and `continue`.
 
 697             and self.leaves[0].type == token.NAME
 
 698             and self.leaves[0].value in FLOW_CONTROL
 
 702     def is_yield(self) -> bool:
 
 703         """Is this line a yield statement?"""
 
 706             and self.leaves[0].type == token.NAME
 
 707             and self.leaves[0].value == "yield"
 
 711     def contains_standalone_comments(self) -> bool:
 
 712         """If so, needs to be split before emitting."""
 
 713         for leaf in self.leaves:
 
 714             if leaf.type == STANDALONE_COMMENT:
 
 719     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
 
 720         """Remove trailing comma if there is one and it's safe."""
 
 723             and self.leaves[-1].type == token.COMMA
 
 724             and closing.type in CLOSING_BRACKETS
 
 728         if closing.type == token.RBRACE:
 
 729             self.remove_trailing_comma()
 
 732         if closing.type == token.RSQB:
 
 733             comma = self.leaves[-1]
 
 734             if comma.parent and comma.parent.type == syms.listmaker:
 
 735                 self.remove_trailing_comma()
 
 738         # For parens let's check if it's safe to remove the comma.  If the
 
 739         # trailing one is the only one, we might mistakenly change a tuple
 
 740         # into a different type by removing the comma.
 
 741         depth = closing.bracket_depth + 1
 
 743         opening = closing.opening_bracket
 
 744         for _opening_index, leaf in enumerate(self.leaves):
 
 751         for leaf in self.leaves[_opening_index + 1:]:
 
 755             bracket_depth = leaf.bracket_depth
 
 756             if bracket_depth == depth and leaf.type == token.COMMA:
 
 758                 if leaf.parent and leaf.parent.type == syms.arglist:
 
 763             self.remove_trailing_comma()
 
 768     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
 
 769         """In a for loop, or comprehension, the variables are often unpacks.
 
 771         To avoid splitting on the comma in this situation, increase the depth of
 
 772         tokens between `for` and `in`.
 
 774         if leaf.type == token.NAME and leaf.value == "for":
 
 776             self.bracket_tracker.depth += 1
 
 777             self._for_loop_variable = True
 
 782     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
 
 783         """See `maybe_increment_for_loop_variable` above for explanation."""
 
 784         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
 
 785             self.bracket_tracker.depth -= 1
 
 786             self._for_loop_variable = False
 
 791     def append_comment(self, comment: Leaf) -> bool:
 
 792         """Add an inline or standalone comment to the line."""
 
 794             comment.type == STANDALONE_COMMENT
 
 795             and self.bracket_tracker.any_open_brackets()
 
 800         if comment.type != token.COMMENT:
 
 803         after = len(self.leaves) - 1
 
 805             comment.type = STANDALONE_COMMENT
 
 810             self.comments.append((after, comment))
 
 813     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
 
 814         """Generate comments that should appear directly after `leaf`."""
 
 815         for _leaf_index, _leaf in enumerate(self.leaves):
 
 822         for index, comment_after in self.comments:
 
 823             if _leaf_index == index:
 
 826     def remove_trailing_comma(self) -> None:
 
 827         """Remove the trailing comma and moves the comments attached to it."""
 
 828         comma_index = len(self.leaves) - 1
 
 829         for i in range(len(self.comments)):
 
 830             comment_index, comment = self.comments[i]
 
 831             if comment_index == comma_index:
 
 832                 self.comments[i] = (comma_index - 1, comment)
 
 835     def __str__(self) -> str:
 
 836         """Render the line."""
 
 840         indent = "    " * self.depth
 
 841         leaves = iter(self.leaves)
 
 843         res = f"{first.prefix}{indent}{first.value}"
 
 846         for _, comment in self.comments:
 
 850     def __bool__(self) -> bool:
 
 851         """Return True if the line has leaves or comments."""
 
 852         return bool(self.leaves or self.comments)
 
 855 class UnformattedLines(Line):
 
 856     """Just like :class:`Line` but stores lines which aren't reformatted."""
 
 858     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
 
 859         """Just add a new `leaf` to the end of the lines.
 
 861         The `preformatted` argument is ignored.
 
 863         Keeps track of indentation `depth`, which is useful when the user
 
 864         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
 
 867             list(generate_comments(leaf))
 
 868         except FormatOn as f_on:
 
 869             self.leaves.append(f_on.leaf_from_consumed(leaf))
 
 872         self.leaves.append(leaf)
 
 873         if leaf.type == token.INDENT:
 
 875         elif leaf.type == token.DEDENT:
 
 878     def __str__(self) -> str:
 
 879         """Render unformatted lines from leaves which were added with `append()`.
 
 881         `depth` is not used for indentation in this case.
 
 887         for leaf in self.leaves:
 
 891     def append_comment(self, comment: Leaf) -> bool:
 
 892         """Not implemented in this class. Raises `NotImplementedError`."""
 
 893         raise NotImplementedError("Unformatted lines don't store comments separately.")
 
 895     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
 
 896         """Does nothing and returns False."""
 
 899     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
 
 900         """Does nothing and returns False."""
 
 905 class EmptyLineTracker:
 
 906     """Provides a stateful method that returns the number of potential extra
 
 907     empty lines needed before and after the currently processed line.
 
 909     Note: this tracker works on lines that haven't been split yet.  It assumes
 
 910     the prefix of the first leaf consists of optional newlines.  Those newlines
 
 911     are consumed by `maybe_empty_lines()` and included in the computation.
 
 913     previous_line: Optional[Line] = None
 
 914     previous_after: int = 0
 
 915     previous_defs: List[int] = Factory(list)
 
 917     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
 
 918         """Return the number of extra empty lines before and after the `current_line`.
 
 920         This is for separating `def`, `async def` and `class` with extra empty
 
 921         lines (two on module-level), as well as providing an extra empty line
 
 922         after flow control keywords to make them more prominent.
 
 924         if isinstance(current_line, UnformattedLines):
 
 927         before, after = self._maybe_empty_lines(current_line)
 
 928         before -= self.previous_after
 
 929         self.previous_after = after
 
 930         self.previous_line = current_line
 
 933     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
 
 935         if current_line.depth == 0:
 
 937         if current_line.leaves:
 
 938             # Consume the first leaf's extra newlines.
 
 939             first_leaf = current_line.leaves[0]
 
 940             before = first_leaf.prefix.count("\n")
 
 941             before = min(before, max_allowed)
 
 942             first_leaf.prefix = ""
 
 945         depth = current_line.depth
 
 946         while self.previous_defs and self.previous_defs[-1] >= depth:
 
 947             self.previous_defs.pop()
 
 948             before = 1 if depth else 2
 
 949         is_decorator = current_line.is_decorator
 
 950         if is_decorator or current_line.is_def or current_line.is_class:
 
 952                 self.previous_defs.append(depth)
 
 953             if self.previous_line is None:
 
 954                 # Don't insert empty lines before the first line in the file.
 
 957             if self.previous_line and self.previous_line.is_decorator:
 
 958                 # Don't insert empty lines between decorators.
 
 962             if current_line.depth:
 
 966         if current_line.is_flow_control:
 
 971             and self.previous_line.is_import
 
 972             and not current_line.is_import
 
 973             and depth == self.previous_line.depth
 
 975             return (before or 1), 0
 
 979             and self.previous_line.is_yield
 
 980             and (not current_line.is_yield or depth != self.previous_line.depth)
 
 982             return (before or 1), 0
 
 988 class LineGenerator(Visitor[Line]):
 
 989     """Generates reformatted Line objects.  Empty lines are not emitted.
 
 991     Note: destroys the tree it's visiting by mutating prefixes of its leaves
 
 992     in ways that will no longer stringify to valid Python code on the tree.
 
 994     current_line: Line = Factory(Line)
 
 996     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
 
 999         If the line is empty, only emit if it makes sense.
 
1000         If the line is too long, split it first and then generate.
 
1002         If any lines were generated, set up a new current_line.
 
1004         if not self.current_line:
 
1005             if self.current_line.__class__ == type:
 
1006                 self.current_line.depth += indent
 
1008                 self.current_line = type(depth=self.current_line.depth + indent)
 
1009             return  # Line is empty, don't emit. Creating a new one unnecessary.
 
1011         complete_line = self.current_line
 
1012         self.current_line = type(depth=complete_line.depth + indent)
 
1015     def visit(self, node: LN) -> Iterator[Line]:
 
1016         """Main method to visit `node` and its children.
 
1018         Yields :class:`Line` objects.
 
1020         if isinstance(self.current_line, UnformattedLines):
 
1021             # File contained `# fmt: off`
 
1022             yield from self.visit_unformatted(node)
 
1025             yield from super().visit(node)
 
1027     def visit_default(self, node: LN) -> Iterator[Line]:
 
1028         """Default `visit_*()` implementation. Recurses to children of `node`."""
 
1029         if isinstance(node, Leaf):
 
1030             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
 
1032                 for comment in generate_comments(node):
 
1033                     if any_open_brackets:
 
1034                         # any comment within brackets is subject to splitting
 
1035                         self.current_line.append(comment)
 
1036                     elif comment.type == token.COMMENT:
 
1037                         # regular trailing comment
 
1038                         self.current_line.append(comment)
 
1039                         yield from self.line()
 
1042                         # regular standalone comment
 
1043                         yield from self.line()
 
1045                         self.current_line.append(comment)
 
1046                         yield from self.line()
 
1048             except FormatOff as f_off:
 
1049                 f_off.trim_prefix(node)
 
1050                 yield from self.line(type=UnformattedLines)
 
1051                 yield from self.visit(node)
 
1053             except FormatOn as f_on:
 
1054                 # This only happens here if somebody says "fmt: on" multiple
 
1056                 f_on.trim_prefix(node)
 
1057                 yield from self.visit_default(node)
 
1060                 normalize_prefix(node, inside_brackets=any_open_brackets)
 
1061                 if node.type == token.STRING:
 
1062                     normalize_string_quotes(node)
 
1063                 if node.type not in WHITESPACE:
 
1064                     self.current_line.append(node)
 
1065         yield from super().visit_default(node)
 
1067     def visit_INDENT(self, node: Node) -> Iterator[Line]:
 
1068         """Increase indentation level, maybe yield a line."""
 
1069         # In blib2to3 INDENT never holds comments.
 
1070         yield from self.line(+1)
 
1071         yield from self.visit_default(node)
 
1073     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
 
1074         """Decrease indentation level, maybe yield a line."""
 
1075         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
 
1076         yield from self.line(-1)
 
1078     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
 
1079         """Visit a statement.
 
1081         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
 
1082         `def`, `with`, and `class`.
 
1084         The relevant Python language `keywords` for a given statement will be NAME
 
1085         leaves within it. This methods puts those on a separate line.
 
1087         for child in node.children:
 
1088             if child.type == token.NAME and child.value in keywords:  # type: ignore
 
1089                 yield from self.line()
 
1091             yield from self.visit(child)
 
1093     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
 
1094         """Visit a statement without nested statements."""
 
1095         is_suite_like = node.parent and node.parent.type in STATEMENT
 
1097             yield from self.line(+1)
 
1098             yield from self.visit_default(node)
 
1099             yield from self.line(-1)
 
1102             yield from self.line()
 
1103             yield from self.visit_default(node)
 
1105     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
 
1106         """Visit `async def`, `async for`, `async with`."""
 
1107         yield from self.line()
 
1109         children = iter(node.children)
 
1110         for child in children:
 
1111             yield from self.visit(child)
 
1113             if child.type == token.ASYNC:
 
1116         internal_stmt = next(children)
 
1117         for child in internal_stmt.children:
 
1118             yield from self.visit(child)
 
1120     def visit_decorators(self, node: Node) -> Iterator[Line]:
 
1121         """Visit decorators."""
 
1122         for child in node.children:
 
1123             yield from self.line()
 
1124             yield from self.visit(child)
 
1126     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
 
1127         """Remove a semicolon and put the other statement on a separate line."""
 
1128         yield from self.line()
 
1130     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
 
1131         """End of file. Process outstanding comments and end with a newline."""
 
1132         yield from self.visit_default(leaf)
 
1133         yield from self.line()
 
1135     def visit_unformatted(self, node: LN) -> Iterator[Line]:
 
1136         """Used when file contained a `# fmt: off`."""
 
1137         if isinstance(node, Node):
 
1138             for child in node.children:
 
1139                 yield from self.visit(child)
 
1143                 self.current_line.append(node)
 
1144             except FormatOn as f_on:
 
1145                 f_on.trim_prefix(node)
 
1146                 yield from self.line()
 
1147                 yield from self.visit(node)
 
1149             if node.type == token.ENDMARKER:
 
1150                 # somebody decided not to put a final `# fmt: on`
 
1151                 yield from self.line()
 
1153     def __attrs_post_init__(self) -> None:
 
1154         """You are in a twisty little maze of passages."""
 
1156         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"})
 
1157         self.visit_while_stmt = partial(v, keywords={"while", "else"})
 
1158         self.visit_for_stmt = partial(v, keywords={"for", "else"})
 
1159         self.visit_try_stmt = partial(v, keywords={"try", "except", "else", "finally"})
 
1160         self.visit_except_clause = partial(v, keywords={"except"})
 
1161         self.visit_funcdef = partial(v, keywords={"def"})
 
1162         self.visit_with_stmt = partial(v, keywords={"with"})
 
1163         self.visit_classdef = partial(v, keywords={"class"})
 
1164         self.visit_async_funcdef = self.visit_async_stmt
 
1165         self.visit_decorated = self.visit_decorators
 
1168 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
 
1169 OPENING_BRACKETS = set(BRACKET.keys())
 
1170 CLOSING_BRACKETS = set(BRACKET.values())
 
1171 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
 
1172 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
1175 def whitespace(leaf: Leaf) -> str:  # noqa C901
 
1176     """Return whitespace prefix if needed for the given `leaf`."""
 
1183     if t in ALWAYS_NO_SPACE:
 
1186     if t == token.COMMENT:
 
1189     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
 
1190     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
 
1193     prev = leaf.prev_sibling
 
1195         prevp = preceding_leaf(p)
 
1196         if not prevp or prevp.type in OPENING_BRACKETS:
 
1199         if t == token.COLON:
 
1200             return SPACE if prevp.type == token.COMMA else NO
 
1202         if prevp.type == token.EQUAL:
 
1204                 if prevp.parent.type in {
 
1205                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
 
1209                 elif prevp.parent.type == syms.typedargslist:
 
1210                     # A bit hacky: if the equal sign has whitespace, it means we
 
1211                     # previously found it's a typed argument.  So, we're using
 
1215         elif prevp.type == token.DOUBLESTAR:
 
1216             if prevp.parent and prevp.parent.type in {
 
1226         elif prevp.type == token.COLON:
 
1227             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
 
1232             and prevp.parent.type in {syms.factor, syms.star_expr}
 
1233             and prevp.type in MATH_OPERATORS
 
1238             prevp.type == token.RIGHTSHIFT
 
1240             and prevp.parent.type == syms.shift_expr
 
1241             and prevp.prev_sibling
 
1242             and prevp.prev_sibling.type == token.NAME
 
1243             and prevp.prev_sibling.value == "print"  # type: ignore
 
1245             # Python 2 print chevron
 
1248     elif prev.type in OPENING_BRACKETS:
 
1251     if p.type in {syms.parameters, syms.arglist}:
 
1252         # untyped function signatures or calls
 
1256         if not prev or prev.type != token.COMMA:
 
1259     elif p.type == syms.varargslist:
 
1264         if prev and prev.type != token.COMMA:
 
1267     elif p.type == syms.typedargslist:
 
1268         # typed function signatures
 
1272         if t == token.EQUAL:
 
1273             if prev.type != syms.tname:
 
1276         elif prev.type == token.EQUAL:
 
1277             # A bit hacky: if the equal sign has whitespace, it means we
 
1278             # previously found it's a typed argument.  So, we're using that, too.
 
1281         elif prev.type != token.COMMA:
 
1284     elif p.type == syms.tname:
 
1287             prevp = preceding_leaf(p)
 
1288             if not prevp or prevp.type != token.COMMA:
 
1291     elif p.type == syms.trailer:
 
1292         # attributes and calls
 
1293         if t == token.LPAR or t == token.RPAR:
 
1298                 prevp = preceding_leaf(p)
 
1299                 if not prevp or prevp.type != token.NUMBER:
 
1302             elif t == token.LSQB:
 
1305         elif prev.type != token.COMMA:
 
1308     elif p.type == syms.argument:
 
1310         if t == token.EQUAL:
 
1314             prevp = preceding_leaf(p)
 
1315             if not prevp or prevp.type == token.LPAR:
 
1318         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
 
1321     elif p.type == syms.decorator:
 
1325     elif p.type == syms.dotted_name:
 
1329         prevp = preceding_leaf(p)
 
1330         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
 
1333     elif p.type == syms.classdef:
 
1337         if prev and prev.type == token.LPAR:
 
1340     elif p.type == syms.subscript:
 
1343             assert p.parent is not None, "subscripts are always parented"
 
1344             if p.parent.type == syms.subscriptlist:
 
1352     elif p.type == syms.atom:
 
1353         if prev and t == token.DOT:
 
1354             # dots, but not the first one.
 
1358         p.type == syms.listmaker
 
1359         or p.type == syms.testlist_gexp
 
1360         or p.type == syms.subscriptlist
 
1362         # list interior, including unpacking
 
1366     elif p.type == syms.dictsetmaker:
 
1367         # dict and set interior, including unpacking
 
1371         if prev.type == token.DOUBLESTAR:
 
1374     elif p.type in {syms.factor, syms.star_expr}:
 
1377             prevp = preceding_leaf(p)
 
1378             if not prevp or prevp.type in OPENING_BRACKETS:
 
1381             prevp_parent = prevp.parent
 
1382             assert prevp_parent is not None
 
1383             if prevp.type == token.COLON and prevp_parent.type in {
 
1384                 syms.subscript, syms.sliceop
 
1388             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
 
1391         elif t == token.NAME or t == token.NUMBER:
 
1394     elif p.type == syms.import_from:
 
1396             if prev and prev.type == token.DOT:
 
1399         elif t == token.NAME:
 
1403             if prev and prev.type == token.DOT:
 
1406     elif p.type == syms.sliceop:
 
1412 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
 
1413     """Return the first leaf that precedes `node`, if any."""
 
1415         res = node.prev_sibling
 
1417             if isinstance(res, Leaf):
 
1421                 return list(res.leaves())[-1]
 
1430 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
 
1431     """Return the priority of the `leaf` delimiter, given a line break after it.
 
1433     The delimiter priorities returned here are from those delimiters that would
 
1434     cause a line break after themselves.
 
1436     Higher numbers are higher priority.
 
1438     if leaf.type == token.COMMA:
 
1439         return COMMA_PRIORITY
 
1442         leaf.type in VARARGS
 
1444         and leaf.parent.type in {syms.argument, syms.typedargslist}
 
1446         return MATH_PRIORITY
 
1451 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
 
1452     """Return the priority of the `leaf` delimiter, given a line before after it.
 
1454     The delimiter priorities returned here are from those delimiters that would
 
1455     cause a line break before themselves.
 
1457     Higher numbers are higher priority.
 
1460         leaf.type in MATH_OPERATORS
 
1462         and leaf.parent.type not in {syms.factor, syms.star_expr}
 
1464         return MATH_PRIORITY
 
1466     if leaf.type in COMPARATORS:
 
1467         return COMPARATOR_PRIORITY
 
1470         leaf.type == token.STRING
 
1471         and previous is not None
 
1472         and previous.type == token.STRING
 
1474         return STRING_PRIORITY
 
1477         leaf.type == token.NAME
 
1478         and leaf.value == "for"
 
1480         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
 
1482         return COMPREHENSION_PRIORITY
 
1485         leaf.type == token.NAME
 
1486         and leaf.value == "if"
 
1488         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
 
1490         return COMPREHENSION_PRIORITY
 
1492     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
 
1493         return LOGIC_PRIORITY
 
1498 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
 
1499     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
 
1501     Higher numbers are higher priority.
 
1504         is_split_before_delimiter(leaf, previous),
 
1505         is_split_after_delimiter(leaf, previous),
 
1509 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
 
1510     """Clean the prefix of the `leaf` and generate comments from it, if any.
 
1512     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
 
1513     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
 
1514     move because it does away with modifying the grammar to include all the
 
1515     possible places in which comments can be placed.
 
1517     The sad consequence for us though is that comments don't "belong" anywhere.
 
1518     This is why this function generates simple parentless Leaf objects for
 
1519     comments.  We simply don't know what the correct parent should be.
 
1521     No matter though, we can live without this.  We really only need to
 
1522     differentiate between inline and standalone comments.  The latter don't
 
1523     share the line with any code.
 
1525     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
 
1526     are emitted with a fake STANDALONE_COMMENT token identifier.
 
1537     for index, line in enumerate(p.split("\n")):
 
1538         consumed += len(line) + 1  # adding the length of the split '\n'
 
1539         line = line.lstrip()
 
1542         if not line.startswith("#"):
 
1545         if index == 0 and leaf.type != token.ENDMARKER:
 
1546             comment_type = token.COMMENT  # simple trailing comment
 
1548             comment_type = STANDALONE_COMMENT
 
1549         comment = make_comment(line)
 
1550         yield Leaf(comment_type, comment, prefix="\n" * nlines)
 
1552         if comment in {"# fmt: on", "# yapf: enable"}:
 
1553             raise FormatOn(consumed)
 
1555         if comment in {"# fmt: off", "# yapf: disable"}:
 
1556             if comment_type == STANDALONE_COMMENT:
 
1557                 raise FormatOff(consumed)
 
1559             prev = preceding_leaf(leaf)
 
1560             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
 
1561                 raise FormatOff(consumed)
 
1566 def make_comment(content: str) -> str:
 
1567     """Return a consistently formatted comment from the given `content` string.
 
1569     All comments (except for "##", "#!", "#:") should have a single space between
 
1570     the hash sign and the content.
 
1572     If `content` didn't start with a hash sign, one is provided.
 
1574     content = content.rstrip()
 
1578     if content[0] == "#":
 
1579         content = content[1:]
 
1580     if content and content[0] not in " !:#":
 
1581         content = " " + content
 
1582     return "#" + content
 
1586     line: Line, line_length: int, inner: bool = False, py36: bool = False
 
1587 ) -> Iterator[Line]:
 
1588     """Split a `line` into potentially many lines.
 
1590     They should fit in the allotted `line_length` but might not be able to.
 
1591     `inner` signifies that there were a pair of brackets somewhere around the
 
1592     current `line`, possibly transitively. This means we can fallback to splitting
 
1593     by delimiters if the LHS/RHS don't yield any results.
 
1595     If `py36` is True, splitting may generate syntax that is only compatible
 
1596     with Python 3.6 and later.
 
1598     if isinstance(line, UnformattedLines) or line.is_comment:
 
1602     line_str = str(line).strip("\n")
 
1604         len(line_str) <= line_length
 
1605         and "\n" not in line_str  # multiline strings
 
1606         and not line.contains_standalone_comments
 
1611     split_funcs: List[SplitFunc]
 
1613         split_funcs = [left_hand_split]
 
1614     elif line.inside_brackets:
 
1615         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
 
1617         split_funcs = [right_hand_split]
 
1618     for split_func in split_funcs:
 
1619         # We are accumulating lines in `result` because we might want to abort
 
1620         # mission and return the original line in the end, or attempt a different
 
1622         result: List[Line] = []
 
1624             for l in split_func(line, py36):
 
1625                 if str(l).strip("\n") == line_str:
 
1626                     raise CannotSplit("Split function returned an unchanged result")
 
1629                     split_line(l, line_length=line_length, inner=True, py36=py36)
 
1631         except CannotSplit as cs:
 
1642 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
1643     """Split line into many lines, starting with the first matching bracket pair.
 
1645     Note: this usually looks weird, only use this for function definitions.
 
1646     Prefer RHS otherwise.
 
1648     head = Line(depth=line.depth)
 
1649     body = Line(depth=line.depth + 1, inside_brackets=True)
 
1650     tail = Line(depth=line.depth)
 
1651     tail_leaves: List[Leaf] = []
 
1652     body_leaves: List[Leaf] = []
 
1653     head_leaves: List[Leaf] = []
 
1654     current_leaves = head_leaves
 
1655     matching_bracket = None
 
1656     for leaf in line.leaves:
 
1658             current_leaves is body_leaves
 
1659             and leaf.type in CLOSING_BRACKETS
 
1660             and leaf.opening_bracket is matching_bracket
 
1662             current_leaves = tail_leaves if body_leaves else head_leaves
 
1663         current_leaves.append(leaf)
 
1664         if current_leaves is head_leaves:
 
1665             if leaf.type in OPENING_BRACKETS:
 
1666                 matching_bracket = leaf
 
1667                 current_leaves = body_leaves
 
1668     # Since body is a new indent level, remove spurious leading whitespace.
 
1670         normalize_prefix(body_leaves[0], inside_brackets=True)
 
1671     # Build the new lines.
 
1672     for result, leaves in (
 
1673         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
 
1676             result.append(leaf, preformatted=True)
 
1677             for comment_after in line.comments_after(leaf):
 
1678                 result.append(comment_after, preformatted=True)
 
1679     bracket_split_succeeded_or_raise(head, body, tail)
 
1680     for result in (head, body, tail):
 
1685 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
1686     """Split line into many lines, starting with the last matching bracket pair."""
 
1687     head = Line(depth=line.depth)
 
1688     body = Line(depth=line.depth + 1, inside_brackets=True)
 
1689     tail = Line(depth=line.depth)
 
1690     tail_leaves: List[Leaf] = []
 
1691     body_leaves: List[Leaf] = []
 
1692     head_leaves: List[Leaf] = []
 
1693     current_leaves = tail_leaves
 
1694     opening_bracket = None
 
1695     for leaf in reversed(line.leaves):
 
1696         if current_leaves is body_leaves:
 
1697             if leaf is opening_bracket:
 
1698                 current_leaves = head_leaves if body_leaves else tail_leaves
 
1699         current_leaves.append(leaf)
 
1700         if current_leaves is tail_leaves:
 
1701             if leaf.type in CLOSING_BRACKETS:
 
1702                 opening_bracket = leaf.opening_bracket
 
1703                 current_leaves = body_leaves
 
1704     tail_leaves.reverse()
 
1705     body_leaves.reverse()
 
1706     head_leaves.reverse()
 
1707     # Since body is a new indent level, remove spurious leading whitespace.
 
1709         normalize_prefix(body_leaves[0], inside_brackets=True)
 
1710     # Build the new lines.
 
1711     for result, leaves in (
 
1712         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
 
1715             result.append(leaf, preformatted=True)
 
1716             for comment_after in line.comments_after(leaf):
 
1717                 result.append(comment_after, preformatted=True)
 
1718     bracket_split_succeeded_or_raise(head, body, tail)
 
1719     for result in (head, body, tail):
 
1724 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
 
1725     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
 
1727     Do nothing otherwise.
 
1729     A left- or right-hand split is based on a pair of brackets. Content before
 
1730     (and including) the opening bracket is left on one line, content inside the
 
1731     brackets is put on a separate line, and finally content starting with and
 
1732     following the closing bracket is put on a separate line.
 
1734     Those are called `head`, `body`, and `tail`, respectively. If the split
 
1735     produced the same line (all content in `head`) or ended up with an empty `body`
 
1736     and the `tail` is just the closing bracket, then it's considered failed.
 
1738     tail_len = len(str(tail).strip())
 
1741             raise CannotSplit("Splitting brackets produced the same line")
 
1745                 f"Splitting brackets on an empty body to save "
 
1746                 f"{tail_len} characters is not worth it"
 
1750 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
 
1751     """Normalize prefix of the first leaf in every line returned by `split_func`.
 
1753     This is a decorator over relevant split functions.
 
1757     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
 
1758         for l in split_func(line, py36):
 
1759             normalize_prefix(l.leaves[0], inside_brackets=True)
 
1762     return split_wrapper
 
1765 @dont_increase_indentation
 
1766 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
1767     """Split according to delimiters of the highest priority.
 
1769     If `py36` is True, the split will add trailing commas also in function
 
1770     signatures that contain `*` and `**`.
 
1773         last_leaf = line.leaves[-1]
 
1775         raise CannotSplit("Line empty")
 
1777     delimiters = line.bracket_tracker.delimiters
 
1779         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
 
1780             exclude={id(last_leaf)}
 
1783         raise CannotSplit("No delimiters found")
 
1785     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
 
1786     lowest_depth = sys.maxsize
 
1787     trailing_comma_safe = True
 
1789     def append_to_line(leaf: Leaf) -> Iterator[Line]:
 
1790         """Append `leaf` to current line or to new line if appending impossible."""
 
1791         nonlocal current_line
 
1793             current_line.append_safe(leaf, preformatted=True)
 
1794         except ValueError as ve:
 
1797             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
 
1798             current_line.append(leaf)
 
1800     for leaf in line.leaves:
 
1801         yield from append_to_line(leaf)
 
1803         for comment_after in line.comments_after(leaf):
 
1804             yield from append_to_line(comment_after)
 
1806         lowest_depth = min(lowest_depth, leaf.bracket_depth)
 
1808             leaf.bracket_depth == lowest_depth
 
1809             and leaf.type == token.STAR
 
1810             or leaf.type == token.DOUBLESTAR
 
1812             trailing_comma_safe = trailing_comma_safe and py36
 
1813         leaf_priority = delimiters.get(id(leaf))
 
1814         if leaf_priority == delimiter_priority:
 
1817             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
 
1821             and delimiter_priority == COMMA_PRIORITY
 
1822             and current_line.leaves[-1].type != token.COMMA
 
1823             and current_line.leaves[-1].type != STANDALONE_COMMENT
 
1825             current_line.append(Leaf(token.COMMA, ","))
 
1829 @dont_increase_indentation
 
1830 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
1831     """Split standalone comments from the rest of the line."""
 
1832     for leaf in line.leaves:
 
1833         if leaf.type == STANDALONE_COMMENT:
 
1834             if leaf.bracket_depth == 0:
 
1838         raise CannotSplit("Line does not have any standalone comments")
 
1840     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
 
1842     def append_to_line(leaf: Leaf) -> Iterator[Line]:
 
1843         """Append `leaf` to current line or to new line if appending impossible."""
 
1844         nonlocal current_line
 
1846             current_line.append_safe(leaf, preformatted=True)
 
1847         except ValueError as ve:
 
1850             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
 
1851             current_line.append(leaf)
 
1853     for leaf in line.leaves:
 
1854         yield from append_to_line(leaf)
 
1856         for comment_after in line.comments_after(leaf):
 
1857             yield from append_to_line(comment_after)
 
1863 def is_import(leaf: Leaf) -> bool:
 
1864     """Return True if the given leaf starts an import statement."""
 
1871             (v == "import" and p and p.type == syms.import_name)
 
1872             or (v == "from" and p and p.type == syms.import_from)
 
1877 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
 
1878     """Leave existing extra newlines if not `inside_brackets`. Remove everything
 
1881     Note: don't use backslashes for formatting or you'll lose your voting rights.
 
1883     if not inside_brackets:
 
1884         spl = leaf.prefix.split("#")
 
1885         if "\\" not in spl[0]:
 
1886             nl_count = spl[-1].count("\n")
 
1889             leaf.prefix = "\n" * nl_count
 
1895 def normalize_string_quotes(leaf: Leaf) -> None:
 
1896     """Prefer double quotes but only if it doesn't cause more escaping.
 
1898     Adds or removes backslashes as appropriate. Doesn't parse and fix
 
1899     strings nested in f-strings (yet).
 
1901     Note: Mutates its argument.
 
1903     value = leaf.value.lstrip("furbFURB")
 
1904     if value[:3] == '"""':
 
1907     elif value[:3] == "'''":
 
1910     elif value[0] == '"':
 
1916     first_quote_pos = leaf.value.find(orig_quote)
 
1917     if first_quote_pos == -1:
 
1918         return  # There's an internal error
 
1920     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
 
1921     new_body = body.replace(f"\\{orig_quote}", orig_quote).replace(
 
1922         new_quote, f"\\{new_quote}"
 
1924     if new_quote == '"""' and new_body[-1] == '"':
 
1926         new_body = new_body[:-1] + '\\"'
 
1927     orig_escape_count = body.count("\\")
 
1928     new_escape_count = new_body.count("\\")
 
1929     if new_escape_count > orig_escape_count:
 
1930         return  # Do not introduce more escaping
 
1932     if new_escape_count == orig_escape_count and orig_quote == '"':
 
1933         return  # Prefer double quotes
 
1935     prefix = leaf.value[:first_quote_pos]
 
1936     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
 
1939 def is_python36(node: Node) -> bool:
 
1940     """Return True if the current file is using Python 3.6+ features.
 
1942     Currently looking for:
 
1944     - trailing commas after * or ** in function signatures.
 
1946     for n in node.pre_order():
 
1947         if n.type == token.STRING:
 
1948             value_head = n.value[:2]  # type: ignore
 
1949             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
 
1953             n.type == syms.typedargslist
 
1955             and n.children[-1].type == token.COMMA
 
1957             for ch in n.children:
 
1958                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
 
1964 PYTHON_EXTENSIONS = {".py"}
 
1965 BLACKLISTED_DIRECTORIES = {
 
1966     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
 
1970 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
 
1971     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
 
1972     and have one of the PYTHON_EXTENSIONS.
 
1974     for child in path.iterdir():
 
1976             if child.name in BLACKLISTED_DIRECTORIES:
 
1979             yield from gen_python_files_in_dir(child)
 
1981         elif child.suffix in PYTHON_EXTENSIONS:
 
1987     """Provides a reformatting counter. Can be rendered with `str(report)`."""
 
1990     change_count: int = 0
 
1992     failure_count: int = 0
 
1994     def done(self, src: Path, changed: bool) -> None:
 
1995         """Increment the counter for successful reformatting. Write out a message."""
 
1997             reformatted = "would reformat" if self.check else "reformatted"
 
1999                 out(f"{reformatted} {src}")
 
2000             self.change_count += 1
 
2003                 out(f"{src} already well formatted, good job.", bold=False)
 
2004             self.same_count += 1
 
2006     def failed(self, src: Path, message: str) -> None:
 
2007         """Increment the counter for failed reformatting. Write out a message."""
 
2008         err(f"error: cannot format {src}: {message}")
 
2009         self.failure_count += 1
 
2012     def return_code(self) -> int:
 
2013         """Return the exit code that the app should use.
 
2015         This considers the current state of changed files and failures:
 
2016         - if there were any failures, return 123;
 
2017         - if any files were changed and --check is being used, return 1;
 
2018         - otherwise return 0.
 
2020         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
 
2021         # 126 we have special returncodes reserved by the shell.
 
2022         if self.failure_count:
 
2025         elif self.change_count and self.check:
 
2030     def __str__(self) -> str:
 
2031         """Render a color report of the current state.
 
2033         Use `click.unstyle` to remove colors.
 
2036             reformatted = "would be reformatted"
 
2037             unchanged = "would be left unchanged"
 
2038             failed = "would fail to reformat"
 
2040             reformatted = "reformatted"
 
2041             unchanged = "left unchanged"
 
2042             failed = "failed to reformat"
 
2044         if self.change_count:
 
2045             s = "s" if self.change_count > 1 else ""
 
2047                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
 
2050             s = "s" if self.same_count > 1 else ""
 
2051             report.append(f"{self.same_count} file{s} {unchanged}")
 
2052         if self.failure_count:
 
2053             s = "s" if self.failure_count > 1 else ""
 
2055                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
 
2057         return ", ".join(report) + "."
 
2060 def assert_equivalent(src: str, dst: str) -> None:
 
2061     """Raise AssertionError if `src` and `dst` aren't equivalent."""
 
2066     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
 
2067         """Simple visitor generating strings to compare ASTs by content."""
 
2068         yield f"{'  ' * depth}{node.__class__.__name__}("
 
2070         for field in sorted(node._fields):
 
2072                 value = getattr(node, field)
 
2073             except AttributeError:
 
2076             yield f"{'  ' * (depth+1)}{field}="
 
2078             if isinstance(value, list):
 
2080                     if isinstance(item, ast.AST):
 
2081                         yield from _v(item, depth + 2)
 
2083             elif isinstance(value, ast.AST):
 
2084                 yield from _v(value, depth + 2)
 
2087                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
 
2089         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
 
2092         src_ast = ast.parse(src)
 
2093     except Exception as exc:
 
2094         major, minor = sys.version_info[:2]
 
2095         raise AssertionError(
 
2096             f"cannot use --safe with this file; failed to parse source file "
 
2097             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
 
2098             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
 
2102         dst_ast = ast.parse(dst)
 
2103     except Exception as exc:
 
2104         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
 
2105         raise AssertionError(
 
2106             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
 
2107             f"Please report a bug on https://github.com/ambv/black/issues.  "
 
2108             f"This invalid output might be helpful: {log}"
 
2111     src_ast_str = "\n".join(_v(src_ast))
 
2112     dst_ast_str = "\n".join(_v(dst_ast))
 
2113     if src_ast_str != dst_ast_str:
 
2114         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
 
2115         raise AssertionError(
 
2116             f"INTERNAL ERROR: Black produced code that is not equivalent to "
 
2118             f"Please report a bug on https://github.com/ambv/black/issues.  "
 
2119             f"This diff might be helpful: {log}"
 
2123 def assert_stable(src: str, dst: str, line_length: int) -> None:
 
2124     """Raise AssertionError if `dst` reformats differently the second time."""
 
2125     newdst = format_str(dst, line_length=line_length)
 
2128             diff(src, dst, "source", "first pass"),
 
2129             diff(dst, newdst, "first pass", "second pass"),
 
2131         raise AssertionError(
 
2132             f"INTERNAL ERROR: Black produced different code on the second pass "
 
2133             f"of the formatter.  "
 
2134             f"Please report a bug on https://github.com/ambv/black/issues.  "
 
2135             f"This diff might be helpful: {log}"
 
2139 def dump_to_file(*output: str) -> str:
 
2140     """Dump `output` to a temporary file. Return path to the file."""
 
2143     with tempfile.NamedTemporaryFile(
 
2144         mode="w", prefix="blk_", suffix=".log", delete=False
 
2146         for lines in output:
 
2148             if lines and lines[-1] != "\n":
 
2153 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
 
2154     """Return a unified diff string between strings `a` and `b`."""
 
2157     a_lines = [line + "\n" for line in a.split("\n")]
 
2158     b_lines = [line + "\n" for line in b.split("\n")]
 
2160         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
 
2164 def cancel(tasks: List[asyncio.Task]) -> None:
 
2165     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
 
2171 def shutdown(loop: BaseEventLoop) -> None:
 
2172     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
 
2174         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
 
2175         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
 
2179         for task in to_cancel:
 
2181         loop.run_until_complete(
 
2182             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
 
2185         # `concurrent.futures.Future` objects cannot be cancelled once they
 
2186         # are already running. There might be some when the `shutdown()` happened.
 
2187         # Silence their logger's spew about the event loop being closed.
 
2188         cf_logger = logging.getLogger("concurrent.futures")
 
2189         cf_logger.setLevel(logging.CRITICAL)
 
2193 if __name__ == "__main__":