X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/b8c1020b526b488a1a216d9a4d58fc8616b61a99..48dfda084a8c854a73c11b6c91c67deae5f86ca3:/src/black/__init__.py?ds=sidebyside

diff --git a/src/black/__init__.py b/src/black/__init__.py
index 6dbb765..e85ffb3 100644
--- a/src/black/__init__.py
+++ b/src/black/__init__.py
@@ -48,7 +48,20 @@ from appdirs import user_cache_dir
 from dataclasses import dataclass, field, replace
 import click
 import toml
-from typed_ast import ast3, ast27
+
+try:
+    from typed_ast import ast3, ast27
+except ImportError:
+    if sys.version_info < (3, 8):
+        print(
+            "The typed_ast package is not installed.\n"
+            "You can install it with `python3 -m pip install typed-ast`.",
+            file=sys.stderr,
+        )
+        sys.exit(1)
+    else:
+        ast3 = ast27 = ast
+
 from pathspec import PathSpec
 
 # lib2to3 fork
@@ -293,7 +306,11 @@ def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
     """Find the absolute filepath to a pyproject.toml if it exists"""
     path_project_root = find_project_root(path_search_start)
     path_pyproject_toml = path_project_root / "pyproject.toml"
-    return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+    if path_pyproject_toml.is_file():
+        return str(path_pyproject_toml)
+
+    path_user_pyproject_toml = find_user_pyproject_toml()
+    return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None
 
 
 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
