]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Document experimental string processing and docstring indentation (#2106)
[etc/vim.git] / src / black / __init__.py
index a8f4f89a6bbe8536f9e254a57b2fc658879c841e..5c9ab7516508b78ee23eebb49fbe934912548554 100644 (file)
@@ -48,7 +48,20 @@ from appdirs import user_cache_dir
 from dataclasses import dataclass, field, replace
 import click
 import toml
 from dataclasses import dataclass, field, replace
 import click
 import toml
-from typed_ast import ast3, ast27
+
+try:
+    from typed_ast import ast3, ast27
+except ImportError:
+    if sys.version_info < (3, 8):
+        print(
+            "The typed_ast package is not installed.\n"
+            "You can install it with `python3 -m pip install typed-ast`.",
+            file=sys.stderr,
+        )
+        sys.exit(1)
+    else:
+        ast3 = ast27 = ast
+
 from pathspec import PathSpec
 
 # lib2to3 fork
 from pathspec import PathSpec
 
 # lib2to3 fork
@@ -69,7 +82,7 @@ if TYPE_CHECKING:
     import colorama  # noqa: F401
 
 DEFAULT_LINE_LENGTH = 88
     import colorama  # noqa: F401
 
 DEFAULT_LINE_LENGTH = 88
-DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
+DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
 DEFAULT_INCLUDES = r"\.pyi?$"
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__"
 DEFAULT_INCLUDES = r"\.pyi?$"
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__"
@@ -289,11 +302,15 @@ def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> b
     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 
     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 
-def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
+def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
     """Find the absolute filepath to a pyproject.toml if it exists"""
     path_project_root = find_project_root(path_search_start)
     path_pyproject_toml = path_project_root / "pyproject.toml"
     """Find the absolute filepath to a pyproject.toml if it exists"""
     path_project_root = find_project_root(path_search_start)
     path_pyproject_toml = path_project_root / "pyproject.toml"
-    return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+    if path_pyproject_toml.is_file():
+        return str(path_pyproject_toml)
+
+    path_user_pyproject_toml = find_user_pyproject_toml()
+    return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None
 
 
 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
 
 
 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
@@ -428,8 +445,8 @@ def validate_regex(
     "--check",
     is_flag=True,
     help=(
     "--check",
     is_flag=True,
     help=(
-        "Don't write the files back, just return the status.  Return code 0 means"
-        " nothing would change.  Return code 1 means some files would be reformatted."
+        "Don't write the files back, just return the status. Return code 0 means"
+        " nothing would change. Return code 1 means some files would be reformatted."
         " Return code 123 means there was an internal error."
     ),
 )
         " Return code 123 means there was an internal error."
     ),
 )
@@ -455,9 +472,9 @@ def validate_regex(
     callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
     callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
-        " included on recursive searches.  An empty value means all files are included"
-        " regardless of the name.  Use forward slashes for directories on all platforms"
-        " (Windows, too).  Exclusions are calculated first, inclusions later."
+        " included on recursive searches. An empty value means all files are included"
+        " regardless of the name. Use forward slashes for directories on all platforms"
+        " (Windows, too). Exclusions are calculated first, inclusions later."
     ),
     show_default=True,
 )
     ),
     show_default=True,
 )
@@ -468,8 +485,8 @@ def validate_regex(
     callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
     callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
-        " excluded on recursive searches.  An empty value means no paths are excluded."
-        " Use forward slashes for directories on all platforms (Windows, too). "
+        " excluded on recursive searches. An empty value means no paths are excluded."
+        " Use forward slashes for directories on all platforms (Windows, too)."
         " Exclusions are calculated first, inclusions later."
     ),
     show_default=True,
         " Exclusions are calculated first, inclusions later."
     ),
     show_default=True,
