]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Set `is_pyi` if `stdin_filename` ends with `.pyi` (#2169)
[etc/vim.git] / src / black / __init__.py
index e21e2af5bd1394ca0441d1ccf133a0d68f6c0515..49d088b531d0bc48f3340f576db6659ad0da8a03 100644 (file)
@@ -48,7 +48,20 @@ from appdirs import user_cache_dir
 from dataclasses import dataclass, field, replace
 import click
 import toml
-from typed_ast import ast3, ast27
+
+try:
+    from typed_ast import ast3, ast27
+except ImportError:
+    if sys.version_info < (3, 8):
+        print(
+            "The typed_ast package is not installed.\n"
+            "You can install it with `python3 -m pip install typed-ast`.",
+            file=sys.stderr,
+        )
+        sys.exit(1)
+    else:
+        ast3 = ast27 = ast
+
 from pathspec import PathSpec
 
 # lib2to3 fork
@@ -69,7 +82,7 @@ if TYPE_CHECKING:
     import colorama  # noqa: F401
 
 DEFAULT_LINE_LENGTH = 88
-DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
+DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
 DEFAULT_INCLUDES = r"\.pyi?$"
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__"
@@ -260,9 +273,9 @@ class Mode:
     target_versions: Set[TargetVersion] = field(default_factory=set)
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
+    is_pyi: bool = False
     magic_trailing_comma: bool = True
     experimental_string_processing: bool = False
-    is_pyi: bool = False
 
     def get_cache_key(self) -> str:
         if self.target_versions:
@@ -277,6 +290,8 @@ class Mode:
             str(self.line_length),
             str(int(self.string_normalization)),
             str(int(self.is_pyi)),
+            str(int(self.magic_trailing_comma)),
+            str(int(self.experimental_string_processing)),
         ]
         return ".".join(parts)
 
@@ -289,11 +304,24 @@ def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> b
     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 
-def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
+def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
     """Find the absolute filepath to a pyproject.toml if it exists"""
     path_project_root = find_project_root(path_search_start)
     path_pyproject_toml = path_project_root / "pyproject.toml"
-    return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+    if path_pyproject_toml.is_file():
+        return str(path_pyproject_toml)
+
+    try:
+        path_user_pyproject_toml = find_user_pyproject_toml()
+        return (
+            str(path_user_pyproject_toml)
+            if path_user_pyproject_toml.is_file()
+            else None
+        )
+    except PermissionError as e:
+        # We do not have access to the user-level config directory, so ignore it.
+        err(f"Ignoring user configuration directory due to {e!r}")
+        return None
 
 
 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
@@ -363,6 +391,17 @@ def target_version_option_callback(
     return [TargetVersion[val.upper()] for val in v]
 
 
+def validate_regex(
+    ctx: click.Context,
+    param: click.Parameter,
+    value: Optional[str],
+) -> Optional[Pattern]:
+    try:
+        return re_compile_maybe_verbose(value) if value is not None else None
+    except re.error:
+        raise click.BadParameter("Not a valid regular expression")
+
+
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
@@ -417,8 +456,8 @@ def target_version_option_callback(
     "--check",
     is_flag=True,
     help=(
-        "Don't write the files back, just return the status.  Return code 0 means"
-        " nothing would change.  Return code 1 means some files would be reformatted."
+        "Don't write the files back, just return the status. Return code 0 means"
+        " nothing would change. Return code 1 means some files would be reformatted."
         " Return code 123 means there was an internal error."
     ),
 )
@@ -441,11 +480,12 @@ def target_version_option_callback(
     "--include",
     type=str,
     default=DEFAULT_INCLUDES,
+    callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
-        " included on recursive searches.  An empty value means all files are included"
-        " regardless of the name.  Use forward slashes for directories on all platforms"
-        " (Windows, too).  Exclusions are calculated first, inclusions later."
+        " included on recursive searches. An empty value means all files are included"
+        " regardless of the name. Use forward slashes for directories on all platforms"
+        " (Windows, too). Exclusions are calculated first, inclusions later."
     ),
     show_default=True,
 )
@@ -453,10 +493,11 @@ def target_version_option_callback(
     "--exclude",
     type=str,
     default=DEFAULT_EXCLUDES,
+    callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
-        " excluded on recursive searches.  An empty value means no paths are excluded."
-        " Use forward slashes for directories on all platforms (Windows, too). "
+        " excluded on recursive searches. An empty value means no paths are excluded."
+        " Use forward slashes for directories on all platforms (Windows, too)."
         " Exclusions are calculated first, inclusions later."
     ),
     show_default=True,