@@ -363,6 +380,17 @@ def target_version_option_callback(
     return [TargetVersion[val.upper()] for val in v]
 
 
+def validate_regex(
+    ctx: click.Context,
+    param: click.Parameter,
+    value: Optional[str],
+) -> Optional[Pattern]:
+    try:
+        return re_compile_maybe_verbose(value) if value is not None else None
+    except re.error:
+        raise click.BadParameter("Not a valid regular expression")
+
+
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
@@ -441,6 +469,7 @@ def target_version_option_callback(
     "--include",
     type=str,
     default=DEFAULT_INCLUDES,
+    callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
         " included on recursive searches.  An empty value means all files are included"
@@ -453,6 +482,7 @@ def target_version_option_callback(
     "--exclude",
     type=str,
     default=DEFAULT_EXCLUDES,
+    callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
         " excluded on recursive searches.  An empty value means no paths are excluded."
@@ -461,9 +491,19 @@ def target_version_option_callback(
     ),
     show_default=True,
 )
+@click.option(
+    "--extend-exclude",
+    type=str,
+    callback=validate_regex,
+    help=(
+        "Like --exclude, but adds additional files and directories on top of the"
+        " excluded ones. (Useful if you simply want to add to the default)"
+    ),
+)
 @click.option(
     "--force-exclude",
     type=str,
+    callback=validate_regex,
     help=(
         "Like --exclude, but files and directories matching this regex will be "
         "excluded even when they are passed explicitly as arguments."
@@ -493,7 +533,7 @@ def target_version_option_callback(
     is_flag=True,
     help=(
         "Also emit messages to stderr about files that were not changed or were ignored"
-        " due to --exclude=."
+        " due to exclusion patterns."
     ),
 )
 @click.version_option(version=__version__)
@@ -535,9 +575,10 @@ def main(
     experimental_string_processing: bool,
     quiet: bool,
     verbose: bool,
-    include: str,
-    exclude: str,
-    force_exclude: Optional[str],
+    include: Pattern,
+    exclude: Pattern,
+    extend_exclude: Optional[Pattern],
+    force_exclude: Optional[Pattern],
     stdin_filename: Optional[str],
     src: Tuple[str, ...],
     config: Optional[str],
@@ -570,6 +611,7 @@ def main(
         verbose=verbose,
         include=include,
         exclude=exclude,
+        extend_exclude=extend_exclude,
         force_exclude=force_exclude,
         report=report,
         stdin_filename=stdin_filename,
@@ -608,30 +650,14 @@ def get_sources(
     src: Tuple[str, ...],
     quiet: bool,
     verbose: bool,
-    include: str,
-    exclude: str,
-    force_exclude: Optional[str],
+    include: Pattern[str],
+    exclude: Pattern[str],
+    extend_exclude: Optional[Pattern[str]],
+    force_exclude: Optional[Pattern[str]],
     report: "Report",
     stdin_filename: Optional[str],
 ) -> Set[Path]:
     """Compute the set of files to be formatted."""
-    try:
-        include_regex = re_compile_maybe_verbose(include)
-    except re.error:
-        err(f"Invalid regular expression for include given: {include!r}")
-        ctx.exit(2)
-    try:
-        exclude_regex = re_compile_maybe_verbose(exclude)
-    except re.error:
-        err(f"Invalid regular expression for exclude given: {exclude!r}")
-        ctx.exit(2)
-    try:
-        force_exclude_regex = (
-            re_compile_maybe_verbose(force_exclude) if force_exclude else None
-        )
-    except re.error:
-        err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
-        ctx.exit(2)
 
     root = find_project_root(src)
     sources: Set[Path] = set()
@@ -653,8 +679,8 @@ def get_sources(
 
             normalized_path = "/" + normalized_path
             # Hard-exclude any files that matches the `--force-exclude` regex.
-            if force_exclude_regex:
-                force_exclude_match = force_exclude_regex.search(normalized_path)
+            if force_exclude:
+                force_exclude_match = force_exclude.search(normalized_path)
             else:
                 force_exclude_match = None
             if force_exclude_match and force_exclude_match.group(0):
@@ -670,9 +696,10 @@ def get_sources(
                 gen_python_files(
                     p.iterdir(),
                     root,
-                    include_regex,
-                    exclude_regex,
-                    force_exclude_regex,
+                    include,
+                    exclude,
+                    extend_exclude,
+                    force_exclude,
                     report,
                     gitignore,
                 )
@@ -883,7 +910,7 @@ def format_file_in_place(
         dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
 
-        if write_back == write_back.COLOR_DIFF:
+        if write_back == WriteBack.COLOR_DIFF:
             diff_contents = color_diff(diff_contents)
 
         with lock or nullcontext():
@@ -1480,7 +1507,8 @@ class Line:
     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
     inside_brackets: bool = False
-    should_explode: bool = False
+    should_split_rhs: bool = False
+    magic_trailing_comma: Optional[Leaf] = None
 
     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
         """Add a new `leaf` to the end of the line.
@@ -1508,7 +1536,7 @@ class Line:
             self.bracket_tracker.mark(leaf)
             if self.mode.magic_trailing_comma:
                 if self.has_magic_trailing_comma(leaf):
-                    self.should_explode = True
+                    self.magic_trailing_comma = leaf
             elif self.has_magic_trailing_comma(leaf, ensure_removable=True):
                 self.remove_trailing_comma()
         if not self.append_comment(leaf):
@@ -1791,7 +1819,8 @@ class Line:
             mode=self.mode,
             depth=self.depth,
             inside_brackets=self.inside_brackets,
-            should_explode=self.should_explode,
+            should_split_rhs=self.should_split_rhs,
+            magic_trailing_comma=self.magic_trailing_comma,
         )
 
     def __str__(self) -> str:
@@ -2047,6 +2076,8 @@ class LineGenerator(Visitor[Line]):
 
     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
         """Visit a statement without nested statements."""
+        if first_child_is_arith(node):
+            wrap_in_parentheses(node, node.children[0], visible=False)
         is_suite_like = node.parent and node.parent.type in STATEMENT
         if is_suite_like:
             if self.mode.is_pyi and is_stub_body(node):
@@ -2581,6 +2612,8 @@ def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Pr
 
 
 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
+FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
+FMT_PASS = {*FMT_OFF, *FMT_SKIP}
 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
 
 
@@ -2708,7 +2741,8 @@ def transform_line(
     transformers: List[Transformer]
     if (
         not line.contains_uncollapsable_type_comments()
-        and not line.should_explode
+        and not line.should_split_rhs
+        and not line.magic_trailing_comma
         and (
             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
             or line.contains_unsplittable_type_ignore()
@@ -4382,7 +4416,8 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
             mode=line.mode,
             depth=line.depth + 1,
             inside_brackets=True,
-            should_explode=line.should_explode,
+            should_split_rhs=line.should_split_rhs,
+            magic_trailing_comma=line.magic_trailing_comma,
         )
         string_leaf = Leaf(token.STRING, string_value)
         insert_str_child(string_leaf)
@@ -5003,8 +5038,8 @@ def bracket_split_build_line(
         result.append(leaf, preformatted=True)
         for comment_after in original.comments_after(leaf):
             result.append(comment_after, preformatted=True)
-    if is_body and should_split_body_explode(result, opening_bracket):
-        result.should_explode = True
+    if is_body and should_split_line(result, opening_bracket):
+        result.should_split_rhs = True
     return result
 
 
@@ -5362,10 +5397,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
             check_lpar = True
 
         if check_lpar:
-            if is_walrus_assignment(child):
-                pass
-
-            elif child.type == syms.atom:
+            if child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(child, parent=node):
                     wrap_in_parentheses(node, child, visible=False)
             elif is_one_tuple(child):
@@ -5404,58 +5436,80 @@ def convert_one_fmt_off_pair(node: Node) -> bool:
     for leaf in node.leaves():
         previous_consumed = 0
         for comment in list_comments(leaf.prefix, is_endmarker=False):
-            if comment.value in FMT_OFF:
-                # We only want standalone comments. If there's no previous leaf or
-                # the previous leaf is indentation, it's a standalone comment in
-                # disguise.
-                if comment.type != STANDALONE_COMMENT:
-                    prev = preceding_leaf(leaf)
-                    if prev and prev.type not in WHITESPACE:
+            if comment.value not in FMT_PASS:
+                previous_consumed = comment.consumed
+                continue
+            # We only want standalone comments. If there's no previous leaf or
+            # the previous leaf is indentation, it's a standalone comment in
+            # disguise.
+            if comment.value in FMT_PASS and comment.type != STANDALONE_COMMENT:
+                prev = preceding_leaf(leaf)
+                if prev:
+                    if comment.value in FMT_OFF and prev.type not in WHITESPACE:
+                        continue
+                    if comment.value in FMT_SKIP and prev.type in WHITESPACE:
                         continue
 
-                ignored_nodes = list(generate_ignored_nodes(leaf))
-                if not ignored_nodes:
-                    continue
-
-                first = ignored_nodes[0]  # Can be a container node with the `leaf`.
-                parent = first.parent
-                prefix = first.prefix
-                first.prefix = prefix[comment.consumed :]
-                hidden_value = (
-                    comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
-                )
-                if hidden_value.endswith("\n"):
-                    # That happens when one of the `ignored_nodes` ended with a NEWLINE
-                    # leaf (possibly followed by a DEDENT).
-                    hidden_value = hidden_value[:-1]
-                first_idx: Optional[int] = None
-                for ignored in ignored_nodes:
-                    index = ignored.remove()
-                    if first_idx is None:
-                        first_idx = index
-                assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
-                assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
-                parent.insert_child(
-                    first_idx,
-                    Leaf(
-                        STANDALONE_COMMENT,
-                        hidden_value,
-                        prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
-                    ),
-                )
-                return True
+            ignored_nodes = list(generate_ignored_nodes(leaf, comment))
+            if not ignored_nodes:
+                continue
 
-            previous_consumed = comment.consumed
+            first = ignored_nodes[0]  # Can be a container node with the `leaf`.
+            parent = first.parent
+            prefix = first.prefix
+            first.prefix = prefix[comment.consumed :]
+            hidden_value = "".join(str(n) for n in ignored_nodes)
+            if comment.value in FMT_OFF:
+                hidden_value = comment.value + "\n" + hidden_value
+            if comment.value in FMT_SKIP:
+                hidden_value += "  " + comment.value
+            if hidden_value.endswith("\n"):
+                # That happens when one of the `ignored_nodes` ended with a NEWLINE
+                # leaf (possibly followed by a DEDENT).
+                hidden_value = hidden_value[:-1]
+            first_idx: Optional[int] = None
+            for ignored in ignored_nodes:
+                index = ignored.remove()
+                if first_idx is None:
+                    first_idx = index
+            assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
+            assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
+            parent.insert_child(
+                first_idx,
+                Leaf(
+                    STANDALONE_COMMENT,
+                    hidden_value,
+                    prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
+                ),
+            )
+            return True
 
     return False
 
 
-def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
+def generate_ignored_nodes(leaf: Leaf, comment: ProtoComment) -> Iterator[LN]:
     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
 
+    If comment is skip, returns leaf only.
     Stops at the end of the block.
     """
     container: Optional[LN] = container_of(leaf)
+    if comment.value in FMT_SKIP:
+        prev_sibling = leaf.prev_sibling
+        if comment.value in leaf.prefix and prev_sibling is not None:
+            leaf.prefix = leaf.prefix.replace(comment.value, "")
+            siblings = [prev_sibling]
+            while (
+                "\n" not in prev_sibling.prefix
+                and prev_sibling.prev_sibling is not None
+            ):
+                prev_sibling = prev_sibling.prev_sibling
+                siblings.insert(0, prev_sibling)
+            for sibling in siblings:
+                yield sibling
+        elif leaf.parent is not None:
+            yield leaf.parent
+        return
     while container is not None and container.type != token.ENDMARKER:
         if is_fmt_on(container):
             return
@@ -5515,6 +5569,7 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
     Returns whether the node should itself be wrapped in invisible parentheses.
 
     """
+
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
@@ -5524,6 +5579,10 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
     ):
         return False
 
+    if is_walrus_assignment(node):
+        if parent.type in [syms.annassign, syms.expr_stmt]:
+            return False
+
     first = node.children[0]
     last = node.children[-1]
     if first.type == token.LPAR and last.type == token.RPAR:
@@ -5585,6 +5644,17 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
     return wrapped
 
 
+def first_child_is_arith(node: Node) -> bool:
+    """Whether first child is an arithmetic or a binary arithmetic expression"""
+    expr_types = {
+        syms.arith_expr,
+        syms.shift_expr,
+        syms.xor_expr,
+        syms.and_expr,
+    }
+    return bool(node.children and node.children[0].type in expr_types)
+
+
 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
     """Wrap `child` in parentheses.
 
@@ -5786,7 +5856,7 @@ def ensure_visible(leaf: Leaf) -> None:
         leaf.value = ")"
 
 
-def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool:
+def should_split_line(line: Line, opening_bracket: Leaf) -> bool:
     """Should `line` be immediately split with `delimiter_split()` after RHS?"""
 
     if not (opening_bracket.parent and opening_bracket.value in "[{("):
@@ -5922,7 +5992,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
     """
 
     omit: Set[LeafID] = set()
-    if not line.should_explode:
+    if not line.magic_trailing_comma:
         yield omit
 
     length = 4 * line.depth
@@ -5944,8 +6014,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
             elif leaf.type in CLOSING_BRACKETS:
                 prev = line.leaves[index - 1] if index > 0 else None
                 if (
-                    line.should_explode
-                    and prev
+                    prev
                     and prev.type == token.COMMA
                     and not is_one_tuple_between(
                         leaf.opening_bracket, leaf, line.leaves
@@ -5972,8 +6041,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
                 yield omit
 
             if (
-                line.should_explode
-                and prev
+                prev
                 and prev.type == token.COMMA
                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
             ):
@@ -6071,17 +6139,27 @@ def normalize_path_maybe_ignore(
     return normalized_path
 
 
+def path_is_excluded(
+    normalized_path: str,
+    pattern: Optional[Pattern[str]],
+) -> bool:
+    match = pattern.search(normalized_path) if pattern else None
+    return bool(match and match.group(0))
+
+
 def gen_python_files(
     paths: Iterable[Path],
     root: Path,
     include: Optional[Pattern[str]],
     exclude: Pattern[str],
+    extend_exclude: Optional[Pattern[str]],
     force_exclude: Optional[Pattern[str]],
     report: "Report",
     gitignore: PathSpec,
 ) -> Iterator[Path]:
     """Generate all files under `path` whose paths are not excluded by the
-    `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
+    `exclude_regex`, `extend_exclude`, or `force_exclude` regexes,
+    but are included by the `include` regex.
 
     Symbolic links pointing outside of the `root` directory are ignored.
 
@@ -6098,20 +6176,22 @@ def gen_python_files(
             report.path_ignored(child, "matches the .gitignore file content")
             continue
 
-        # Then ignore with `--exclude` and `--force-exclude` options.
+        # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options.
         normalized_path = "/" + normalized_path
         if child.is_dir():
             normalized_path += "/"
 
-        exclude_match = exclude.search(normalized_path) if exclude else None
-        if exclude_match and exclude_match.group(0):
+        if path_is_excluded(normalized_path, exclude):
             report.path_ignored(child, "matches the --exclude regular expression")
             continue
 
-        force_exclude_match = (
-            force_exclude.search(normalized_path) if force_exclude else None
-        )
-        if force_exclude_match and force_exclude_match.group(0):
+        if path_is_excluded(normalized_path, extend_exclude):
+            report.path_ignored(
+                child, "matches the --extend-exclude regular expression"
+            )
+            continue
+
+        if path_is_excluded(normalized_path, force_exclude):
             report.path_ignored(child, "matches the --force-exclude regular expression")
             continue
 
@@ -6121,6 +6201,7 @@ def gen_python_files(
                 root,
                 include,
                 exclude,
+                extend_exclude,
                 force_exclude,
                 report,
                 gitignore,
@@ -6171,6 +6252,22 @@ def find_project_root(srcs: Iterable[str]) -> Path:
     return directory
 
 
+@lru_cache()
+def find_user_pyproject_toml() -> Path:
+    r"""Return the path to the top-level user configuration for black.
+
+    This looks for ~\.black on Windows and ~/.config/black on Linux and other
+    Unix systems.
+    """
+    if sys.platform == "win32":
+        # Windows
+        user_config_path = Path.home() / ".black"
+    else:
+        config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config")
+        user_config_path = Path(config_root).expanduser() / "black"
+    return user_config_path.resolve()
+
+
 @dataclass
 class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
@@ -6272,7 +6369,12 @@ def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
                 return ast3.parse(src, filename, feature_version=feature_version)
             except SyntaxError:
                 continue
-
+    if ast27.__name__ == "ast":
+        raise SyntaxError(
+            "The requested source code has invalid Python 3 syntax.\n"
+            "If you are trying to format Python 2 files please reinstall Black"
+            " with the 'python2' extra: `python3 -m pip install black[python2]`."
+        )
     return ast27.parse(src)
 
 
@@ -6399,14 +6501,14 @@ def assert_stable(src: str, dst: str, mode: Mode) -> None:
 
 
 @mypyc_attr(patchable=True)
-def dump_to_file(*output: str) -> str:
+def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str:
     """Dump `output` to a temporary file. Return path to the file."""
     with tempfile.NamedTemporaryFile(
         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
     ) as f:
         for lines in output:
             f.write(lines)
-            if lines and lines[-1] != "\n":
+            if ensure_final_newline and lines and lines[-1] != "\n":
                 f.write("\n")
     return f.name
 
@@ -6424,11 +6526,20 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
 
-    a_lines = [line + "\n" for line in a.splitlines()]
-    b_lines = [line + "\n" for line in b.splitlines()]
-    return "".join(
-        difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
-    )
+    a_lines = [line for line in a.splitlines(keepends=True)]
+    b_lines = [line for line in b.splitlines(keepends=True)]
+    diff_lines = []
+    for line in difflib.unified_diff(
+        a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5
+    ):
+        # Work around https://bugs.python.org/issue2142
+        # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html
+        if line[-1] == "\n":
+            diff_lines.append(line)
+        else:
+            diff_lines.append(line + "\n")
+            diff_lines.append("\\ No newline at end of file\n")
+    return "".join(diff_lines)
 
 
 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
@@ -6605,7 +6716,7 @@ def can_omit_invisible_parens(
 
     penultimate = line.leaves[-2]
     last = line.leaves[-1]
-    if line.should_explode:
+    if line.magic_trailing_comma:
         try:
             penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
         except LookupError:
@@ -6632,7 +6743,7 @@ def can_omit_invisible_parens(
             # unnecessary.
             return True
 
-        if line.should_explode and penultimate.type == token.COMMA:
+        if line.magic_trailing_comma and penultimate.type == token.COMMA:
             # The rightmost non-omitted bracket pair is the one we want to explode on.
             return True