@@ -767,7 +784,7 @@ def reformat_many(
     except (ImportError, OSError):
         # we arrive here if the underlying system does not support multi-processing
         # like in AWS Lambda or Termux, in which case we gracefully fallback to
     except (ImportError, OSError):
         # we arrive here if the underlying system does not support multi-processing
         # like in AWS Lambda or Termux, in which case we gracefully fallback to
-        # a ThreadPollExecutor with just a single worker (more workers would not do us
+        # a ThreadPoolExecutor with just a single worker (more workers would not do us
         # any good due to the Global Interpreter Lock)
         executor = ThreadPoolExecutor(max_workers=1)
 
         # any good due to the Global Interpreter Lock)
         executor = ThreadPoolExecutor(max_workers=1)
 
@@ -830,7 +847,7 @@ async def schedule_formatting(
         ): src
         for src in sorted(sources)
     }
         ): src
         for src in sorted(sources)
     }
-    pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
+    pending = tasks.keys()
     try:
         loop.add_signal_handler(signal.SIGINT, cancel, pending)
         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
     try:
         loop.add_signal_handler(signal.SIGINT, cancel, pending)
         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
@@ -2136,16 +2153,35 @@ class LineGenerator(Visitor[Line]):
             # We're ignoring docstrings with backslash newline escapes because changing
             # indentation of those changes the AST representation of the code.
             prefix = get_string_prefix(leaf.value)
             # We're ignoring docstrings with backslash newline escapes because changing
             # indentation of those changes the AST representation of the code.
             prefix = get_string_prefix(leaf.value)
-            lead_len = len(prefix) + 3
-            tail_len = -3
-            indent = " " * 4 * self.current_line.depth
-            docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
+            docstring = leaf.value[len(prefix) :]  # Remove the prefix
+            quote_char = docstring[0]
+            # A natural way to remove the outer quotes is to do:
+            #   docstring = docstring.strip(quote_char)
+            # but that breaks on """""x""" (which is '""x').
+            # So we actually need to remove the first character and the next two
+            # characters but only if they are the same as the first.
+            quote_len = 1 if docstring[1] != quote_char else 3
+            docstring = docstring[quote_len:-quote_len]
+
+            if is_multiline_string(leaf):
+                indent = " " * 4 * self.current_line.depth
+                docstring = fix_docstring(docstring, indent)
+            else:
+                docstring = docstring.strip()
+
             if docstring:
             if docstring:
-                if leaf.value[lead_len - 1] == docstring[0]:
+                # Add some padding if the docstring starts / ends with a quote mark.
+                if docstring[0] == quote_char:
                     docstring = " " + docstring
                     docstring = " " + docstring
-                if leaf.value[tail_len + 1] == docstring[-1]:
+                if docstring[-1] == quote_char:
                     docstring = docstring + " "
                     docstring = docstring + " "
-            leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
+            else:
+                # Add some padding if the docstring is empty.
+                docstring = " "
+
+            # We could enforce triple quotes at this point.
+            quote = quote_char * quote_len
+            leaf.value = prefix + quote + docstring + quote
 
         yield from self.visit_default(leaf)
 
 
         yield from self.visit_default(leaf)
 
@@ -2692,6 +2728,13 @@ def make_comment(content: str) -> str:
 
     if content[0] == "#":
         content = content[1:]
 
     if content[0] == "#":
         content = content[1:]
+    NON_BREAKING_SPACE = " "
+    if (
+        content
+        and content[0] == NON_BREAKING_SPACE
+        and not content.lstrip().startswith("type:")
+    ):
+        content = " " + content[1:]  # Replace NBSP by a simple space
     if content and content[0] not in " !:#'%":
         content = " " + content
     return "#" + content
     if content and content[0] not in " !:#'%":
         content = " " + content
     return "#" + content
@@ -6089,7 +6132,7 @@ def get_future_imports(node: Node) -> Set[str]:
 
 @lru_cache()
 def get_gitignore(root: Path) -> PathSpec:
 
 @lru_cache()
 def get_gitignore(root: Path) -> PathSpec:
