node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
yield from self.visit_default(node)
+ def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
+ # Check if it's a docstring
+ if prev_siblings_are(
+ leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
+ ) and is_multiline_string(leaf):
+ prefix = " " * self.current_line.depth
+ docstring = fix_docstring(leaf.value[3:-3], prefix)
+ leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:]
+ normalize_string_quotes(leaf)
+
+ yield from self.visit_default(leaf)
+
def __post_init__(self) -> None:
"""You are in a twisty little maze of passages."""
v = self.visit_stmt
return None
+def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
+ """Return if the `node` and its previous siblings match types against the provided
+ list of tokens; the provided `node`has its type matched against the last element in
+ the list. `None` can be used as the first element to declare that the start of the
+ list is anchored at the start of its parent's children."""
+ if not tokens:
+ return True
+ if tokens[-1] is None:
+ return node is None
+ if not node:
+ return False
+ if node.type != tokens[-1]:
+ return False
+ return prev_siblings_are(node.prev_sibling, tokens[:-1])
+
+
def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
"""Return the child of `ancestor` that contains `descendant`."""
node: Optional[LN] = descendant
return node
-def assert_equivalent(src: str, dst: str) -> None:
- """Raise AssertionError if `src` and `dst` aren't equivalent."""
-
- def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
- """Simple visitor generating strings to compare ASTs by content."""
+def _stringify_ast(
+ node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
+) -> Iterator[str]:
+ """Simple visitor generating strings to compare ASTs by content."""
- node = _fixup_ast_constants(node)
+ node = _fixup_ast_constants(node)
- yield f"{' ' * depth}{node.__class__.__name__}("
+ yield f"{' ' * depth}{node.__class__.__name__}("
- for field in sorted(node._fields): # noqa: F402
- # TypeIgnore has only one field 'lineno' which breaks this comparison
- type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
- if sys.version_info >= (3, 8):
- type_ignore_classes += (ast.TypeIgnore,)
- if isinstance(node, type_ignore_classes):
- break
+ for field in sorted(node._fields): # noqa: F402
+ # TypeIgnore has only one field 'lineno' which breaks this comparison
+ type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
+ if sys.version_info >= (3, 8):
+ type_ignore_classes += (ast.TypeIgnore,)
+ if isinstance(node, type_ignore_classes):
+ break
- try:
- value = getattr(node, field)
- except AttributeError:
- continue
+ try:
+ value = getattr(node, field)
+ except AttributeError:
+ continue
- yield f"{' ' * (depth+1)}{field}="
+ yield f"{' ' * (depth+1)}{field}="
- if isinstance(value, list):
- for item in value:
- # Ignore nested tuples within del statements, because we may insert
- # parentheses and they change the AST.
- if (
- field == "targets"
- and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
- and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
- ):
- for item in item.elts:
- yield from _v(item, depth + 2)
+ if isinstance(value, list):
+ for item in value:
+ # Ignore nested tuples within del statements, because we may insert
+ # parentheses and they change the AST.
+ if (
+ field == "targets"
+ and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
+ and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
+ ):
+ for item in item.elts:
+ yield from _stringify_ast(item, depth + 2)
- elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
- yield from _v(item, depth + 2)
+ elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
+ yield from _stringify_ast(item, depth + 2)
- elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
- yield from _v(value, depth + 2)
+ elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
+ yield from _stringify_ast(value, depth + 2)
+ else:
+ # Constant strings may be indented across newlines, if they are
+ # docstrings; fold spaces after newlines when comparing
+ if (
+ isinstance(node, ast.Constant)
+ and field == "value"
+ and isinstance(value, str)
+ ):
+ normalized = re.sub(r"\n[ \t]+", "\n ", value)
else:
- yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
+ normalized = value
+ yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}"
+
+ yield f"{' ' * depth}) # /{node.__class__.__name__}"
- yield f"{' ' * depth}) # /{node.__class__.__name__}"
+def assert_equivalent(src: str, dst: str) -> None:
+ """Raise AssertionError if `src` and `dst` aren't equivalent."""
try:
src_ast = parse_ast(src)
except Exception as exc:
f" helpful: {log}"
) from None
- src_ast_str = "\n".join(_v(src_ast))
- dst_ast_str = "\n".join(_v(dst_ast))
+ src_ast_str = "\n".join(_stringify_ast(src_ast))
+ dst_ast_str = "\n".join(_stringify_ast(dst_ast))
if src_ast_str != dst_ast_str:
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError(
main()
+def fix_docstring(docstring: str, prefix: str) -> str:
+ # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
+ if not docstring:
+ return ""
+ # Convert tabs to spaces (following the normal Python rules)
+ # and split into a list of lines:
+ lines = docstring.expandtabs().splitlines()
+ # Determine minimum indentation (first line doesn't count):
+ indent = sys.maxsize
+ for line in lines[1:]:
+ stripped = line.lstrip()
+ if stripped:
+ indent = min(indent, len(line) - len(stripped))
+ # Remove indentation (first line is special):
+ trimmed = [lines[0].strip()]
+ if indent < sys.maxsize:
+ last_line_idx = len(lines) - 2
+ for i, line in enumerate(lines[1:]):
+ stripped_line = line[indent:].rstrip()
+ if stripped_line or i == last_line_idx:
+ trimmed.append(prefix + stripped_line)
+ else:
+ trimmed.append("")
+ # Return a single string:
+ return "\n".join(trimmed)
+
+
if __name__ == "__main__":
patched_main()