]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Speed up test suite via distributed testing (#2196)
[etc/vim.git] / src / black / __init__.py
index 36716474e8c187adcb4acddf1a46d46c17603e3f..e47aa21d4352436c84db76ed6b8acbd0b8b07cac 100644 (file)
@@ -273,9 +273,9 @@ class Mode:
     target_versions: Set[TargetVersion] = field(default_factory=set)
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
     target_versions: Set[TargetVersion] = field(default_factory=set)
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
+    is_pyi: bool = False
     magic_trailing_comma: bool = True
     experimental_string_processing: bool = False
     magic_trailing_comma: bool = True
     experimental_string_processing: bool = False
-    is_pyi: bool = False
 
     def get_cache_key(self) -> str:
         if self.target_versions:
 
     def get_cache_key(self) -> str:
         if self.target_versions:
@@ -290,6 +290,8 @@ class Mode:
             str(self.line_length),
             str(int(self.string_normalization)),
             str(int(self.is_pyi)),
             str(self.line_length),
             str(int(self.string_normalization)),
             str(int(self.is_pyi)),
+            str(int(self.magic_trailing_comma)),
+            str(int(self.experimental_string_processing)),
         ]
         return ".".join(parts)
 
         ]
         return ".".join(parts)
 
@@ -309,8 +311,17 @@ def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
     if path_pyproject_toml.is_file():
         return str(path_pyproject_toml)
 
     if path_pyproject_toml.is_file():
         return str(path_pyproject_toml)
 
-    path_user_pyproject_toml = find_user_pyproject_toml()
-    return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None
+    try:
+        path_user_pyproject_toml = find_user_pyproject_toml()
+        return (
+            str(path_user_pyproject_toml)
+            if path_user_pyproject_toml.is_file()
+            else None
+        )
+    except PermissionError as e:
+        # We do not have access to the user-level config directory, so ignore it.
+        err(f"Ignoring user configuration directory due to {e!r}")
+        return None
 
 
 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
 
 
 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