@@ -464,6 +505,7 @@ def target_version_option_callback(
 @click.option(
     "--extend-exclude",
     type=str,
+    callback=validate_regex,
     help=(
         "Like --exclude, but adds additional files and directories on top of the"
         " excluded ones. (Useful if you simply want to add to the default)"
@@ -472,6 +514,7 @@ def target_version_option_callback(
 @click.option(
     "--force-exclude",
     type=str,
+    callback=validate_regex,
     help=(
         "Like --exclude, but files and directories matching this regex will be "
         "excluded even when they are passed explicitly as arguments."
@@ -543,10 +586,10 @@ def main(
     experimental_string_processing: bool,
     quiet: bool,
     verbose: bool,
-    include: str,
-    exclude: str,
-    extend_exclude: Optional[str],
-    force_exclude: Optional[str],
+    include: Pattern,
+    exclude: Pattern,
+    extend_exclude: Optional[Pattern],
+    force_exclude: Optional[Pattern],
     stdin_filename: Optional[str],
     src: Tuple[str, ...],
     config: Optional[str],
@@ -612,39 +655,21 @@ def main(
     ctx.exit(report.return_code)
 
 
-def test_regex(
-    ctx: click.Context,
-    regex_name: str,
-    regex: Optional[str],
-) -> Optional[Pattern]:
-    try:
-        return re_compile_maybe_verbose(regex) if regex is not None else None
-    except re.error:
-        err(f"Invalid regular expression for {regex_name} given: {regex!r}")
-        ctx.exit(2)
-
-
 def get_sources(
     *,
     ctx: click.Context,
     src: Tuple[str, ...],
     quiet: bool,
     verbose: bool,
-    include: str,
-    exclude: str,
-    extend_exclude: Optional[str],
-    force_exclude: Optional[str],
+    include: Pattern[str],
+    exclude: Pattern[str],
+    extend_exclude: Optional[Pattern[str]],
+    force_exclude: Optional[Pattern[str]],
     report: "Report",
     stdin_filename: Optional[str],
 ) -> Set[Path]:
     """Compute the set of files to be formatted."""
 
-    include_regex = test_regex(ctx, "include", include)
-    exclude_regex = test_regex(ctx, "exclude", exclude)
-    assert exclude_regex is not None
-    extend_exclude_regex = test_regex(ctx, "extend_exclude", extend_exclude)
-    force_exclude_regex = test_regex(ctx, "force_exclude", force_exclude)
-
     root = find_project_root(src)
     sources: Set[Path] = set()
     path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
@@ -665,8 +690,8 @@ def get_sources(
 
             normalized_path = "/" + normalized_path
             # Hard-exclude any files that matches the `--force-exclude` regex.
-            if force_exclude_regex:
-                force_exclude_match = force_exclude_regex.search(normalized_path)
+            if force_exclude:
+                force_exclude_match = force_exclude.search(normalized_path)
             else:
                 force_exclude_match = None
             if force_exclude_match and force_exclude_match.group(0):
@@ -682,10 +707,10 @@ def get_sources(
                 gen_python_files(
                     p.iterdir(),
                     root,
-                    include_regex,
-                    exclude_regex,
-                    extend_exclude_regex,
-                    force_exclude_regex,
+                    include,
+                    exclude,
+                    extend_exclude,
+                    force_exclude,
                     report,
                     gitignore,
                 )
@@ -730,6 +755,8 @@ def reformat_one(
             is_stdin = False
 
         if is_stdin:
+            if src.suffix == ".pyi":
+                mode = replace(mode, is_pyi=True)
             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
                 changed = Changed.YES
         else:
@@ -770,7 +797,7 @@ def reformat_many(
     except (ImportError, OSError):
         # we arrive here if the underlying system does not support multi-processing
         # like in AWS Lambda or Termux, in which case we gracefully fallback to
-        # a ThreadPollExecutor with just a single worker (more workers would not do us
+        # a ThreadPoolExecutor with just a single worker (more workers would not do us
         # any good due to the Global Interpreter Lock)
         executor = ThreadPoolExecutor(max_workers=1)
 
@@ -833,7 +860,7 @@ async def schedule_formatting(
         ): src
         for src in sorted(sources)
     }
-    pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
+    pending = tasks.keys()
     try:
         loop.add_signal_handler(signal.SIGINT, cancel, pending)
         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
@@ -1002,7 +1029,17 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(src_contents, dst_contents, mode=mode)
+
+        # Forced second pass to work around optional trailing commas (becoming
+        # forced trailing commas on pass 2) interacting differently with optional
+        # parentheses.  Admittedly ugly.
+        dst_contents_pass2 = format_str(dst_contents, mode=mode)
+        if dst_contents != dst_contents_pass2:
+            dst_contents = dst_contents_pass2
+            assert_equivalent(src_contents, dst_contents, pass_num=2)
+            assert_stable(src_contents, dst_contents, mode=mode)
+        # Note: no need to explicitly call `assert_stable` if `dst_contents` was
+        # the same as `dst_contents_pass2`.
     return dst_contents
 
 
@@ -2139,16 +2176,41 @@ class LineGenerator(Visitor[Line]):
             # We're ignoring docstrings with backslash newline escapes because changing
             # indentation of those changes the AST representation of the code.
             prefix = get_string_prefix(leaf.value)
-            lead_len = len(prefix) + 3
-            tail_len = -3
-            indent = " " * 4 * self.current_line.depth
-            docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
+            docstring = leaf.value[len(prefix) :]  # Remove the prefix
+            quote_char = docstring[0]
+            # A natural way to remove the outer quotes is to do:
+            #   docstring = docstring.strip(quote_char)
+            # but that breaks on """""x""" (which is '""x').
+            # So we actually need to remove the first character and the next two
+            # characters but only if they are the same as the first.
+            quote_len = 1 if docstring[1] != quote_char else 3
+            docstring = docstring[quote_len:-quote_len]
+
+            if is_multiline_string(leaf):
+                indent = " " * 4 * self.current_line.depth
+                docstring = fix_docstring(docstring, indent)
+            else:
+                docstring = docstring.strip()
+
             if docstring:
-                if leaf.value[lead_len - 1] == docstring[0]:
+                # Add some padding if the docstring starts / ends with a quote mark.
+                if docstring[0] == quote_char:
                     docstring = " " + docstring
-                if leaf.value[tail_len + 1] == docstring[-1]:
-                    docstring = docstring + " "
-            leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
+                if docstring[-1] == quote_char:
+                    docstring += " "
+                if docstring[-1] == "\\":
+                    backslash_count = len(docstring) - len(docstring.rstrip("\\"))
+                    if backslash_count % 2:
+                        # Odd number of tailing backslashes, add some padding to
+                        # avoid escaping the closing string quote.
+                        docstring += " "
+            else:
+                # Add some padding if the docstring is empty.
+                docstring = " "
+
+            # We could enforce triple quotes at this point.
+            quote = quote_char * quote_len
+            leaf.value = prefix + quote + docstring + quote
 
         yield from self.visit_default(leaf)
 
@@ -2695,6 +2757,13 @@ def make_comment(content: str) -> str:
 
     if content[0] == "#":
         content = content[1:]
+    NON_BREAKING_SPACE = " "
+    if (
+        content
+        and content[0] == NON_BREAKING_SPACE
+        and not content.lstrip().startswith("type:")
+    ):
+        content = " " + content[1:]  # Replace NBSP by a simple space
     if content and content[0] not in " !:#'%":
         content = " " + content
     return "#" + content
@@ -5311,15 +5380,10 @@ def normalize_numeric_literal(leaf: Leaf) -> None:
 
 def format_hex(text: str) -> str:
     """
-    Formats a hexadecimal string like "0x12b3"
-
-    Uses lowercase because of similarity between "B" and "8", which
-    can cause security issues.
-    see: https://github.com/psf/black/issues/1692
+    Formats a hexadecimal string like "0x12B3"
     """
-
     before, after = text[:2], text[2:]
-    return f"{before}{after.lower()}"
+    return f"{before}{after.upper()}"
 
 
 def format_scientific_notation(text: str) -> str:
@@ -5566,7 +5630,15 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
         return False
 
     if is_walrus_assignment(node):
-        if parent.type in [syms.annassign, syms.expr_stmt]:
+        if parent.type in [
+            syms.annassign,
+            syms.expr_stmt,
+            syms.assert_stmt,
+            syms.return_stmt,
+            # these ones aren't useful to end users, but they do please fuzzers
+            syms.for_stmt,
+            syms.del_stmt,
+        ]:
             return False
 
     first = node.children[0]
@@ -6092,7 +6164,7 @@ def get_future_imports(node: Node) -> Set[str]:
 
 @lru_cache()
 def get_gitignore(root: Path) -> PathSpec:
-    """ Return a PathSpec matching gitignore content if present."""
+    """Return a PathSpec matching gitignore content if present."""
     gitignore = root / ".gitignore"
     lines: List[str] = []
     if gitignore.is_file():
@@ -6200,7 +6272,7 @@ def gen_python_files(
 
 
 @lru_cache()
-def find_project_root(srcs: Iterable[str]) -> Path:
+def find_project_root(srcs: Tuple[str, ...]) -> Path:
     """Return a directory containing .git, .hg, or pyproject.toml.
 
     That directory will be a common parent of all files and directories
@@ -6238,6 +6310,22 @@ def find_project_root(srcs: Iterable[str]) -> Path:
     return directory
 
 
+@lru_cache()
+def find_user_pyproject_toml() -> Path:
+    r"""Return the path to the top-level user configuration for black.
+
+    This looks for ~\.black on Windows and ~/.config/black on Linux and other
+    Unix systems.
+    """
+    if sys.platform == "win32":
+        # Windows
+        user_config_path = Path.home() / ".black"
+    else:
+        config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config")
+        user_config_path = Path(config_root).expanduser() / "black"
+    return user_config_path.resolve()
+
+
 @dataclass
 class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
@@ -6339,7 +6427,12 @@ def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
                 return ast3.parse(src, filename, feature_version=feature_version)
             except SyntaxError:
                 continue
-
+    if ast27.__name__ == "ast":
+        raise SyntaxError(
+            "The requested source code has invalid Python 3 syntax.\n"
+            "If you are trying to format Python 2 files please reinstall Black"
+            " with the 'python2' extra: `python3 -m pip install black[python2]`."
+        )
     return ast27.parse(src)
 
 
@@ -6405,12 +6498,22 @@ def _stringify_ast(
             # Constant strings may be indented across newlines, if they are
             # docstrings; fold spaces after newlines when comparing. Similarly,
             # trailing and leading space may be removed.
+            # Note that when formatting Python 2 code, at least with Windows
+            # line-endings, docstrings can end up here as bytes instead of
+            # str so make sure that we handle both cases.
             if (
                 isinstance(node, ast.Constant)
                 and field == "value"
-                and isinstance(value, str)
+                and isinstance(value, (str, bytes))
             ):
-                normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
+                lineend = "\n" if isinstance(value, str) else b"\n"
+                # To normalize, we strip any leading and trailing space from
+                # each line...
+                stripped = [line.strip() for line in value.splitlines()]
+                normalized = lineend.join(stripped)  # type: ignore[attr-defined]
+                # ...and remove any blank lines at the beginning and end of
+                # the whole string
+                normalized = normalized.strip()
             else:
                 normalized = value
             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
@@ -6418,7 +6521,7 @@ def _stringify_ast(
     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
 
 
-def assert_equivalent(src: str, dst: str) -> None:
+def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
     try:
         src_ast = parse_ast(src)
@@ -6433,9 +6536,9 @@ def assert_equivalent(src: str, dst: str) -> None:
     except Exception as exc:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
-            f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
-            " on https://github.com/psf/black/issues.  This invalid output might be"
-            f" helpful: {log}"
+            f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
+            "Please report a bug on https://github.com/psf/black/issues.  "
+            f"This invalid output might be helpful: {log}"
         ) from None
 
     src_ast_str = "\n".join(_stringify_ast(src_ast))
@@ -6444,8 +6547,8 @@ def assert_equivalent(src: str, dst: str) -> None:
         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
         raise AssertionError(
             "INTERNAL ERROR: Black produced code that is not equivalent to the"
-            " source.  Please report a bug on https://github.com/psf/black/issues. "
-            f" This diff might be helpful: {log}"
+            f" source on pass {pass_num}.  Please report a bug on "
+            f"https://github.com/psf/black/issues.  This diff might be helpful: {log}"
         ) from None
 
 
@@ -6911,11 +7014,6 @@ def patched_main() -> None:
 
 
 def is_docstring(leaf: Leaf) -> bool:
-    if not is_multiline_string(leaf):
-        # For the purposes of docstring re-indentation, we don't need to do anything
-        # with single-line docstrings.
-        return False
-
     if prev_siblings_are(
         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
     ):