]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Set `is_pyi` if `stdin_filename` ends with `.pyi` (#2169)
[etc/vim.git] / src / black / __init__.py
index efa82f41160ac12f1665052fc7ccfc459d71cdc0..49d088b531d0bc48f3340f576db6659ad0da8a03 100644 (file)
@@ -273,9 +273,9 @@ class Mode:
     target_versions: Set[TargetVersion] = field(default_factory=set)
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
+    is_pyi: bool = False
     magic_trailing_comma: bool = True
     experimental_string_processing: bool = False
-    is_pyi: bool = False
 
     def get_cache_key(self) -> str:
         if self.target_versions:
@@ -290,6 +290,8 @@ class Mode:
             str(self.line_length),
             str(int(self.string_normalization)),
             str(int(self.is_pyi)),
+            str(int(self.magic_trailing_comma)),
+            str(int(self.experimental_string_processing)),
         ]
         return ".".join(parts)
 
@@ -309,8 +311,17 @@ def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
     if path_pyproject_toml.is_file():
         return str(path_pyproject_toml)
 
-    path_user_pyproject_toml = find_user_pyproject_toml()
-    return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None
+    try:
+        path_user_pyproject_toml = find_user_pyproject_toml()
+        return (
+            str(path_user_pyproject_toml)
+            if path_user_pyproject_toml.is_file()
+            else None
+        )
+    except PermissionError as e:
+        # We do not have access to the user-level config directory, so ignore it.
+        err(f"Ignoring user configuration directory due to {e!r}")
+        return None
 
 
 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
@@ -744,6 +755,8 @@ def reformat_one(
             is_stdin = False
 
         if is_stdin:
+            if src.suffix == ".pyi":
+                mode = replace(mode, is_pyi=True)
             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
                 changed = Changed.YES
         else:
@@ -1016,7 +1029,17 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(src_contents, dst_contents, mode=mode)
+
+        # Forced second pass to work around optional trailing commas (becoming
+        # forced trailing commas on pass 2) interacting differently with optional
+        # parentheses.  Admittedly ugly.
+        dst_contents_pass2 = format_str(dst_contents, mode=mode)
+        if dst_contents != dst_contents_pass2:
+            dst_contents = dst_contents_pass2
+            assert_equivalent(src_contents, dst_contents, pass_num=2)
+            assert_stable(src_contents, dst_contents, mode=mode)
+        # Note: no need to explicitly call `assert_stable` if `dst_contents` was
+        # the same as `dst_contents_pass2`.
     return dst_contents
 
 
@@ -2174,7 +2197,13 @@ class LineGenerator(Visitor[Line]):
                 if docstring[0] == quote_char:
                     docstring = " " + docstring
                 if docstring[-1] == quote_char:
-                    docstring = docstring + " "
+                    docstring += " "
+                if docstring[-1] == "\\":
+                    backslash_count = len(docstring) - len(docstring.rstrip("\\"))
+                    if backslash_count % 2:
+                        # Odd number of tailing backslashes, add some padding to
+                        # avoid escaping the closing string quote.
+                        docstring += " "
             else:
                 # Add some padding if the docstring is empty.
                 docstring = " "
@@ -5351,15 +5380,10 @@ def normalize_numeric_literal(leaf: Leaf) -> None:
 
 def format_hex(text: str) -> str:
     """
-    Formats a hexadecimal string like "0x12b3"
-
-    Uses lowercase because of similarity between "B" and "8", which
-    can cause security issues.
-    see: https://github.com/psf/black/issues/1692
+    Formats a hexadecimal string like "0x12B3"
     """
-
     before, after = text[:2], text[2:]
-    return f"{before}{after.lower()}"
+    return f"{before}{after.upper()}"
 
 
 def format_scientific_notation(text: str) -> str:
@@ -5606,7 +5630,15 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
         return False
 
     if is_walrus_assignment(node):
-        if parent.type in [syms.annassign, syms.expr_stmt]:
+        if parent.type in [
+            syms.annassign,
+            syms.expr_stmt,
+            syms.assert_stmt,
+            syms.return_stmt,
+            # these ones aren't useful to end users, but they do please fuzzers
+            syms.for_stmt,
+            syms.del_stmt,
+        ]:
             return False
 
     first = node.children[0]
@@ -6466,12 +6498,22 @@ def _stringify_ast(
             # Constant strings may be indented across newlines, if they are
             # docstrings; fold spaces after newlines when comparing. Similarly,
             # trailing and leading space may be removed.
+            # Note that when formatting Python 2 code, at least with Windows
+            # line-endings, docstrings can end up here as bytes instead of
+            # str so make sure that we handle both cases.
             if (
                 isinstance(node, ast.Constant)
                 and field == "value"
-                and isinstance(value, str)
+                and isinstance(value, (str, bytes))
             ):
-                normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
+                lineend = "\n" if isinstance(value, str) else b"\n"
+                # To normalize, we strip any leading and trailing space from
+                # each line...
+                stripped = [line.strip() for line in value.splitlines()]
+                normalized = lineend.join(stripped)  # type: ignore[attr-defined]
+                # ...and remove any blank lines at the beginning and end of
+                # the whole string
+                normalized = normalized.strip()
             else:
                 normalized = value
             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
@@ -6479,7 +6521,7 @@ def _stringify_ast(
     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
 
 
-def assert_equivalent(src: str, dst: str) -> None:
+def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
     try:
         src_ast = parse_ast(src)
@@ -6494,9 +6536,9 @@ def assert_equivalent(src: str, dst: str) -> None:
     except Exception as exc:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
-            f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
-            " on https://github.com/psf/black/issues.  This invalid output might be"
-            f" helpful: {log}"
+            f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
+            "Please report a bug on https://github.com/psf/black/issues.  "
+            f"This invalid output might be helpful: {log}"
         ) from None
 
     src_ast_str = "\n".join(_stringify_ast(src_ast))
@@ -6505,8 +6547,8 @@ def assert_equivalent(src: str, dst: str) -> None:
         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
         raise AssertionError(
             "INTERNAL ERROR: Black produced code that is not equivalent to the"
-            " source.  Please report a bug on https://github.com/psf/black/issues. "
-            f" This diff might be helpful: {log}"
+            f" source on pass {pass_num}.  Please report a bug on "
+            f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
         ) from None