@@ -481,15 +492,15 @@ def validate_regex(
 @click.option(
     "--exclude",
     type=str,
 @click.option(
     "--exclude",
     type=str,
-    default=DEFAULT_EXCLUDES,
     callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
         " excluded on recursive searches. An empty value means no paths are excluded."
         " Use forward slashes for directories on all platforms (Windows, too)."
     callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
         " excluded on recursive searches. An empty value means no paths are excluded."
         " Use forward slashes for directories on all platforms (Windows, too)."
-        " Exclusions are calculated first, inclusions later."
+        " Exclusions are calculated first, inclusions later. [default:"
+        f" {DEFAULT_EXCLUDES}]"
     ),
     ),
-    show_default=True,
+    show_default=False,
 )
 @click.option(
     "--extend-exclude",
 )
 @click.option(
     "--extend-exclude",
@@ -576,7 +587,7 @@ def main(
     quiet: bool,
     verbose: bool,
     include: Pattern,
     quiet: bool,
     verbose: bool,
     include: Pattern,
-    exclude: Pattern,
+    exclude: Optional[Pattern],
     extend_exclude: Optional[Pattern],
     force_exclude: Optional[Pattern],
     stdin_filename: Optional[str],
     extend_exclude: Optional[Pattern],
     force_exclude: Optional[Pattern],
     stdin_filename: Optional[str],
@@ -651,7 +662,7 @@ def get_sources(
     quiet: bool,
     verbose: bool,
     include: Pattern[str],
     quiet: bool,
     verbose: bool,
     include: Pattern[str],
-    exclude: Pattern[str],
+    exclude: Optional[Pattern[str]],
     extend_exclude: Optional[Pattern[str]],
     force_exclude: Optional[Pattern[str]],
     report: "Report",
     extend_exclude: Optional[Pattern[str]],
     force_exclude: Optional[Pattern[str]],
     report: "Report",
@@ -662,7 +673,12 @@ def get_sources(
     root = find_project_root(src)
     sources: Set[Path] = set()
     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
     root = find_project_root(src)
     sources: Set[Path] = set()
     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
-    gitignore = get_gitignore(root)
+
+    if exclude is None:
+        exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
+        gitignore = get_gitignore(root)
+    else:
+        gitignore = None
 
     for s in src:
         if s == "-" and stdin_filename:
 
     for s in src:
         if s == "-" and stdin_filename:
@@ -744,6 +760,8 @@ def reformat_one(
             is_stdin = False
 
         if is_stdin:
             is_stdin = False
 
         if is_stdin:
+            if src.suffix == ".pyi":
+                mode = replace(mode, is_pyi=True)
             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
                 changed = Changed.YES
         else:
             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
                 changed = Changed.YES
         else:
@@ -784,7 +802,7 @@ def reformat_many(
     except (ImportError, OSError):
         # we arrive here if the underlying system does not support multi-processing
         # like in AWS Lambda or Termux, in which case we gracefully fallback to
     except (ImportError, OSError):
         # we arrive here if the underlying system does not support multi-processing
         # like in AWS Lambda or Termux, in which case we gracefully fallback to
-        # a ThreadPollExecutor with just a single worker (more workers would not do us
+        # a ThreadPoolExecutor with just a single worker (more workers would not do us
         # any good due to the Global Interpreter Lock)
         executor = ThreadPoolExecutor(max_workers=1)
 
         # any good due to the Global Interpreter Lock)
         executor = ThreadPoolExecutor(max_workers=1)
 
@@ -1016,7 +1034,17 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(src_contents, dst_contents, mode=mode)
+
+        # Forced second pass to work around optional trailing commas (becoming
+        # forced trailing commas on pass 2) interacting differently with optional
+        # parentheses.  Admittedly ugly.
+        dst_contents_pass2 = format_str(dst_contents, mode=mode)
+        if dst_contents != dst_contents_pass2:
+            dst_contents = dst_contents_pass2
+            assert_equivalent(src_contents, dst_contents, pass_num=2)
+            assert_stable(src_contents, dst_contents, mode=mode)
+        # Note: no need to explicitly call `assert_stable` if `dst_contents` was
+        # the same as `dst_contents_pass2`.
     return dst_contents
 
 
     return dst_contents
 
 
@@ -2153,16 +2181,41 @@ class LineGenerator(Visitor[Line]):
             # We're ignoring docstrings with backslash newline escapes because changing
             # indentation of those changes the AST representation of the code.
             prefix = get_string_prefix(leaf.value)
             # We're ignoring docstrings with backslash newline escapes because changing
             # indentation of those changes the AST representation of the code.
             prefix = get_string_prefix(leaf.value)
-            lead_len = len(prefix) + 3
-            tail_len = -3
-            indent = " " * 4 * self.current_line.depth
-            docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
+            docstring = leaf.value[len(prefix) :]  # Remove the prefix
+            quote_char = docstring[0]
+            # A natural way to remove the outer quotes is to do:
+            #   docstring = docstring.strip(quote_char)
+            # but that breaks on """""x""" (which is '""x').
+            # So we actually need to remove the first character and the next two
+            # characters but only if they are the same as the first.
+            quote_len = 1 if docstring[1] != quote_char else 3
+            docstring = docstring[quote_len:-quote_len]
+
+            if is_multiline_string(leaf):
+                indent = " " * 4 * self.current_line.depth
+                docstring = fix_docstring(docstring, indent)
+            else:
+                docstring = docstring.strip()
+
             if docstring:
             if docstring:
-                if leaf.value[lead_len - 1] == docstring[0]:
+                # Add some padding if the docstring starts / ends with a quote mark.
+                if docstring[0] == quote_char:
                     docstring = " " + docstring
                     docstring = " " + docstring
-                if leaf.value[tail_len + 1] == docstring[-1]:
-                    docstring = docstring + " "
-            leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
+                if docstring[-1] == quote_char:
+                    docstring += " "
+                if docstring[-1] == "\\":
+                    backslash_count = len(docstring) - len(docstring.rstrip("\\"))
+                    if backslash_count % 2:
+                        # Odd number of tailing backslashes, add some padding to
+                        # avoid escaping the closing string quote.
+                        docstring += " "
+            else:
+                # Add some padding if the docstring is empty.
+                docstring = " "
+
+            # We could enforce triple quotes at this point.
+            quote = quote_char * quote_len
+            leaf.value = prefix + quote + docstring + quote
 
         yield from self.visit_default(leaf)
 
 
         yield from self.visit_default(leaf)
 
@@ -5332,15 +5385,10 @@ def normalize_numeric_literal(leaf: Leaf) -> None:
 
 def format_hex(text: str) -> str:
     """
 
 def format_hex(text: str) -> str:
     """
-    Formats a hexadecimal string like "0x12b3"
-
-    Uses lowercase because of similarity between "B" and "8", which
-    can cause security issues.
-    see: https://github.com/psf/black/issues/1692
+    Formats a hexadecimal string like "0x12B3"
     """
     """
-
     before, after = text[:2], text[2:]
     before, after = text[:2], text[2:]
-    return f"{before}{after.lower()}"
+    return f"{before}{after.upper()}"
 
 
 def format_scientific_notation(text: str) -> str:
 
 
 def format_scientific_notation(text: str) -> str:
@@ -5587,7 +5635,15 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
         return False
 
     if is_walrus_assignment(node):
         return False
 
     if is_walrus_assignment(node):
-        if parent.type in [syms.annassign, syms.expr_stmt]:
+        if parent.type in [
+            syms.annassign,
+            syms.expr_stmt,
+            syms.assert_stmt,
+            syms.return_stmt,
+            # these ones aren't useful to end users, but they do please fuzzers
+            syms.for_stmt,
+            syms.del_stmt,
+        ]:
             return False
 
     first = node.children[0]
             return False
 
     first = node.children[0]
@@ -5710,6 +5766,13 @@ def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:
             and node.children[0].type == token.DOT
             and node.children[1].type == token.NAME
         )
             and node.children[0].type == token.DOT
             and node.children[1].type == token.NAME
         )
+        # last trailer can be an argument-less parentheses pair
+        or (
+            last
+            and len(node.children) == 2
+            and node.children[0].type == token.LPAR
+            and node.children[1].type == token.RPAR
+        )
         # last trailer can be arguments
         or (
             last
         # last trailer can be arguments
         or (
             last
@@ -6113,7 +6176,7 @@ def get_future_imports(node: Node) -> Set[str]:
 
 @lru_cache()
 def get_gitignore(root: Path) -> PathSpec:
 
 @lru_cache()
 def get_gitignore(root: Path) -> PathSpec:
-    """ Return a PathSpec matching gitignore content if present."""
+    """Return a PathSpec matching gitignore content if present."""
     gitignore = root / ".gitignore"
     lines: List[str] = []
     if gitignore.is_file():
     gitignore = root / ".gitignore"
     lines: List[str] = []
     if gitignore.is_file():
@@ -6157,12 +6220,12 @@ def path_is_excluded(
 def gen_python_files(
     paths: Iterable[Path],
     root: Path,
 def gen_python_files(
     paths: Iterable[Path],
     root: Path,
-    include: Optional[Pattern[str]],
+    include: Pattern[str],
     exclude: Pattern[str],
     extend_exclude: Optional[Pattern[str]],
     force_exclude: Optional[Pattern[str]],
     report: "Report",
     exclude: Pattern[str],
     extend_exclude: Optional[Pattern[str]],
     force_exclude: Optional[Pattern[str]],
     report: "Report",
-    gitignore: PathSpec,
+    gitignore: Optional[PathSpec],
 ) -> Iterator[Path]:
     """Generate all files under `path` whose paths are not excluded by the
     `exclude_regex`, `extend_exclude`, or `force_exclude` regexes,
 ) -> Iterator[Path]:
     """Generate all files under `path` whose paths are not excluded by the
     `exclude_regex`, `extend_exclude`, or `force_exclude` regexes,
@@ -6178,8 +6241,8 @@ def gen_python_files(
         if normalized_path is None:
             continue
 
         if normalized_path is None:
             continue
 
-        # First ignore files matching .gitignore
-        if gitignore.match_file(normalized_path):
+        # First ignore files matching .gitignore, if passed
+        if gitignore is not None and gitignore.match_file(normalized_path):
             report.path_ignored(child, "matches the .gitignore file content")
             continue
 
             report.path_ignored(child, "matches the .gitignore file content")
             continue
 
@@ -6447,12 +6510,22 @@ def _stringify_ast(
             # Constant strings may be indented across newlines, if they are
             # docstrings; fold spaces after newlines when comparing. Similarly,
             # trailing and leading space may be removed.
             # Constant strings may be indented across newlines, if they are
             # docstrings; fold spaces after newlines when comparing. Similarly,
             # trailing and leading space may be removed.
+            # Note that when formatting Python 2 code, at least with Windows
+            # line-endings, docstrings can end up here as bytes instead of
+            # str so make sure that we handle both cases.
             if (
                 isinstance(node, ast.Constant)
                 and field == "value"
             if (
                 isinstance(node, ast.Constant)
                 and field == "value"
-                and isinstance(value, str)
+                and isinstance(value, (str, bytes))
             ):
             ):
-                normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
+                lineend = "\n" if isinstance(value, str) else b"\n"
+                # To normalize, we strip any leading and trailing space from
+                # each line...
+                stripped = [line.strip() for line in value.splitlines()]
+                normalized = lineend.join(stripped)  # type: ignore[attr-defined]
+                # ...and remove any blank lines at the beginning and end of
+                # the whole string
+                normalized = normalized.strip()
             else:
                 normalized = value
             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
             else:
                 normalized = value
             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
@@ -6460,7 +6533,7 @@ def _stringify_ast(
     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
 
 
     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
 
 
-def assert_equivalent(src: str, dst: str) -> None:
+def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
     try:
         src_ast = parse_ast(src)
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
     try:
         src_ast = parse_ast(src)
@@ -6475,9 +6548,9 @@ def assert_equivalent(src: str, dst: str) -> None:
     except Exception as exc:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
     except Exception as exc:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
-            f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
-            " on https://github.com/psf/black/issues.  This invalid output might be"
-            f" helpful: {log}"
+            f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
+            "Please report a bug on https://github.com/psf/black/issues.  "
+            f"This invalid output might be helpful: {log}"
         ) from None
 
     src_ast_str = "\n".join(_stringify_ast(src_ast))
         ) from None
 
     src_ast_str = "\n".join(_stringify_ast(src_ast))
@@ -6486,8 +6559,8 @@ def assert_equivalent(src: str, dst: str) -> None:
         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
         raise AssertionError(
             "INTERNAL ERROR: Black produced code that is not equivalent to the"
         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
         raise AssertionError(
             "INTERNAL ERROR: Black produced code that is not equivalent to the"
-            " source.  Please report a bug on https://github.com/psf/black/issues. "
-            f" This diff might be helpful: {log}"
+            f" source on pass {pass_num}.  Please report a bug on "
+            f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
         ) from None
 
 
         ) from None
 
 
@@ -6953,11 +7026,6 @@ def patched_main() -> None:
 
 
 def is_docstring(leaf: Leaf) -> bool:
 
 
 def is_docstring(leaf: Leaf) -> bool:
-    if not is_multiline_string(leaf):
-        # For the purposes of docstring re-indentation, we don't need to do anything
-        # with single-line docstrings.
-        return False
-
     if prev_siblings_are(
         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
     ):
     if prev_siblings_are(
         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
     ):