X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/51141f1af4aab0e0c7f71932ffe06482a884f1d5..5446a92f0161e398de765bf9532d8c76c5652333:/src/black/__init__.py?ds=inline

diff --git a/src/black/__init__.py b/src/black/__init__.py
index c836b2b..a8f4f89 100644
--- a/src/black/__init__.py
+++ b/src/black/__init__.py
@@ -363,6 +363,17 @@ def target_version_option_callback(
     return [TargetVersion[val.upper()] for val in v]
 
 
+def validate_regex(
+    ctx: click.Context,
+    param: click.Parameter,
+    value: Optional[str],
+) -> Optional[Pattern]:
+    try:
+        return re_compile_maybe_verbose(value) if value is not None else None
+    except re.error:
+        raise click.BadParameter("Not a valid regular expression")
+
+
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
@@ -441,6 +452,7 @@ def target_version_option_callback(
     "--include",
     type=str,
     default=DEFAULT_INCLUDES,
+    callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
         " included on recursive searches.  An empty value means all files are included"
@@ -453,6 +465,7 @@ def target_version_option_callback(
     "--exclude",
     type=str,
     default=DEFAULT_EXCLUDES,
+    callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
         " excluded on recursive searches.  An empty value means no paths are excluded."
@@ -461,9 +474,19 @@ def target_version_option_callback(
     ),
     show_default=True,
 )
+@click.option(
+    "--extend-exclude",
+    type=str,
+    callback=validate_regex,
+    help=(
+        "Like --exclude, but adds additional files and directories on top of the"
+        " excluded ones. (Useful if you simply want to add to the default)"
+    ),
+)
 @click.option(
     "--force-exclude",
     type=str,
+    callback=validate_regex,
     help=(
         "Like --exclude, but files and directories matching this regex will be "
         "excluded even when they are passed explicitly as arguments."
@@ -493,7 +516,7 @@ def target_version_option_callback(
     is_flag=True,
     help=(
         "Also emit messages to stderr about files that were not changed or were ignored"
-        " due to --exclude=."
+        " due to exclusion patterns."
     ),
 )
 @click.version_option(version=__version__)
@@ -535,9 +558,10 @@ def main(
     experimental_string_processing: bool,
     quiet: bool,
     verbose: bool,
-    include: str,
-    exclude: str,
-    force_exclude: Optional[str],
+    include: Pattern,
+    exclude: Pattern,
+    extend_exclude: Optional[Pattern],
+    force_exclude: Optional[Pattern],
     stdin_filename: Optional[str],
     src: Tuple[str, ...],
     config: Optional[str],
@@ -570,6 +594,7 @@ def main(
         verbose=verbose,
         include=include,
         exclude=exclude,
+        extend_exclude=extend_exclude,
         force_exclude=force_exclude,
         report=report,
         stdin_filename=stdin_filename,
@@ -608,30 +633,14 @@ def get_sources(
     src: Tuple[str, ...],
     quiet: bool,
     verbose: bool,
-    include: str,
-    exclude: str,
-    force_exclude: Optional[str],
+    include: Pattern[str],
+    exclude: Pattern[str],
+    extend_exclude: Optional[Pattern[str]],
+    force_exclude: Optional[Pattern[str]],
     report: "Report",
     stdin_filename: Optional[str],
 ) -> Set[Path]:
     """Compute the set of files to be formatted."""
-    try:
-        include_regex = re_compile_maybe_verbose(include)
-    except re.error:
-        err(f"Invalid regular expression for include given: {include!r}")
-        ctx.exit(2)
-    try:
-        exclude_regex = re_compile_maybe_verbose(exclude)
-    except re.error:
-        err(f"Invalid regular expression for exclude given: {exclude!r}")
-        ctx.exit(2)
-    try:
-        force_exclude_regex = (
-            re_compile_maybe_verbose(force_exclude) if force_exclude else None
-        )
-    except re.error:
-        err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
-        ctx.exit(2)
 
     root = find_project_root(src)
     sources: Set[Path] = set()
@@ -653,8 +662,8 @@ def get_sources(
 
             normalized_path = "/" + normalized_path
             # Hard-exclude any files that matches the `--force-exclude` regex.
-            if force_exclude_regex:
-                force_exclude_match = force_exclude_regex.search(normalized_path)
+            if force_exclude:
+                force_exclude_match = force_exclude.search(normalized_path)
             else:
                 force_exclude_match = None
             if force_exclude_match and force_exclude_match.group(0):
@@ -670,9 +679,10 @@ def get_sources(
                 gen_python_files(
                     p.iterdir(),
                     root,
-                    include_regex,
-                    exclude_regex,
-                    force_exclude_regex,
+                    include,
+                    exclude,
+                    extend_exclude,
+                    force_exclude,
                     report,
                     gitignore,
                 )
@@ -883,7 +893,7 @@ def format_file_in_place(
         dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
 
-        if write_back == write_back.COLOR_DIFF:
+        if write_back == WriteBack.COLOR_DIFF:
             diff_contents = color_diff(diff_contents)
 
         with lock or nullcontext():
@@ -1480,7 +1490,7 @@ class Line:
     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
     inside_brackets: bool = False
-    should_split: bool = False
+    should_split_rhs: bool = False
     magic_trailing_comma: Optional[Leaf] = None
 
     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
@@ -1792,7 +1802,7 @@ class Line:
             mode=self.mode,
             depth=self.depth,
             inside_brackets=self.inside_brackets,
-            should_split=self.should_split,
+            should_split_rhs=self.should_split_rhs,
             magic_trailing_comma=self.magic_trailing_comma,
         )
 
@@ -2049,6 +2059,8 @@ class LineGenerator(Visitor[Line]):
 
     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
         """Visit a statement without nested statements."""
+        if first_child_is_arith(node):
+            wrap_in_parentheses(node, node.children[0], visible=False)
         is_suite_like = node.parent and node.parent.type in STATEMENT
         if is_suite_like:
             if self.mode.is_pyi and is_stub_body(node):
@@ -2712,7 +2724,8 @@ def transform_line(
     transformers: List[Transformer]
     if (
         not line.contains_uncollapsable_type_comments()
-        and not (line.should_split or line.magic_trailing_comma)
+        and not line.should_split_rhs
+        and not line.magic_trailing_comma
         and (
             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
             or line.contains_unsplittable_type_ignore()
@@ -4386,7 +4399,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
             mode=line.mode,
             depth=line.depth + 1,
             inside_brackets=True,
-            should_split=line.should_split,
+            should_split_rhs=line.should_split_rhs,
             magic_trailing_comma=line.magic_trailing_comma,
         )
         string_leaf = Leaf(token.STRING, string_value)
@@ -5008,8 +5021,8 @@ def bracket_split_build_line(
         result.append(leaf, preformatted=True)
         for comment_after in original.comments_after(leaf):
             result.append(comment_after, preformatted=True)
-    if is_body and should_split(result, opening_bracket):
-        result.should_split = True
+    if is_body and should_split_line(result, opening_bracket):
+        result.should_split_rhs = True
     return result
 
 
@@ -5367,10 +5380,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
             check_lpar = True
 
         if check_lpar:
-            if is_walrus_assignment(child):
-                pass
-
-            elif child.type == syms.atom:
+            if child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(child, parent=node):
                     wrap_in_parentheses(node, child, visible=False)
             elif is_one_tuple(child):
@@ -5542,6 +5552,7 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
     Returns whether the node should itself be wrapped in invisible parentheses.
 
     """
+
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
@@ -5551,6 +5562,10 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
     ):
         return False
 
+    if is_walrus_assignment(node):
+        if parent.type in [syms.annassign, syms.expr_stmt]:
+            return False
+
     first = node.children[0]
     last = node.children[-1]
     if first.type == token.LPAR and last.type == token.RPAR:
@@ -5612,6 +5627,17 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
     return wrapped
 
 
+def first_child_is_arith(node: Node) -> bool:
+    """Whether first child is an arithmetic or a binary arithmetic expression"""
+    expr_types = {
+        syms.arith_expr,
+        syms.shift_expr,
+        syms.xor_expr,
+        syms.and_expr,
+    }
+    return bool(node.children and node.children[0].type in expr_types)
+
+
 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
     """Wrap `child` in parentheses.
 
@@ -5813,7 +5839,7 @@ def ensure_visible(leaf: Leaf) -> None:
         leaf.value = ")"
 
 
-def should_split(line: Line, opening_bracket: Leaf) -> bool:
+def should_split_line(line: Line, opening_bracket: Leaf) -> bool:
     """Should `line` be immediately split with `delimiter_split()` after RHS?"""
 
     if not (opening_bracket.parent and opening_bracket.value in "[{("):
@@ -5949,7 +5975,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
     """
 
     omit: Set[LeafID] = set()
-    if not line.should_split and not line.magic_trailing_comma:
+    if not line.magic_trailing_comma:
         yield omit
 
     length = 4 * line.depth
@@ -5971,8 +5997,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
             elif leaf.type in CLOSING_BRACKETS:
                 prev = line.leaves[index - 1] if index > 0 else None
                 if (
-                    line.magic_trailing_comma
-                    and prev
+                    prev
                     and prev.type == token.COMMA
                     and not is_one_tuple_between(
                         leaf.opening_bracket, leaf, line.leaves
@@ -5999,8 +6024,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
                 yield omit
 
             if (
-                line.magic_trailing_comma
-                and prev
+                prev
                 and prev.type == token.COMMA
                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
             ):
@@ -6098,17 +6122,27 @@ def normalize_path_maybe_ignore(
     return normalized_path
 
 
+def path_is_excluded(
+    normalized_path: str,
+    pattern: Optional[Pattern[str]],
+) -> bool:
+    match = pattern.search(normalized_path) if pattern else None
+    return bool(match and match.group(0))
+
+
 def gen_python_files(
     paths: Iterable[Path],
     root: Path,
     include: Optional[Pattern[str]],
     exclude: Pattern[str],
+    extend_exclude: Optional[Pattern[str]],
     force_exclude: Optional[Pattern[str]],
     report: "Report",
     gitignore: PathSpec,
 ) -> Iterator[Path]:
     """Generate all files under `path` whose paths are not excluded by the
-    `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
+    `exclude_regex`, `extend_exclude`, or `force_exclude` regexes,
+    but are included by the `include` regex.
 
     Symbolic links pointing outside of the `root` directory are ignored.
 
@@ -6125,20 +6159,22 @@ def gen_python_files(
             report.path_ignored(child, "matches the .gitignore file content")
             continue
 
-        # Then ignore with `--exclude` and `--force-exclude` options.
+        # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options.
         normalized_path = "/" + normalized_path
         if child.is_dir():
             normalized_path += "/"
 
-        exclude_match = exclude.search(normalized_path) if exclude else None
-        if exclude_match and exclude_match.group(0):
+        if path_is_excluded(normalized_path, exclude):
             report.path_ignored(child, "matches the --exclude regular expression")
             continue
 
-        force_exclude_match = (
-            force_exclude.search(normalized_path) if force_exclude else None
-        )
-        if force_exclude_match and force_exclude_match.group(0):
+        if path_is_excluded(normalized_path, extend_exclude):
+            report.path_ignored(
+                child, "matches the --extend-exclude regular expression"
+            )
+            continue
+
+        if path_is_excluded(normalized_path, force_exclude):
             report.path_ignored(child, "matches the --force-exclude regular expression")
             continue
 
@@ -6148,6 +6184,7 @@ def gen_python_files(
                 root,
                 include,
                 exclude,
+                extend_exclude,
                 force_exclude,
                 report,
                 gitignore,
@@ -6426,14 +6463,14 @@ def assert_stable(src: str, dst: str, mode: Mode) -> None:
 
 
 @mypyc_attr(patchable=True)
-def dump_to_file(*output: str) -> str:
+def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str:
     """Dump `output` to a temporary file. Return path to the file."""
     with tempfile.NamedTemporaryFile(
         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
     ) as f:
         for lines in output:
             f.write(lines)
-            if lines and lines[-1] != "\n":
+            if ensure_final_newline and lines and lines[-1] != "\n":
                 f.write("\n")
     return f.name
 
@@ -6451,11 +6488,20 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
 
-    a_lines = [line + "\n" for line in a.splitlines()]
-    b_lines = [line + "\n" for line in b.splitlines()]
-    return "".join(
-        difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
-    )
+    a_lines = [line for line in a.splitlines(keepends=True)]
+    b_lines = [line for line in b.splitlines(keepends=True)]
+    diff_lines = []
+    for line in difflib.unified_diff(
+        a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5
+    ):
+        # Work around https://bugs.python.org/issue2142
+        # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html
+        if line[-1] == "\n":
+            diff_lines.append(line)
+        else:
+            diff_lines.append(line + "\n")
+            diff_lines.append("\\ No newline at end of file\n")
+    return "".join(diff_lines)
 
 
 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
@@ -6632,7 +6678,7 @@ def can_omit_invisible_parens(
 
     penultimate = line.leaves[-2]
     last = line.leaves[-1]
-    if line.should_split or line.magic_trailing_comma:
+    if line.magic_trailing_comma:
         try:
             penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
         except LookupError: