]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump peter-evans/find-comment from 2.0.1 to 2.1.0 (#3404)
[etc/vim.git] / src / black / __init__.py
index afd71e519162eeda6643aa3ff22ca9e9b060f6af..7d7ddbe093017e80915b450a2f87a8b785c332b8 100644 (file)
@@ -30,6 +30,7 @@ from typing import (
 import click
 from click.core import ParameterSource
 from mypy_extensions import mypyc_attr
 import click
 from click.core import ParameterSource
 from mypy_extensions import mypyc_attr
+from pathspec import PathSpec
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
 
 from _black_version import version as __version__
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
 
 from _black_version import version as __version__
@@ -61,7 +62,7 @@ from black.handle_ipynb_magics import (
     unmask_cell,
 )
 from black.linegen import LN, LineGenerator, transform_line
     unmask_cell,
 )
 from black.linegen import LN, LineGenerator, transform_line
-from black.lines import EmptyLineTracker, Line
+from black.lines import EmptyLineTracker, LinesBlock
 from black.mode import (
     FUTURE_FLAG_TO_FEATURE,
     VERSION_TO_FEATURES,
 from black.mode import (
     FUTURE_FLAG_TO_FEATURE,
     VERSION_TO_FEATURES,
@@ -497,8 +498,10 @@ def main(  # noqa: C901
             user_level_config = str(find_user_pyproject_toml())
             if config == user_level_config:
                 out(
             user_level_config = str(find_user_pyproject_toml())
             if config == user_level_config:
                 out(
-                    "Using configuration from user-level config at "
-                    f"'{user_level_config}'.",
+                    (
+                        "Using configuration from user-level config at "
+                        f"'{user_level_config}'."
+                    ),
                     fg="blue",
                 )
             elif config_source in (
                     fg="blue",
                 )
             elif config_source in (
@@ -625,6 +628,11 @@ def get_sources(
     sources: Set[Path] = set()
     root = ctx.obj["root"]
 
     sources: Set[Path] = set()
     root = ctx.obj["root"]
 
+    using_default_exclude = exclude is None
+    exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
+    gitignore: Optional[PathSpec] = None
+    root_gitignore = get_gitignore(root)
+
     for s in src:
         if s == "-" and stdin_filename:
             p = Path(stdin_filename)
     for s in src:
         if s == "-" and stdin_filename:
             p = Path(stdin_filename)
@@ -658,16 +666,11 @@ def get_sources(
 
             sources.add(p)
         elif p.is_dir():
 
             sources.add(p)
         elif p.is_dir():
-            if exclude is None:
-                exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
-                gitignore = get_gitignore(root)
-                p_gitignore = get_gitignore(p)
-                # No need to use p's gitignore if it is identical to root's gitignore
-                # (i.e. root and p point to the same directory).
-                if gitignore != p_gitignore:
-                    gitignore += p_gitignore
-            else:
-                gitignore = None
+            if using_default_exclude:
+                gitignore = {
+                    root: root_gitignore,
+                    root / p: get_gitignore(p),
+                }
             sources.update(
                 gen_python_files(
                     p.iterdir(),
             sources.update(
                 gen_python_files(
                     p.iterdir(),
@@ -914,7 +917,7 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
     `mode` is passed to :func:`format_str`.
     """
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
     `mode` is passed to :func:`format_str`.
     """
-    if not src_contents.strip():
+    if not mode.preview and not src_contents.strip():
         raise NothingChanged
 
     if mode.is_ipynb:
         raise NothingChanged
 
     if mode.is_ipynb:
@@ -1011,6 +1014,9 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon
     Operate cell-by-cell, only on code cells, only for Python notebooks.
     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
     """
     Operate cell-by-cell, only on code cells, only for Python notebooks.
     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
     """
+    if mode.preview and not src_contents:
+        raise NothingChanged
+
     trailing_newline = src_contents[-1] == "\n"
     modified = False
     nb = json.loads(src_contents)
     trailing_newline = src_contents[-1] == "\n"
     modified = False
     nb = json.loads(src_contents)
@@ -1075,7 +1081,7 @@ def format_str(src_contents: str, *, mode: Mode) -> str:
 
 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
 
 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
-    dst_contents = []
+    dst_blocks: List[LinesBlock] = []
     if mode.target_versions:
         versions = mode.target_versions
     else:
     if mode.target_versions:
         versions = mode.target_versions
     else:
@@ -1084,22 +1090,32 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
 
     normalize_fmt_off(src_node, preview=mode.preview)
     lines = LineGenerator(mode=mode)
 
     normalize_fmt_off(src_node, preview=mode.preview)
     lines = LineGenerator(mode=mode)
-    elt = EmptyLineTracker(is_pyi=mode.is_pyi)
-    empty_line = Line(mode=mode)
-    after = 0
+    elt = EmptyLineTracker(mode=mode)
     split_line_features = {
         feature
         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
         if supports_feature(versions, feature)
     }
     split_line_features = {
         feature
         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
         if supports_feature(versions, feature)
     }
+    block: Optional[LinesBlock] = None
     for current_line in lines.visit(src_node):
     for current_line in lines.visit(src_node):
-        dst_contents.append(str(empty_line) * after)
-        before, after = elt.maybe_empty_lines(current_line)
-        dst_contents.append(str(empty_line) * before)
+        block = elt.maybe_empty_lines(current_line)
+        dst_blocks.append(block)
         for line in transform_line(
             current_line, mode=mode, features=split_line_features
         ):
         for line in transform_line(
             current_line, mode=mode, features=split_line_features
         ):
-            dst_contents.append(str(line))
+            block.content_lines.append(str(line))
+    if dst_blocks:
+        dst_blocks[-1].after = 0
+    dst_contents = []
+    for block in dst_blocks:
+        dst_contents.extend(block.all_lines())
+    if mode.preview and not dst_contents:
+        # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
+        # and check if normalized_content has more than one line
+        normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
+        if "\n" in normalized_content:
+            return newline
+        return ""
     return "".join(dst_contents)
 
 
     return "".join(dst_contents)
 
 
@@ -1382,9 +1398,9 @@ def patch_click() -> None:
 
     for module in modules:
         if hasattr(module, "_verify_python3_env"):
 
     for module in modules:
         if hasattr(module, "_verify_python3_env"):
-            module._verify_python3_env = lambda: None  # type: ignore
+            module._verify_python3_env = lambda: None
         if hasattr(module, "_verify_python_env"):
         if hasattr(module, "_verify_python_env"):
-            module._verify_python_env = lambda: None  # type: ignore
+            module._verify_python_env = lambda: None
 
 
 def patched_main() -> None:
 
 
 def patched_main() -> None: