]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

774d91dc2a7f64ec2dbc252a55e0fe051d5c429c
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from functools import partial
6 import keyword
7 import os
8 from pathlib import Path
9 import tokenize
10 from typing import (
11     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
12 )
13
14 from attr import attrib, dataclass, Factory
15 import click
16
17 # lib2to3 fork
18 from blib2to3.pytree import Node, Leaf, type_repr
19 from blib2to3 import pygram, pytree
20 from blib2to3.pgen2 import driver, token
21 from blib2to3.pgen2.parse import ParseError
22
23 __version__ = "18.3a1"
24 DEFAULT_LINE_LENGTH = 88
25 # types
26 syms = pygram.python_symbols
27 FileContent = str
28 Encoding = str
29 Depth = int
30 NodeType = int
31 LeafID = int
32 Priority = int
33 LN = Union[Leaf, Node]
34 out = partial(click.secho, bold=True, err=True)
35 err = partial(click.secho, fg='red', err=True)
36
37
38 class NothingChanged(UserWarning):
39     """Raised by `format_file` when the reformatted code is the same as source."""
40
41
42 class CannotSplit(Exception):
43     """A readable split that fits the allotted line length is impossible.
44
45     Raised by `left_hand_split()` and `right_hand_split()`.
46     """
47
48
49 @click.command()
50 @click.option(
51     '-l',
52     '--line-length',
53     type=int,
54     default=DEFAULT_LINE_LENGTH,
55     help='How many character per line to allow.',
56     show_default=True,
57 )
58 @click.option(
59     '--check',
60     is_flag=True,
61     help=(
62         "Don't write back the files, just return the status.  Return code 0 "
63         "means nothing changed.  Return code 1 means some files were "
64         "reformatted.  Return code 123 means there was an internal error."
65     ),
66 )
67 @click.option(
68     '--fast/--safe',
69     is_flag=True,
70     help='If --fast given, skip temporary sanity checks. [default: --safe]',
71 )
72 @click.version_option(version=__version__)
73 @click.argument(
74     'src',
75     nargs=-1,
76     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
77 )
78 @click.pass_context
79 def main(
80     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
81 ) -> None:
82     """The uncompromising code formatter."""
83     sources: List[Path] = []
84     for s in src:
85         p = Path(s)
86         if p.is_dir():
87             sources.extend(gen_python_files_in_dir(p))
88         elif p.is_file():
89             # if a file was explicitly given, we don't care about its extension
90             sources.append(p)
91         else:
92             err(f'invalid path: {s}')
93     if len(sources) == 0:
94         ctx.exit(0)
95     elif len(sources) == 1:
96         p = sources[0]
97         report = Report()
98         try:
99             changed = format_file_in_place(
100                 p, line_length=line_length, fast=fast, write_back=not check
101             )
102             report.done(p, changed)
103         except Exception as exc:
104             report.failed(p, str(exc))
105         ctx.exit(report.return_code)
106     else:
107         loop = asyncio.get_event_loop()
108         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
109         return_code = 1
110         try:
111             return_code = loop.run_until_complete(
112                 schedule_formatting(
113                     sources, line_length, not check, fast, loop, executor
114                 )
115             )
116         finally:
117             loop.close()
118             ctx.exit(return_code)
119
120
121 async def schedule_formatting(
122     sources: List[Path],
123     line_length: int,
124     write_back: bool,
125     fast: bool,
126     loop: BaseEventLoop,
127     executor: Executor,
128 ) -> int:
129     tasks = {
130         src: loop.run_in_executor(
131             executor, format_file_in_place, src, line_length, fast, write_back
132         )
133         for src in sources
134     }
135     await asyncio.wait(tasks.values())
136     cancelled = []
137     report = Report()
138     for src, task in tasks.items():
139         if not task.done():
140             report.failed(src, 'timed out, cancelling')
141             task.cancel()
142             cancelled.append(task)
143         elif task.exception():
144             report.failed(src, str(task.exception()))
145         else:
146             report.done(src, task.result())
147     if cancelled:
148         await asyncio.wait(cancelled, timeout=2)
149     out('All done! ✨ 🍰 ✨')
150     click.echo(str(report))
151     return report.return_code
152
153
154 def format_file_in_place(
155     src: Path, line_length: int, fast: bool, write_back: bool = False
156 ) -> bool:
157     """Format the file and rewrite if changed. Return True if changed."""
158     try:
159         contents, encoding = format_file(src, line_length=line_length, fast=fast)
160     except NothingChanged:
161         return False
162
163     if write_back:
164         with open(src, "w", encoding=encoding) as f:
165             f.write(contents)
166     return True
167
168
169 def format_file(
170     src: Path, line_length: int, fast: bool
171 ) -> Tuple[FileContent, Encoding]:
172     """Reformats a file and returns its contents and encoding."""
173     with tokenize.open(src) as src_buffer:
174         src_contents = src_buffer.read()
175     if src_contents.strip() == '':
176         raise NothingChanged(src)
177
178     dst_contents = format_str(src_contents, line_length=line_length)
179     if src_contents == dst_contents:
180         raise NothingChanged(src)
181
182     if not fast:
183         assert_equivalent(src_contents, dst_contents)
184         assert_stable(src_contents, dst_contents, line_length=line_length)
185     return dst_contents, src_buffer.encoding
186
187
188 def format_str(src_contents: str, line_length: int) -> FileContent:
189     """Reformats a string and returns new contents."""
190     src_node = lib2to3_parse(src_contents)
191     dst_contents = ""
192     comments: List[Line] = []
193     lines = LineGenerator()
194     elt = EmptyLineTracker()
195     empty_line = Line()
196     after = 0
197     for current_line in lines.visit(src_node):
198         for _ in range(after):
199             dst_contents += str(empty_line)
200         before, after = elt.maybe_empty_lines(current_line)
201         for _ in range(before):
202             dst_contents += str(empty_line)
203         if not current_line.is_comment:
204             for comment in comments:
205                 dst_contents += str(comment)
206             comments = []
207             for line in split_line(current_line, line_length=line_length):
208                 dst_contents += str(line)
209         else:
210             comments.append(current_line)
211     for comment in comments:
212         dst_contents += str(comment)
213     return dst_contents
214
215
216 def lib2to3_parse(src_txt: str) -> Node:
217     """Given a string with source, return the lib2to3 Node."""
218     grammar = pygram.python_grammar_no_print_statement
219     drv = driver.Driver(grammar, pytree.convert)
220     if src_txt[-1] != '\n':
221         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
222         src_txt += nl
223     try:
224         result = drv.parse_string(src_txt, True)
225     except ParseError as pe:
226         lineno, column = pe.context[1]
227         lines = src_txt.splitlines()
228         try:
229             faulty_line = lines[lineno - 1]
230         except IndexError:
231             faulty_line = "<line number missing in source>"
232         raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None
233
234     if isinstance(result, Leaf):
235         result = Node(syms.file_input, [result])
236     return result
237
238
239 def lib2to3_unparse(node: Node) -> str:
240     """Given a lib2to3 node, return its string representation."""
241     code = str(node)
242     return code
243
244
245 T = TypeVar('T')
246
247
248 class Visitor(Generic[T]):
249     """Basic lib2to3 visitor that yields things on visiting."""
250
251     def visit(self, node: LN) -> Iterator[T]:
252         if node.type < 256:
253             name = token.tok_name[node.type]
254         else:
255             name = type_repr(node.type)
256         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
257
258     def visit_default(self, node: LN) -> Iterator[T]:
259         if isinstance(node, Node):
260             for child in node.children:
261                 yield from self.visit(child)
262
263
264 @dataclass
265 class DebugVisitor(Visitor[T]):
266     tree_depth: int = attrib(default=0)
267
268     def visit_default(self, node: LN) -> Iterator[T]:
269         indent = ' ' * (2 * self.tree_depth)
270         if isinstance(node, Node):
271             _type = type_repr(node.type)
272             out(f'{indent}{_type}', fg='yellow')
273             self.tree_depth += 1
274             for child in node.children:
275                 yield from self.visit(child)
276
277             self.tree_depth -= 1
278             out(f'{indent}/{_type}', fg='yellow', bold=False)
279         else:
280             _type = token.tok_name.get(node.type, str(node.type))
281             out(f'{indent}{_type}', fg='blue', nl=False)
282             if node.prefix:
283                 # We don't have to handle prefixes for `Node` objects since
284                 # that delegates to the first child anyway.
285                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
286             out(f' {node.value!r}', fg='blue', bold=False)
287
288
289 KEYWORDS = set(keyword.kwlist)
290 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
291 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
292 STATEMENT = {
293     syms.if_stmt,
294     syms.while_stmt,
295     syms.for_stmt,
296     syms.try_stmt,
297     syms.except_clause,
298     syms.with_stmt,
299     syms.funcdef,
300     syms.classdef,
301 }
302 STANDALONE_COMMENT = 153
303 LOGIC_OPERATORS = {'and', 'or'}
304 COMPARATORS = {
305     token.LESS,
306     token.GREATER,
307     token.EQEQUAL,
308     token.NOTEQUAL,
309     token.LESSEQUAL,
310     token.GREATEREQUAL,
311 }
312 MATH_OPERATORS = {
313     token.PLUS,
314     token.MINUS,
315     token.STAR,
316     token.SLASH,
317     token.VBAR,
318     token.AMPER,
319     token.PERCENT,
320     token.CIRCUMFLEX,
321     token.LEFTSHIFT,
322     token.RIGHTSHIFT,
323     token.DOUBLESTAR,
324     token.DOUBLESLASH,
325 }
326 COMPREHENSION_PRIORITY = 20
327 COMMA_PRIORITY = 10
328 LOGIC_PRIORITY = 5
329 STRING_PRIORITY = 4
330 COMPARATOR_PRIORITY = 3
331 MATH_PRIORITY = 1
332
333
334 @dataclass
335 class BracketTracker:
336     depth: int = attrib(default=0)
337     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = attrib(default=Factory(dict))
338     delimiters: Dict[LeafID, Priority] = attrib(default=Factory(dict))
339     previous: Optional[Leaf] = attrib(default=None)
340
341     def mark(self, leaf: Leaf) -> None:
342         if leaf.type == token.COMMENT:
343             return
344
345         if leaf.type in CLOSING_BRACKETS:
346             self.depth -= 1
347             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
348             leaf.opening_bracket = opening_bracket  # type: ignore
349         leaf.bracket_depth = self.depth  # type: ignore
350         if self.depth == 0:
351             delim = is_delimiter(leaf)
352             if delim:
353                 self.delimiters[id(leaf)] = delim
354             elif self.previous is not None:
355                 if leaf.type == token.STRING and self.previous.type == token.STRING:
356                     self.delimiters[id(self.previous)] = STRING_PRIORITY
357                 elif (
358                     leaf.type == token.NAME and
359                     leaf.value == 'for' and
360                     leaf.parent and
361                     leaf.parent.type in {syms.comp_for, syms.old_comp_for}
362                 ):
363                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
364                 elif (
365                     leaf.type == token.NAME and
366                     leaf.value == 'if' and
367                     leaf.parent and
368                     leaf.parent.type in {syms.comp_if, syms.old_comp_if}
369                 ):
370                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
371         if leaf.type in OPENING_BRACKETS:
372             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
373             self.depth += 1
374         self.previous = leaf
375
376     def any_open_brackets(self) -> bool:
377         """Returns True if there is an yet unmatched open bracket on the line."""
378         return bool(self.bracket_match)
379
380     def max_priority(self, exclude: Iterable[LeafID] =()) -> int:
381         """Returns the highest priority of a delimiter found on the line.
382
383         Values are consistent with what `is_delimiter()` returns.
384         """
385         return max(v for k, v in self.delimiters.items() if k not in exclude)
386
387
388 @dataclass
389 class Line:
390     depth: int = attrib(default=0)
391     leaves: List[Leaf] = attrib(default=Factory(list))
392     comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict))
393     bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker))
394     inside_brackets: bool = attrib(default=False)
395     has_for: bool = attrib(default=False)
396     _for_loop_variable: bool = attrib(default=False, init=False)
397
398     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
399         has_value = leaf.value.strip()
400         if not has_value:
401             return
402
403         if self.leaves and not preformatted:
404             # Note: at this point leaf.prefix should be empty except for
405             # imports, for which we only preserve newlines.
406             leaf.prefix += whitespace(leaf)
407         if self.inside_brackets or not preformatted:
408             self.maybe_decrement_after_for_loop_variable(leaf)
409             self.bracket_tracker.mark(leaf)
410             self.maybe_remove_trailing_comma(leaf)
411             self.maybe_increment_for_loop_variable(leaf)
412             if self.maybe_adapt_standalone_comment(leaf):
413                 return
414
415         if not self.append_comment(leaf):
416             self.leaves.append(leaf)
417
418     @property
419     def is_comment(self) -> bool:
420         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
421
422     @property
423     def is_decorator(self) -> bool:
424         return bool(self) and self.leaves[0].type == token.AT
425
426     @property
427     def is_import(self) -> bool:
428         return bool(self) and is_import(self.leaves[0])
429
430     @property
431     def is_class(self) -> bool:
432         return (
433             bool(self) and
434             self.leaves[0].type == token.NAME and
435             self.leaves[0].value == 'class'
436         )
437
438     @property
439     def is_def(self) -> bool:
440         """Also returns True for async defs."""
441         try:
442             first_leaf = self.leaves[0]
443         except IndexError:
444             return False
445
446         try:
447             second_leaf: Optional[Leaf] = self.leaves[1]
448         except IndexError:
449             second_leaf = None
450         return (
451             (first_leaf.type == token.NAME and first_leaf.value == 'def') or
452             (
453                 first_leaf.type == token.NAME and
454                 first_leaf.value == 'async' and
455                 second_leaf is not None and
456                 second_leaf.type == token.NAME and
457                 second_leaf.value == 'def'
458             )
459         )
460
461     @property
462     def is_flow_control(self) -> bool:
463         return (
464             bool(self) and
465             self.leaves[0].type == token.NAME and
466             self.leaves[0].value in FLOW_CONTROL
467         )
468
469     @property
470     def is_yield(self) -> bool:
471         return (
472             bool(self) and
473             self.leaves[0].type == token.NAME and
474             self.leaves[0].value == 'yield'
475         )
476
477     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
478         if not (
479             self.leaves and
480             self.leaves[-1].type == token.COMMA and
481             closing.type in CLOSING_BRACKETS
482         ):
483             return False
484
485         if closing.type == token.RSQB or closing.type == token.RBRACE:
486             self.leaves.pop()
487             return True
488
489         # For parens let's check if it's safe to remove the comma.  If the
490         # trailing one is the only one, we might mistakenly change a tuple
491         # into a different type by removing the comma.
492         depth = closing.bracket_depth + 1  # type: ignore
493         commas = 0
494         opening = closing.opening_bracket  # type: ignore
495         for _opening_index, leaf in enumerate(self.leaves):
496             if leaf is opening:
497                 break
498
499         else:
500             return False
501
502         for leaf in self.leaves[_opening_index + 1:]:
503             if leaf is closing:
504                 break
505
506             bracket_depth = leaf.bracket_depth  # type: ignore
507             if bracket_depth == depth and leaf.type == token.COMMA:
508                 commas += 1
509         if commas > 1:
510             self.leaves.pop()
511             return True
512
513         return False
514
515     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
516         """In a for loop, or comprehension, the variables are often unpacks.
517
518         To avoid splitting on the comma in this situation, we will increase
519         the depth of tokens between `for` and `in`.
520         """
521         if leaf.type == token.NAME and leaf.value == 'for':
522             self.has_for = True
523             self.bracket_tracker.depth += 1
524             self._for_loop_variable = True
525             return True
526
527         return False
528
529     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
530         # See `maybe_increment_for_loop_variable` above for explanation.
531         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
532             self.bracket_tracker.depth -= 1
533             self._for_loop_variable = False
534             return True
535
536         return False
537
538     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
539         """Hack a standalone comment to act as a trailing comment for line splitting.
540
541         If this line has brackets and a standalone `comment`, we need to adapt
542         it to be able to still reformat the line.
543
544         This is not perfect, the line to which the standalone comment gets
545         appended will appear "too long" when splitting.
546         """
547         if not (
548             comment.type == STANDALONE_COMMENT and
549             self.bracket_tracker.any_open_brackets()
550         ):
551             return False
552
553         comment.type = token.COMMENT
554         comment.prefix = '\n' + '    ' * (self.depth + 1)
555         return self.append_comment(comment)
556
557     def append_comment(self, comment: Leaf) -> bool:
558         if comment.type != token.COMMENT:
559             return False
560
561         try:
562             after = id(self.last_non_delimiter())
563         except LookupError:
564             comment.type = STANDALONE_COMMENT
565             comment.prefix = ''
566             return False
567
568         else:
569             if after in self.comments:
570                 self.comments[after].value += str(comment)
571             else:
572                 self.comments[after] = comment
573             return True
574
575     def last_non_delimiter(self) -> Leaf:
576         for i in range(len(self.leaves)):
577             last = self.leaves[-i - 1]
578             if not is_delimiter(last):
579                 return last
580
581         raise LookupError("No non-delimiters found")
582
583     def __str__(self) -> str:
584         if not self:
585             return '\n'
586
587         indent = '    ' * self.depth
588         leaves = iter(self.leaves)
589         first = next(leaves)
590         res = f'{first.prefix}{indent}{first.value}'
591         for leaf in leaves:
592             res += str(leaf)
593         for comment in self.comments.values():
594             res += str(comment)
595         return res + '\n'
596
597     def __bool__(self) -> bool:
598         return bool(self.leaves or self.comments)
599
600
601 @dataclass
602 class EmptyLineTracker:
603     """Provides a stateful method that returns the number of potential extra
604     empty lines needed before and after the currently processed line.
605
606     Note: this tracker works on lines that haven't been split yet.
607     """
608     previous_line: Optional[Line] = attrib(default=None)
609     previous_after: int = attrib(default=0)
610     previous_defs: List[int] = attrib(default=Factory(list))
611
612     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
613         """Returns the number of extra empty lines before and after the `current_line`.
614
615         This is for separating `def`, `async def` and `class` with extra empty lines
616         (two on module-level), as well as providing an extra empty line after flow
617         control keywords to make them more prominent.
618         """
619         before, after = self._maybe_empty_lines(current_line)
620         self.previous_after = after
621         self.previous_line = current_line
622         return before, after
623
624     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
625         before = 0
626         depth = current_line.depth
627         while self.previous_defs and self.previous_defs[-1] >= depth:
628             self.previous_defs.pop()
629             before = (1 if depth else 2) - self.previous_after
630         is_decorator = current_line.is_decorator
631         if is_decorator or current_line.is_def or current_line.is_class:
632             if not is_decorator:
633                 self.previous_defs.append(depth)
634             if self.previous_line is None:
635                 # Don't insert empty lines before the first line in the file.
636                 return 0, 0
637
638             if self.previous_line and self.previous_line.is_decorator:
639                 # Don't insert empty lines between decorators.
640                 return 0, 0
641
642             newlines = 2
643             if current_line.depth:
644                 newlines -= 1
645             newlines -= self.previous_after
646             return newlines, 0
647
648         if current_line.is_flow_control:
649             return before, 1
650
651         if (
652             self.previous_line and
653             self.previous_line.is_import and
654             not current_line.is_import and
655             depth == self.previous_line.depth
656         ):
657             return (before or 1), 0
658
659         if (
660             self.previous_line and
661             self.previous_line.is_yield and
662             (not current_line.is_yield or depth != self.previous_line.depth)
663         ):
664             return (before or 1), 0
665
666         return before, 0
667
668
669 @dataclass
670 class LineGenerator(Visitor[Line]):
671     """Generates reformatted Line objects.  Empty lines are not emitted.
672
673     Note: destroys the tree it's visiting by mutating prefixes of its leaves
674     in ways that will no longer stringify to valid Python code on the tree.
675     """
676     current_line: Line = attrib(default=Factory(Line))
677     standalone_comments: List[Leaf] = attrib(default=Factory(list))
678
679     def line(self, indent: int = 0) -> Iterator[Line]:
680         """Generate a line.
681
682         If the line is empty, only emit if it makes sense.
683         If the line is too long, split it first and then generate.
684
685         If any lines were generated, set up a new current_line.
686         """
687         if not self.current_line:
688             self.current_line.depth += indent
689             return  # Line is empty, don't emit. Creating a new one unnecessary.
690
691         complete_line = self.current_line
692         self.current_line = Line(depth=complete_line.depth + indent)
693         yield complete_line
694
695     def visit_default(self, node: LN) -> Iterator[Line]:
696         if isinstance(node, Leaf):
697             for comment in generate_comments(node):
698                 if self.current_line.bracket_tracker.any_open_brackets():
699                     # any comment within brackets is subject to splitting
700                     self.current_line.append(comment)
701                 elif comment.type == token.COMMENT:
702                     # regular trailing comment
703                     self.current_line.append(comment)
704                     yield from self.line()
705
706                 else:
707                     # regular standalone comment, to be processed later (see
708                     # docstring in `generate_comments()`
709                     self.standalone_comments.append(comment)
710             normalize_prefix(node)
711             if node.type not in WHITESPACE:
712                 for comment in self.standalone_comments:
713                     yield from self.line()
714
715                     self.current_line.append(comment)
716                     yield from self.line()
717
718                 self.standalone_comments = []
719                 self.current_line.append(node)
720         yield from super().visit_default(node)
721
722     def visit_suite(self, node: Node) -> Iterator[Line]:
723         """Body of a statement after a colon."""
724         children = iter(node.children)
725         # Process newline before indenting.  It might contain an inline
726         # comment that should go right after the colon.
727         newline = next(children)
728         yield from self.visit(newline)
729         yield from self.line(+1)
730
731         for child in children:
732             yield from self.visit(child)
733
734         yield from self.line(-1)
735
736     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
737         """Visit a statement.
738
739         The relevant Python language keywords for this statement are NAME leaves
740         within it.
741         """
742         for child in node.children:
743             if child.type == token.NAME and child.value in keywords:  # type: ignore
744                 yield from self.line()
745
746             yield from self.visit(child)
747
748     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
749         """A statement without nested statements."""
750         is_suite_like = node.parent and node.parent.type in STATEMENT
751         if is_suite_like:
752             yield from self.line(+1)
753             yield from self.visit_default(node)
754             yield from self.line(-1)
755
756         else:
757             yield from self.line()
758             yield from self.visit_default(node)
759
760     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
761         yield from self.line()
762
763         children = iter(node.children)
764         for child in children:
765             yield from self.visit(child)
766
767             if child.type == token.NAME and child.value == 'async':  # type: ignore
768                 break
769
770         internal_stmt = next(children)
771         for child in internal_stmt.children:
772             yield from self.visit(child)
773
774     def visit_decorators(self, node: Node) -> Iterator[Line]:
775         for child in node.children:
776             yield from self.line()
777             yield from self.visit(child)
778
779     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
780         yield from self.line()
781
782     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
783         yield from self.visit_default(leaf)
784         yield from self.line()
785
786     def __attrs_post_init__(self) -> None:
787         """You are in a twisty little maze of passages."""
788         v = self.visit_stmt
789         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
790         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
791         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
792         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
793         self.visit_except_clause = partial(v, keywords={'except'})
794         self.visit_funcdef = partial(v, keywords={'def'})
795         self.visit_with_stmt = partial(v, keywords={'with'})
796         self.visit_classdef = partial(v, keywords={'class'})
797         self.visit_async_funcdef = self.visit_async_stmt
798         self.visit_decorated = self.visit_decorators
799
800
801 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
802 OPENING_BRACKETS = set(BRACKET.keys())
803 CLOSING_BRACKETS = set(BRACKET.values())
804 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
805
806
807 def whitespace(leaf: Leaf) -> str:
808     """Return whitespace prefix if needed for the given `leaf`."""
809     NO = ''
810     SPACE = ' '
811     DOUBLESPACE = '  '
812     t = leaf.type
813     p = leaf.parent
814     v = leaf.value
815     if t == token.COLON:
816         return NO
817
818     if t == token.COMMA:
819         return NO
820
821     if t == token.RPAR:
822         return NO
823
824     if t == token.COMMENT:
825         return DOUBLESPACE
826
827     if t == STANDALONE_COMMENT:
828         return NO
829
830     if t in CLOSING_BRACKETS:
831         return NO
832
833     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
834     prev = leaf.prev_sibling
835     if not prev:
836         prevp = preceding_leaf(p)
837         if not prevp or prevp.type in OPENING_BRACKETS:
838             return NO
839
840         if prevp.type == token.EQUAL:
841             if prevp.parent and prevp.parent.type in {
842                 syms.typedargslist,
843                 syms.varargslist,
844                 syms.parameters,
845                 syms.arglist,
846                 syms.argument,
847             }:
848                 return NO
849
850         elif prevp.type == token.DOUBLESTAR:
851             if prevp.parent and prevp.parent.type in {
852                 syms.typedargslist,
853                 syms.varargslist,
854                 syms.parameters,
855                 syms.arglist,
856                 syms.dictsetmaker,
857             }:
858                 return NO
859
860         elif prevp.type == token.COLON:
861             if prevp.parent and prevp.parent.type == syms.subscript:
862                 return NO
863
864         elif prevp.parent and prevp.parent.type == syms.factor:
865             return NO
866
867     elif prev.type in OPENING_BRACKETS:
868         return NO
869
870     if p.type in {syms.parameters, syms.arglist}:
871         # untyped function signatures or calls
872         if t == token.RPAR:
873             return NO
874
875         if not prev or prev.type != token.COMMA:
876             return NO
877
878     if p.type == syms.varargslist:
879         # lambdas
880         if t == token.RPAR:
881             return NO
882
883         if prev and prev.type != token.COMMA:
884             return NO
885
886     elif p.type == syms.typedargslist:
887         # typed function signatures
888         if not prev:
889             return NO
890
891         if t == token.EQUAL:
892             if prev.type != syms.tname:
893                 return NO
894
895         elif prev.type == token.EQUAL:
896             # A bit hacky: if the equal sign has whitespace, it means we
897             # previously found it's a typed argument.  So, we're using that, too.
898             return prev.prefix
899
900         elif prev.type != token.COMMA:
901             return NO
902
903     elif p.type == syms.tname:
904         # type names
905         if not prev:
906             prevp = preceding_leaf(p)
907             if not prevp or prevp.type != token.COMMA:
908                 return NO
909
910     elif p.type == syms.trailer:
911         # attributes and calls
912         if t == token.LPAR or t == token.RPAR:
913             return NO
914
915         if not prev:
916             if t == token.DOT:
917                 prevp = preceding_leaf(p)
918                 if not prevp or prevp.type != token.NUMBER:
919                     return NO
920
921             elif t == token.LSQB:
922                 return NO
923
924         elif prev.type != token.COMMA:
925             return NO
926
927     elif p.type == syms.argument:
928         # single argument
929         if t == token.EQUAL:
930             return NO
931
932         if not prev:
933             prevp = preceding_leaf(p)
934             if not prevp or prevp.type == token.LPAR:
935                 return NO
936
937         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
938             return NO
939
940     elif p.type == syms.decorator:
941         # decorators
942         return NO
943
944     elif p.type == syms.dotted_name:
945         if prev:
946             return NO
947
948         prevp = preceding_leaf(p)
949         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
950             return NO
951
952     elif p.type == syms.classdef:
953         if t == token.LPAR:
954             return NO
955
956         if prev and prev.type == token.LPAR:
957             return NO
958
959     elif p.type == syms.subscript:
960         # indexing
961         if not prev or prev.type == token.COLON:
962             return NO
963
964     elif p.type == syms.atom:
965         if prev and t == token.DOT:
966             # dots, but not the first one.
967             return NO
968
969     elif (
970         p.type == syms.listmaker or
971         p.type == syms.testlist_gexp or
972         p.type == syms.subscriptlist
973     ):
974         # list interior, including unpacking
975         if not prev:
976             return NO
977
978     elif p.type == syms.dictsetmaker:
979         # dict and set interior, including unpacking
980         if not prev:
981             return NO
982
983         if prev.type == token.DOUBLESTAR:
984             return NO
985
986     elif p.type == syms.factor or p.type == syms.star_expr:
987         # unary ops
988         if not prev:
989             prevp = preceding_leaf(p)
990             if not prevp or prevp.type in OPENING_BRACKETS:
991                 return NO
992
993             prevp_parent = prevp.parent
994             assert prevp_parent is not None
995             if prevp.type == token.COLON and prevp_parent.type in {
996                 syms.subscript, syms.sliceop
997             }:
998                 return NO
999
1000             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1001                 return NO
1002
1003         elif t == token.NAME or t == token.NUMBER:
1004             return NO
1005
1006     elif p.type == syms.import_from:
1007         if t == token.DOT:
1008             if prev and prev.type == token.DOT:
1009                 return NO
1010
1011         elif t == token.NAME:
1012             if v == 'import':
1013                 return SPACE
1014
1015             if prev and prev.type == token.DOT:
1016                 return NO
1017
1018     elif p.type == syms.sliceop:
1019         return NO
1020
1021     return SPACE
1022
1023
1024 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1025     """Returns the first leaf that precedes `node`, if any."""
1026     while node:
1027         res = node.prev_sibling
1028         if res:
1029             if isinstance(res, Leaf):
1030                 return res
1031
1032             try:
1033                 return list(res.leaves())[-1]
1034
1035             except IndexError:
1036                 return None
1037
1038         node = node.parent
1039     return None
1040
1041
1042 def is_delimiter(leaf: Leaf) -> int:
1043     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1044
1045     Higher numbers are higher priority.
1046     """
1047     if leaf.type == token.COMMA:
1048         return COMMA_PRIORITY
1049
1050     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS:
1051         return LOGIC_PRIORITY
1052
1053     if leaf.type in COMPARATORS:
1054         return COMPARATOR_PRIORITY
1055
1056     if (
1057         leaf.type in MATH_OPERATORS and
1058         leaf.parent and
1059         leaf.parent.type not in {syms.factor, syms.star_expr}
1060     ):
1061         return MATH_PRIORITY
1062
1063     return 0
1064
1065
1066 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1067     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1068
1069     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1070     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1071     move because it does away with modifying the grammar to include all the
1072     possible places in which comments can be placed.
1073
1074     The sad consequence for us though is that comments don't "belong" anywhere.
1075     This is why this function generates simple parentless Leaf objects for
1076     comments.  We simply don't know what the correct parent should be.
1077
1078     No matter though, we can live without this.  We really only need to
1079     differentiate between inline and standalone comments.  The latter don't
1080     share the line with any code.
1081
1082     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1083     are emitted with a fake STANDALONE_COMMENT token identifier.
1084     """
1085     if not leaf.prefix:
1086         return
1087
1088     if '#' not in leaf.prefix:
1089         return
1090
1091     before_comment, content = leaf.prefix.split('#', 1)
1092     content = content.rstrip()
1093     if content and (content[0] not in {' ', '!', '#'}):
1094         content = ' ' + content
1095     is_standalone_comment = (
1096         '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
1097     )
1098     if not is_standalone_comment:
1099         # simple trailing comment
1100         yield Leaf(token.COMMENT, value='#' + content)
1101         return
1102
1103     for line in ('#' + content).split('\n'):
1104         line = line.lstrip()
1105         if not line.startswith('#'):
1106             continue
1107
1108         yield Leaf(STANDALONE_COMMENT, line)
1109
1110
1111 def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]:
1112     """Splits a `line` into potentially many lines.
1113
1114     They should fit in the allotted `line_length` but might not be able to.
1115     `inner` signifies that there were a pair of brackets somewhere around the
1116     current `line`, possibly transitively. This means we can fallback to splitting
1117     by delimiters if the LHS/RHS don't yield any results.
1118     """
1119     line_str = str(line).strip('\n')
1120     if len(line_str) <= line_length and '\n' not in line_str:
1121         yield line
1122         return
1123
1124     if line.is_def:
1125         split_funcs = [left_hand_split]
1126     elif line.inside_brackets:
1127         split_funcs = [delimiter_split]
1128         if '\n' not in line_str:
1129             # Only attempt RHS if we don't have multiline strings or comments
1130             # on this line.
1131             split_funcs.append(right_hand_split)
1132     else:
1133         split_funcs = [right_hand_split]
1134     for split_func in split_funcs:
1135         # We are accumulating lines in `result` because we might want to abort
1136         # mission and return the original line in the end, or attempt a different
1137         # split altogether.
1138         result: List[Line] = []
1139         try:
1140             for l in split_func(line):
1141                 if str(l).strip('\n') == line_str:
1142                     raise CannotSplit("Split function returned an unchanged result")
1143
1144                 result.extend(split_line(l, line_length=line_length, inner=True))
1145         except CannotSplit as cs:
1146             continue
1147
1148         else:
1149             yield from result
1150             break
1151
1152     else:
1153         yield line
1154
1155
1156 def left_hand_split(line: Line) -> Iterator[Line]:
1157     """Split line into many lines, starting with the first matching bracket pair.
1158
1159     Note: this usually looks weird, only use this for function definitions.
1160     Prefer RHS otherwise.
1161     """
1162     head = Line(depth=line.depth)
1163     body = Line(depth=line.depth + 1, inside_brackets=True)
1164     tail = Line(depth=line.depth)
1165     tail_leaves: List[Leaf] = []
1166     body_leaves: List[Leaf] = []
1167     head_leaves: List[Leaf] = []
1168     current_leaves = head_leaves
1169     matching_bracket = None
1170     for leaf in line.leaves:
1171         if (
1172             current_leaves is body_leaves and
1173             leaf.type in CLOSING_BRACKETS and
1174             leaf.opening_bracket is matching_bracket  # type: ignore
1175         ):
1176             current_leaves = tail_leaves
1177         current_leaves.append(leaf)
1178         if current_leaves is head_leaves:
1179             if leaf.type in OPENING_BRACKETS:
1180                 matching_bracket = leaf
1181                 current_leaves = body_leaves
1182     # Since body is a new indent level, remove spurious leading whitespace.
1183     if body_leaves:
1184         normalize_prefix(body_leaves[0])
1185     # Build the new lines.
1186     for result, leaves in (
1187         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1188     ):
1189         for leaf in leaves:
1190             result.append(leaf, preformatted=True)
1191             comment_after = line.comments.get(id(leaf))
1192             if comment_after:
1193                 result.append(comment_after, preformatted=True)
1194     # Check if the split succeeded.
1195     tail_len = len(str(tail))
1196     if not body:
1197         if tail_len == 0:
1198             raise CannotSplit("Splitting brackets produced the same line")
1199
1200         elif tail_len < 3:
1201             raise CannotSplit(
1202                 f"Splitting brackets on an empty body to save "
1203                 f"{tail_len} characters is not worth it"
1204             )
1205
1206     for result in (head, body, tail):
1207         if result:
1208             yield result
1209
1210
1211 def right_hand_split(line: Line) -> Iterator[Line]:
1212     """Split line into many lines, starting with the last matching bracket pair."""
1213     head = Line(depth=line.depth)
1214     body = Line(depth=line.depth + 1, inside_brackets=True)
1215     tail = Line(depth=line.depth)
1216     tail_leaves: List[Leaf] = []
1217     body_leaves: List[Leaf] = []
1218     head_leaves: List[Leaf] = []
1219     current_leaves = tail_leaves
1220     opening_bracket = None
1221     for leaf in reversed(line.leaves):
1222         if current_leaves is body_leaves:
1223             if leaf is opening_bracket:
1224                 current_leaves = head_leaves
1225         current_leaves.append(leaf)
1226         if current_leaves is tail_leaves:
1227             if leaf.type in CLOSING_BRACKETS:
1228                 opening_bracket = leaf.opening_bracket  # type: ignore
1229                 current_leaves = body_leaves
1230     tail_leaves.reverse()
1231     body_leaves.reverse()
1232     head_leaves.reverse()
1233     # Since body is a new indent level, remove spurious leading whitespace.
1234     if body_leaves:
1235         normalize_prefix(body_leaves[0])
1236     # Build the new lines.
1237     for result, leaves in (
1238         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1239     ):
1240         for leaf in leaves:
1241             result.append(leaf, preformatted=True)
1242             comment_after = line.comments.get(id(leaf))
1243             if comment_after:
1244                 result.append(comment_after, preformatted=True)
1245     # Check if the split succeeded.
1246     tail_len = len(str(tail).strip('\n'))
1247     if not body:
1248         if tail_len == 0:
1249             raise CannotSplit("Splitting brackets produced the same line")
1250
1251         elif tail_len < 3:
1252             raise CannotSplit(
1253                 f"Splitting brackets on an empty body to save "
1254                 f"{tail_len} characters is not worth it"
1255             )
1256
1257     for result in (head, body, tail):
1258         if result:
1259             yield result
1260
1261
1262 def delimiter_split(line: Line) -> Iterator[Line]:
1263     """Split according to delimiters of the highest priority.
1264
1265     This kind of split doesn't increase indentation.
1266     """
1267     try:
1268         last_leaf = line.leaves[-1]
1269     except IndexError:
1270         raise CannotSplit("Line empty")
1271
1272     delimiters = line.bracket_tracker.delimiters
1273     try:
1274         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1275     except ValueError:
1276         raise CannotSplit("No delimiters found")
1277
1278     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1279     for leaf in line.leaves:
1280         current_line.append(leaf, preformatted=True)
1281         comment_after = line.comments.get(id(leaf))
1282         if comment_after:
1283             current_line.append(comment_after, preformatted=True)
1284         leaf_priority = delimiters.get(id(leaf))
1285         if leaf_priority == delimiter_priority:
1286             normalize_prefix(current_line.leaves[0])
1287             yield current_line
1288
1289             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1290     if current_line:
1291         if (
1292             delimiter_priority == COMMA_PRIORITY and
1293             current_line.leaves[-1].type != token.COMMA
1294         ):
1295             current_line.append(Leaf(token.COMMA, ','))
1296         normalize_prefix(current_line.leaves[0])
1297         yield current_line
1298
1299
1300 def is_import(leaf: Leaf) -> bool:
1301     """Returns True if the given leaf starts an import statement."""
1302     p = leaf.parent
1303     t = leaf.type
1304     v = leaf.value
1305     return bool(
1306         t == token.NAME and
1307         (
1308             (v == 'import' and p and p.type == syms.import_name) or
1309             (v == 'from' and p and p.type == syms.import_from)
1310         )
1311     )
1312
1313
1314 def normalize_prefix(leaf: Leaf) -> None:
1315     """Leave existing extra newlines for imports.  Remove everything else."""
1316     if is_import(leaf):
1317         spl = leaf.prefix.split('#', 1)
1318         nl_count = spl[0].count('\n')
1319         if len(spl) > 1:
1320             # Skip one newline since it was for a standalone comment.
1321             nl_count -= 1
1322         leaf.prefix = '\n' * nl_count
1323         return
1324
1325     leaf.prefix = ''
1326
1327
1328 PYTHON_EXTENSIONS = {'.py'}
1329 BLACKLISTED_DIRECTORIES = {
1330     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1331 }
1332
1333
1334 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1335     for child in path.iterdir():
1336         if child.is_dir():
1337             if child.name in BLACKLISTED_DIRECTORIES:
1338                 continue
1339
1340             yield from gen_python_files_in_dir(child)
1341
1342         elif child.suffix in PYTHON_EXTENSIONS:
1343             yield child
1344
1345
1346 @dataclass
1347 class Report:
1348     """Provides a reformatting counter."""
1349     change_count: int = attrib(default=0)
1350     same_count: int = attrib(default=0)
1351     failure_count: int = attrib(default=0)
1352
1353     def done(self, src: Path, changed: bool) -> None:
1354         """Increment the counter for successful reformatting. Write out a message."""
1355         if changed:
1356             out(f'reformatted {src}')
1357             self.change_count += 1
1358         else:
1359             out(f'{src} already well formatted, good job.', bold=False)
1360             self.same_count += 1
1361
1362     def failed(self, src: Path, message: str) -> None:
1363         """Increment the counter for failed reformatting. Write out a message."""
1364         err(f'error: cannot format {src}: {message}')
1365         self.failure_count += 1
1366
1367     @property
1368     def return_code(self) -> int:
1369         """Which return code should the app use considering the current state."""
1370         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1371         # 126 we have special returncodes reserved by the shell.
1372         if self.failure_count:
1373             return 123
1374
1375         elif self.change_count:
1376             return 1
1377
1378         return 0
1379
1380     def __str__(self) -> str:
1381         """A color report of the current state.
1382
1383         Use `click.unstyle` to remove colors.
1384         """
1385         report = []
1386         if self.change_count:
1387             s = 's' if self.change_count > 1 else ''
1388             report.append(
1389                 click.style(f'{self.change_count} file{s} reformatted', bold=True)
1390             )
1391         if self.same_count:
1392             s = 's' if self.same_count > 1 else ''
1393             report.append(f'{self.same_count} file{s} left unchanged')
1394         if self.failure_count:
1395             s = 's' if self.failure_count > 1 else ''
1396             report.append(
1397                 click.style(
1398                     f'{self.failure_count} file{s} failed to reformat', fg='red'
1399                 )
1400             )
1401         return ', '.join(report) + '.'
1402
1403
1404 def assert_equivalent(src: str, dst: str) -> None:
1405     """Raises AssertionError if `src` and `dst` aren't equivalent.
1406
1407     This is a temporary sanity check until Black becomes stable.
1408     """
1409
1410     import ast
1411     import traceback
1412
1413     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1414         """Simple visitor generating strings to compare ASTs by content."""
1415         yield f"{'  ' * depth}{node.__class__.__name__}("
1416
1417         for field in sorted(node._fields):
1418             try:
1419                 value = getattr(node, field)
1420             except AttributeError:
1421                 continue
1422
1423             yield f"{'  ' * (depth+1)}{field}="
1424
1425             if isinstance(value, list):
1426                 for item in value:
1427                     if isinstance(item, ast.AST):
1428                         yield from _v(item, depth + 2)
1429
1430             elif isinstance(value, ast.AST):
1431                 yield from _v(value, depth + 2)
1432
1433             else:
1434                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1435
1436         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1437
1438     try:
1439         src_ast = ast.parse(src)
1440     except Exception as exc:
1441         raise AssertionError(f"cannot parse source: {exc}") from None
1442
1443     try:
1444         dst_ast = ast.parse(dst)
1445     except Exception as exc:
1446         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1447         raise AssertionError(
1448             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1449             f"Please report a bug on https://github.com/ambv/black/issues.  "
1450             f"This invalid output might be helpful: {log}",
1451         ) from None
1452
1453     src_ast_str = '\n'.join(_v(src_ast))
1454     dst_ast_str = '\n'.join(_v(dst_ast))
1455     if src_ast_str != dst_ast_str:
1456         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1457         raise AssertionError(
1458             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1459             f"the source.  "
1460             f"Please report a bug on https://github.com/ambv/black/issues.  "
1461             f"This diff might be helpful: {log}",
1462         ) from None
1463
1464
1465 def assert_stable(src: str, dst: str, line_length: int) -> None:
1466     """Raises AssertionError if `dst` reformats differently the second time.
1467
1468     This is a temporary sanity check until Black becomes stable.
1469     """
1470     newdst = format_str(dst, line_length=line_length)
1471     if dst != newdst:
1472         log = dump_to_file(
1473             diff(src, dst, 'source', 'first pass'),
1474             diff(dst, newdst, 'first pass', 'second pass'),
1475         )
1476         raise AssertionError(
1477             f"INTERNAL ERROR: Black produced different code on the second pass "
1478             f"of the formatter.  "
1479             f"Please report a bug on https://github.com/ambv/black/issues.  "
1480             f"This diff might be helpful: {log}",
1481         ) from None
1482
1483
1484 def dump_to_file(*output: str) -> str:
1485     """Dumps `output` to a temporary file. Returns path to the file."""
1486     import tempfile
1487
1488     with tempfile.NamedTemporaryFile(
1489         mode='w', prefix='blk_', suffix='.log', delete=False
1490     ) as f:
1491         for lines in output:
1492             f.write(lines)
1493             f.write('\n')
1494     return f.name
1495
1496
1497 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1498     """Returns a udiff string between strings `a` and `b`."""
1499     import difflib
1500
1501     a_lines = [line + '\n' for line in a.split('\n')]
1502     b_lines = [line + '\n' for line in b.split('\n')]
1503     return ''.join(
1504         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1505     )
1506
1507
1508 if __name__ == '__main__':
1509     main()