-    """ Return a PathSpec matching gitignore content if present."""
+    """Return a PathSpec matching gitignore content if present."""
     gitignore = root / ".gitignore"
     lines: List[str] = []
     if gitignore.is_file():
     gitignore = root / ".gitignore"
     lines: List[str] = []
     if gitignore.is_file():
@@ -6197,7 +6240,7 @@ def gen_python_files(
 
 
 @lru_cache()
 
 
 @lru_cache()
-def find_project_root(srcs: Iterable[str]) -> Path:
+def find_project_root(srcs: Tuple[str, ...]) -> Path:
     """Return a directory containing .git, .hg, or pyproject.toml.
 
     That directory will be a common parent of all files and directories
     """Return a directory containing .git, .hg, or pyproject.toml.
 
     That directory will be a common parent of all files and directories
@@ -6235,6 +6278,22 @@ def find_project_root(srcs: Iterable[str]) -> Path:
     return directory
 
 
     return directory
 
 
+@lru_cache()
+def find_user_pyproject_toml() -> Path:
+    r"""Return the path to the top-level user configuration for black.
+
+    This looks for ~\.black on Windows and ~/.config/black on Linux and other
+    Unix systems.
+    """
+    if sys.platform == "win32":
+        # Windows
+        user_config_path = Path.home() / ".black"
+    else:
+        config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config")
+        user_config_path = Path(config_root).expanduser() / "black"
+    return user_config_path.resolve()
+
+
 @dataclass
 class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
 @dataclass
 class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
@@ -6336,7 +6395,12 @@ def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
                 return ast3.parse(src, filename, feature_version=feature_version)
             except SyntaxError:
                 continue
                 return ast3.parse(src, filename, feature_version=feature_version)
             except SyntaxError:
                 continue
-
+    if ast27.__name__ == "ast":
+        raise SyntaxError(
+            "The requested source code has invalid Python 3 syntax.\n"
+            "If you are trying to format Python 2 files please reinstall Black"
+            " with the 'python2' extra: `python3 -m pip install black[python2]`."
+        )
     return ast27.parse(src)
 
 
     return ast27.parse(src)
 
 
@@ -6402,12 +6466,22 @@ def _stringify_ast(
             # Constant strings may be indented across newlines, if they are
             # docstrings; fold spaces after newlines when comparing. Similarly,
             # trailing and leading space may be removed.
             # Constant strings may be indented across newlines, if they are
             # docstrings; fold spaces after newlines when comparing. Similarly,
             # trailing and leading space may be removed.
+            # Note that when formatting Python 2 code, at least with Windows
+            # line-endings, docstrings can end up here as bytes instead of
+            # str so make sure that we handle both cases.
             if (
                 isinstance(node, ast.Constant)
                 and field == "value"
             if (
                 isinstance(node, ast.Constant)
                 and field == "value"
-                and isinstance(value, str)
+                and isinstance(value, (str, bytes))
             ):
             ):
-                normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
+                lineend = "\n" if isinstance(value, str) else b"\n"
+                # To normalize, we strip any leading and trailing space from
+                # each line...
+                stripped = [line.strip() for line in value.splitlines()]
+                normalized = lineend.join(stripped)  # type: ignore[attr-defined]
+                # ...and remove any blank lines at the beginning and end of
+                # the whole string
+                normalized = normalized.strip()
             else:
                 normalized = value
             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
             else:
                 normalized = value
             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
@@ -6908,11 +6982,6 @@ def patched_main() -> None:
 
 
 def is_docstring(leaf: Leaf) -> bool:
 
 
 def is_docstring(leaf: Leaf) -> bool:
-    if not is_multiline_string(leaf):
-        # For the purposes of docstring re-indentation, we don't need to do anything
-        # with single-line docstrings.
-        return False
-
     if prev_siblings_are(
         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
     ):
     if prev_siblings_are(
         leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
     ):