]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Split out Change Log (#1117)
[etc/vim.git] / black.py
index d837c89609f3e7237786b7da1657527382612c52..b7cacf7767855ec903578b5be7499d526f1d542f 100644 (file)
--- a/black.py
+++ b/black.py
@@ -37,9 +37,11 @@ from typing import (
     Union,
     cast,
 )
+from typing_extensions import Final
+from mypy_extensions import mypyc_attr
 
 from appdirs import user_cache_dir
-from attr import dataclass, evolve, Factory
+from dataclasses import dataclass, field, replace
 import click
 import toml
 from typed_ast import ast3, ast27
@@ -185,7 +187,7 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
 
 @dataclass
 class FileMode:
-    target_versions: Set[TargetVersion] = Factory(set)
+    target_versions: Set[TargetVersion] = field(default_factory=set)
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
     is_pyi: bool = False
@@ -211,6 +213,23 @@ def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> b
     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 
+def find_pyproject_toml(path_search_start: str) -> Optional[str]:
+    """Find the absolute filepath to a pyproject.toml if it exists"""
+    path_project_root = find_project_root(path_search_start)
+    path_pyproject_toml = path_project_root / "pyproject.toml"
+    return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+
+
+def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
+    """Parse a pyproject toml file, pulling out relevant parts for Black
+
+    If parsing fails, will raise a toml.TomlDecodeError
+    """
+    pyproject_toml = toml.load(path_config)
+    config = pyproject_toml.get("tool", {}).get("black", {})
+    return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
+
+
 def read_pyproject_toml(
     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
 ) -> Optional[str]:
@@ -221,16 +240,12 @@ def read_pyproject_toml(
     """
     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
     if not value:
-        root = find_project_root(ctx.params.get("src", ()))
-        path = root / "pyproject.toml"
-        if path.is_file():
-            value = str(path)
-        else:
+        value = find_pyproject_toml(ctx.params.get("src", ()))
+        if value is None:
             return None
 
     try:
-        pyproject_toml = toml.load(value)
-        config = pyproject_toml.get("tool", {}).get("black", {})
+        config = parse_pyproject_toml(value)
     except (toml.TomlDecodeError, OSError) as e:
         raise click.FileError(
             filename=value, hint=f"Error reading configuration file: {e}"
@@ -241,12 +256,21 @@ def read_pyproject_toml(
 
     if ctx.default_map is None:
         ctx.default_map = {}
-    ctx.default_map.update(  # type: ignore  # bad types in .pyi
-        {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
-    )
+    ctx.default_map.update(config)  # type: ignore  # bad types in .pyi
     return value
 
 
+def target_version_option_callback(
+    c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
+) -> List[TargetVersion]:
+    """Compute the target versions from a --target-version flag.
+
+    This is its own function because mypy couldn't infer the type correctly
+    when it was a lambda, causing mypyc trouble.
+    """
+    return [TargetVersion[val.upper()] for val in v]
+
+
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
@@ -261,7 +285,7 @@ def read_pyproject_toml(
     "-t",
     "--target-version",
     type=click.Choice([v.name.lower() for v in TargetVersion]),
-    callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
+    callback=target_version_option_callback,
     multiple=True,
     help=(
         "Python versions that should be supported by Black's output. [default: "
@@ -388,7 +412,7 @@ def main(
     verbose: bool,
     include: str,
     exclude: str,
-    src: Tuple[str],
+    src: Tuple[str, ...],
     config: Optional[str],
 ) -> None:
     """The uncompromising code formatter."""
@@ -429,7 +453,7 @@ def main(
     except re.error:
         err(f"Invalid regular expression for exclude given: {exclude!r}")
         ctx.exit(2)
-    report = Report(check=check, quiet=quiet, verbose=verbose)
+    report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
     root = find_project_root(src)
     sources: Set[Path] = set()
     path_empty(src, quiet, verbose, ctx)
@@ -470,7 +494,9 @@ def main(
     ctx.exit(report.return_code)
 
 
-def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None:
+def path_empty(
+    src: Tuple[str, ...], quiet: bool, verbose: bool, ctx: click.Context
+) -> None:
     """
     Exit if there is no `src` provided for formatting
     """
@@ -585,7 +611,7 @@ async def schedule_formatting(
         ): src
         for src in sorted(sources)
     }
-    pending: Iterable[asyncio.Future] = tasks.keys()
+    pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
     try:
         loop.add_signal_handler(signal.SIGINT, cancel, pending)
         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
@@ -629,7 +655,7 @@ def format_file_in_place(
     `mode` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
-        mode = evolve(mode, is_pyi=True)
+        mode = replace(mode, is_pyi=True)
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
@@ -639,10 +665,10 @@ def format_file_in_place(
     except NothingChanged:
         return False
 
-    if write_back == write_back.YES:
+    if write_back == WriteBack.YES:
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
-    elif write_back == write_back.DIFF:
+    elif write_back == WriteBack.DIFF:
         now = datetime.utcnow()
         src_name = f"{src}\t{then} +0000"
         dst_name = f"{src}\t{now} +0000"
@@ -745,11 +771,9 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
         if supports_feature(versions, feature)
     }
     for current_line in lines.visit(src_node):
-        for _ in range(after):
-            dst_contents.append(str(empty_line))
+        dst_contents.append(str(empty_line) * after)
         before, after = elt.maybe_empty_lines(current_line)
-        for _ in range(before):
-            dst_contents.append(str(empty_line))
+        dst_contents.append(str(empty_line) * before)
         for line in split_line(
             current_line, line_length=mode.line_length, features=split_line_features
         ):
@@ -865,8 +889,16 @@ class Visitor(Generic[T]):
         if node.type < 256:
             name = token.tok_name[node.type]
         else:
-            name = type_repr(node.type)
-        yield from getattr(self, f"visit_{name}", self.visit_default)(node)
+            name = str(type_repr(node.type))
+        # We explicitly branch on whether a visitor exists (instead of
+        # using self.visit_default as the default arg to getattr) in order
+        # to save needing to create a bound method object and so mypyc can
+        # generate a native call to visit_default.
+        visitf = getattr(self, f"visit_{name}", None)
+        if visitf:
+            yield from visitf(node)
+        else:
+            yield from self.visit_default(node)
 
     def visit_default(self, node: LN) -> Iterator[T]:
         """Default `visit_*()` implementation. Recurses to children of `node`."""
@@ -911,8 +943,8 @@ class DebugVisitor(Visitor[T]):
         list(v.visit(code))
 
 
-WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
-STATEMENT = {
+WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
+STATEMENT: Final = {
     syms.if_stmt,
     syms.while_stmt,
     syms.for_stmt,
@@ -922,10 +954,10 @@ STATEMENT = {
     syms.funcdef,
     syms.classdef,
 }
-STANDALONE_COMMENT = 153
+STANDALONE_COMMENT: Final = 153
 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
-LOGIC_OPERATORS = {"and", "or"}
-COMPARATORS = {
+LOGIC_OPERATORS: Final = {"and", "or"}
+COMPARATORS: Final = {
     token.LESS,
     token.GREATER,
     token.EQEQUAL,
@@ -933,7 +965,7 @@ COMPARATORS = {
     token.LESSEQUAL,
     token.GREATEREQUAL,
 }
-MATH_OPERATORS = {
+MATH_OPERATORS: Final = {
     token.VBAR,
     token.CIRCUMFLEX,
     token.AMPER,
@@ -949,23 +981,23 @@ MATH_OPERATORS = {
     token.TILDE,
     token.DOUBLESTAR,
 }
-STARS = {token.STAR, token.DOUBLESTAR}
-VARARGS_SPECIALS = STARS | {token.SLASH}
-VARARGS_PARENTS = {
+STARS: Final = {token.STAR, token.DOUBLESTAR}
+VARARGS_SPECIALS: Final = STARS | {token.SLASH}
+VARARGS_PARENTS: Final = {
     syms.arglist,
     syms.argument,  # double star in arglist
     syms.trailer,  # single argument to call
     syms.typedargslist,
     syms.varargslist,  # lambdas
 }
-UNPACKING_PARENTS = {
+UNPACKING_PARENTS: Final = {
     syms.atom,  # single element of a list or set literal
     syms.dictsetmaker,
     syms.listmaker,
     syms.testlist_gexp,
     syms.testlist_star_expr,
 }
-TEST_DESCENDANTS = {
+TEST_DESCENDANTS: Final = {
     syms.test,
     syms.lambdef,
     syms.or_test,
@@ -982,7 +1014,7 @@ TEST_DESCENDANTS = {
     syms.term,
     syms.power,
 }
-ASSIGNMENTS = {
+ASSIGNMENTS: Final = {
     "=",
     "+=",
     "-=",
@@ -998,13 +1030,13 @@ ASSIGNMENTS = {
     "**=",
     "//=",
 }
-COMPREHENSION_PRIORITY = 20
-COMMA_PRIORITY = 18
-TERNARY_PRIORITY = 16
-LOGIC_PRIORITY = 14
-STRING_PRIORITY = 12
-COMPARATOR_PRIORITY = 10
-MATH_PRIORITIES = {
+COMPREHENSION_PRIORITY: Final = 20
+COMMA_PRIORITY: Final = 18
+TERNARY_PRIORITY: Final = 16
+LOGIC_PRIORITY: Final = 14
+STRING_PRIORITY: Final = 12
+COMPARATOR_PRIORITY: Final = 10
+MATH_PRIORITIES: Final = {
     token.VBAR: 9,
     token.CIRCUMFLEX: 8,
     token.AMPER: 7,
@@ -1020,7 +1052,7 @@ MATH_PRIORITIES = {
     token.TILDE: 3,
     token.DOUBLESTAR: 2,
 }
-DOT_PRIORITY = 1
+DOT_PRIORITY: Final = 1
 
 
 @dataclass
@@ -1028,11 +1060,11 @@ class BracketTracker:
     """Keeps track of brackets on a line."""
 
     depth: int = 0
-    bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
-    delimiters: Dict[LeafID, Priority] = Factory(dict)
+    bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
+    delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
     previous: Optional[Leaf] = None
-    _for_loop_depths: List[int] = Factory(list)
-    _lambda_argument_depths: List[int] = Factory(list)
+    _for_loop_depths: List[int] = field(default_factory=list)
+    _lambda_argument_depths: List[int] = field(default_factory=list)
 
     def mark(self, leaf: Leaf) -> None:
         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
@@ -1160,9 +1192,10 @@ class Line:
     """Holds leaves and comments. Can be printed with `str(line)`."""
 
     depth: int = 0
-    leaves: List[Leaf] = Factory(list)
-    comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
-    bracket_tracker: BracketTracker = Factory(BracketTracker)
+    leaves: List[Leaf] = field(default_factory=list)
+    # keys ordered like `leaves`
+    comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
+    bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
     inside_brackets: bool = False
     should_explode: bool = False
 
@@ -1383,7 +1416,10 @@ class Line:
         for leaf_id, comments in self.comments.items():
             for comment in comments:
                 if is_type_comment(comment):
-                    if leaf_id not in ignored_ids or comment_seen:
+                    if comment_seen or (
+                        not is_type_comment(comment, " ignore")
+                        and leaf_id not in ignored_ids
+                    ):
                         return True
 
                 comment_seen = True
@@ -1422,11 +1458,7 @@ class Line:
         return False
 
     def contains_multiline_strings(self) -> bool:
-        for leaf in self.leaves:
-            if is_multiline_string(leaf):
-                return True
-
-        return False
+        return any(is_multiline_string(leaf) for leaf in self.leaves)
 
     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
         """Remove trailing comma if there is one and it's safe."""
@@ -1565,7 +1597,7 @@ class EmptyLineTracker:
     is_pyi: bool = False
     previous_line: Optional[Line] = None
     previous_after: int = 0
-    previous_defs: List[int] = Factory(list)
+    previous_defs: List[int] = field(default_factory=list)
 
     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
         """Return the number of extra empty lines before and after the `current_line`.
@@ -1679,7 +1711,7 @@ class LineGenerator(Visitor[Line]):
 
     is_pyi: bool = False
     normalize_strings: bool = True
-    current_line: Line = Factory(Line)
+    current_line: Line = field(default_factory=Line)
     remove_u_prefix: bool = False
 
     def line(self, indent: int = 0) -> Iterator[Line]:
@@ -1728,13 +1760,13 @@ class LineGenerator(Visitor[Line]):
                 self.current_line.append(node)
         yield from super().visit_default(node)
 
-    def visit_INDENT(self, node: Node) -> Iterator[Line]:
+    def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
         yield from self.line(+1)
         yield from self.visit_default(node)
 
-    def visit_DEDENT(self, node: Node) -> Iterator[Line]:
+    def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
         """Decrease indentation level, maybe yield a line."""
         # The current line might still wait for trailing comments.  At DEDENT time
         # there won't be any (they would be prefixes on the preceding NEWLINE).
@@ -1844,7 +1876,7 @@ class LineGenerator(Visitor[Line]):
             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
         yield from self.visit_default(node)
 
-    def __attrs_post_init__(self) -> None:
+    def __post_init__(self) -> None:
         """You are in a twisty little maze of passages."""
         v = self.visit_stmt
         Ø: Set[str] = set()
@@ -2462,7 +2494,7 @@ def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
     current_leaves = head_leaves
-    matching_bracket = None
+    matching_bracket: Optional[Leaf] = None
     for leaf in line.leaves:
         if (
             current_leaves is body_leaves
@@ -2505,8 +2537,8 @@ def right_hand_split(
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
     current_leaves = tail_leaves
-    opening_bracket = None
-    closing_bracket = None
+    opening_bracket: Optional[Leaf] = None
+    closing_bracket: Optional[Leaf] = None
     for leaf in reversed(line.leaves):
         if current_leaves is body_leaves:
             if leaf is opening_bracket:
@@ -2810,7 +2842,7 @@ def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
     assert match is not None, f"failed to match string {leaf.value!r}"
     orig_prefix = match.group(1)
-    new_prefix = orig_prefix.lower()
+    new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
     if remove_u_prefix:
         new_prefix = new_prefix.replace("u", "")
     leaf.value = f"{new_prefix}{match.group(2)}"
@@ -2966,16 +2998,9 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
 
             if child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(child, parent=node):
-                    lpar = Leaf(token.LPAR, "")
-                    rpar = Leaf(token.RPAR, "")
-                    index = child.remove() or 0
-                    node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+                    wrap_in_parentheses(node, child, visible=False)
             elif is_one_tuple(child):
-                # wrap child in visible parentheses
-                lpar = Leaf(token.LPAR, "(")
-                rpar = Leaf(token.RPAR, ")")
-                child.remove()
-                node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+                wrap_in_parentheses(node, child, visible=True)
             elif node.type == syms.import_from:
                 # "import from" nodes store parentheses directly as part of
                 # the statement
@@ -2990,15 +3015,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
                 break
 
             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
-                # wrap child in invisible parentheses
-                lpar = Leaf(token.LPAR, "")
-                rpar = Leaf(token.RPAR, "")
-                index = child.remove() or 0
-                prefix = child.prefix
-                child.prefix = ""
-                new_child = Node(syms.atom, [lpar, child, rpar])
-                new_child.prefix = prefix
-                node.insert_child(index, new_child)
+                wrap_in_parentheses(node, child, visible=False)
 
         check_lpar = isinstance(child, Leaf) and child.value in parens_after
 
@@ -3042,7 +3059,7 @@ def convert_one_fmt_off_pair(node: Node) -> bool:
                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
                     # leaf (possibly followed by a DEDENT).
                     hidden_value = hidden_value[:-1]
-                first_idx = None
+                first_idx: Optional[int] = None
                 for ignored in ignored_nodes:
                     index = ignored.remove()
                     if first_idx is None:
@@ -3071,9 +3088,14 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
     """
     container: Optional[LN] = container_of(leaf)
     while container is not None and container.type != token.ENDMARKER:
+        is_fmt_on = False
         for comment in list_comments(container.prefix, is_endmarker=False):
             if comment.value in FMT_ON:
-                return
+                is_fmt_on = True
+            elif comment.value in FMT_OFF:
+                is_fmt_on = False
+        if is_fmt_on:
+            return
 
         yield container
 
@@ -3158,6 +3180,24 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
     return wrapped
 
 
+def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
+    """Wrap `child` in parentheses.
+
+    This replaces `child` with an atom holding the parentheses and the old
+    child.  That requires moving the prefix.
+
+    If `visible` is False, the leaves will be valueless (and thus invisible).
+    """
+    lpar = Leaf(token.LPAR, "(" if visible else "")
+    rpar = Leaf(token.RPAR, ")" if visible else "")
+    prefix = child.prefix
+    child.prefix = ""
+    index = child.remove() or 0
+    new_child = Node(syms.atom, [lpar, child, rpar])
+    new_child.prefix = prefix
+    parent.insert_child(index, new_child)
+
+
 def is_one_tuple(node: LN) -> bool:
     """Return True if `node` holds a tuple with one element, with or without parens."""
     if node.type == syms.atom:
@@ -3390,8 +3430,8 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
     yield omit
 
     length = 4 * line.depth
-    opening_bracket = None
-    closing_bracket = None
+    opening_bracket: Optional[Leaf] = None
+    closing_bracket: Optional[Leaf] = None
     inner_brackets: Set[LeafID] = set()
     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
         length += leaf_length
@@ -3562,7 +3602,7 @@ def find_project_root(srcs: Iterable[str]) -> Path:
         # Append a fake file so `parents` below returns `common_base_dir`, too.
         common_base /= "fake-file"
     for directory in common_base.parents:
-        if (directory / ".git").is_dir():
+        if (directory / ".git").exists():
             return directory
 
         if (directory / ".hg").is_dir():
@@ -3579,6 +3619,7 @@ class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
 
     check: bool = False
+    diff: bool = False
     quiet: bool = False
     verbose: bool = False
     change_count: int = 0
@@ -3588,7 +3629,7 @@ class Report:
     def done(self, src: Path, changed: Changed) -> None:
         """Increment the counter for successful reformatting. Write out a message."""
         if changed is Changed.YES:
-            reformatted = "would reformat" if self.check else "reformatted"
+            reformatted = "would reformat" if self.check or self.diff else "reformatted"
             if self.verbose or not self.quiet:
                 out(f"{reformatted} {src}")
             self.change_count += 1
@@ -3634,7 +3675,7 @@ class Report:
 
         Use `click.unstyle` to remove colors.
         """
-        if self.check:
+        if self.check or self.diff:
             reformatted = "would be reformatted"
             unchanged = "would be left unchanged"
             failed = "would fail to reformat"
@@ -3704,7 +3745,7 @@ def assert_equivalent(src: str, dst: str) -> None:
 
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
-        for field in sorted(node._fields):
+        for field in sorted(node._fields):  # noqa: F402
             # TypeIgnore has only one field 'lineno' which breaks this comparison
             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
             if sys.version_info >= (3, 8):
@@ -3788,6 +3829,7 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None:
         ) from None
 
 
+@mypyc_attr(patchable=True)
 def dump_to_file(*output: str) -> str:
     """Dump `output` to a temporary file. Return path to the file."""
     with tempfile.NamedTemporaryFile(
@@ -3820,7 +3862,7 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     )
 
 
-def cancel(tasks: Iterable[asyncio.Task]) -> None:
+def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
     err("Aborted!")
     for task in tasks: