]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Enforce empty lines before classes/functions with sticky leading comments. (#3302)
[etc/vim.git] / src / black / __init__.py
index 5b8c9749119c233d48ee817e10f3ffe33c8486fe..d9fba41ebd34136560e09ae296e1a4a425dec8f5 100644 (file)
@@ -61,7 +61,7 @@ from black.handle_ipynb_magics import (
     unmask_cell,
 )
 from black.linegen import LN, LineGenerator, transform_line
-from black.lines import EmptyLineTracker, Line
+from black.lines import EmptyLineTracker, LinesBlock
 from black.mode import (
     FUTURE_FLAG_TO_FEATURE,
     VERSION_TO_FEATURES,
@@ -248,6 +248,12 @@ def validate_regex(
     ),
     default=[],
 )
+@click.option(
+    "-x",
+    "--skip-source-first-line",
+    is_flag=True,
+    help="Skip the first line of the source code.",
+)
 @click.option(
     "-S",
     "--skip-string-normalization",
@@ -428,6 +434,7 @@ def main(  # noqa: C901
     pyi: bool,
     ipynb: bool,
     python_cell_magics: Sequence[str],
+    skip_source_first_line: bool,
     skip_string_normalization: bool,
     skip_magic_trailing_comma: bool,
     experimental_string_processing: bool,
@@ -528,6 +535,7 @@ def main(  # noqa: C901
         line_length=line_length,
         is_pyi=pyi,
         is_ipynb=ipynb,
+        skip_source_first_line=skip_source_first_line,
         string_normalization=not skip_string_normalization,
         magic_trailing_comma=not skip_magic_trailing_comma,
         experimental_string_processing=experimental_string_processing,
@@ -790,7 +798,10 @@ def format_file_in_place(
         mode = replace(mode, is_ipynb=True)
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
+    header = b""
     with open(src, "rb") as buf:
+        if mode.skip_source_first_line:
+            header = buf.readline()
         src_contents, encoding, newline = decode_bytes(buf.read())
     try:
         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
@@ -800,6 +811,8 @@ def format_file_in_place(
         raise ValueError(
             f"File '{src}' cannot be parsed as valid Jupyter notebook."
         ) from None
+    src_contents = header.decode(encoding) + src_contents
+    dst_contents = header.decode(encoding) + dst_contents
 
     if write_back == WriteBack.YES:
         with open(src, "w", encoding=encoding, newline=newline) as f:
@@ -1062,7 +1075,7 @@ def format_str(src_contents: str, *, mode: Mode) -> str:
 
 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
-    dst_contents = []
+    dst_blocks: List[LinesBlock] = []
     if mode.target_versions:
         versions = mode.target_versions
     else:
@@ -1071,22 +1084,25 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
 
     normalize_fmt_off(src_node, preview=mode.preview)
     lines = LineGenerator(mode=mode)
-    elt = EmptyLineTracker(is_pyi=mode.is_pyi)
-    empty_line = Line(mode=mode)
-    after = 0
+    elt = EmptyLineTracker(mode=mode)
     split_line_features = {
         feature
         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
         if supports_feature(versions, feature)
     }
+    block: Optional[LinesBlock] = None
     for current_line in lines.visit(src_node):
-        dst_contents.append(str(empty_line) * after)
-        before, after = elt.maybe_empty_lines(current_line)
-        dst_contents.append(str(empty_line) * before)
+        block = elt.maybe_empty_lines(current_line)
+        dst_blocks.append(block)
         for line in transform_line(
             current_line, mode=mode, features=split_line_features
         ):
-            dst_contents.append(str(line))
+            block.content_lines.append(str(line))
+    if dst_blocks:
+        dst_blocks[-1].after = 0
+    dst_contents = []
+    for block in dst_blocks:
+        dst_contents.extend(block.all_lines())
     return "".join(dst_contents)
 
 
@@ -1369,9 +1385,9 @@ def patch_click() -> None:
 
     for module in modules:
         if hasattr(module, "_verify_python3_env"):
-            module._verify_python3_env = lambda: None  # type: ignore
+            module._verify_python3_env = lambda: None
         if hasattr(module, "_verify_python_env"):
-            module._verify_python_env = lambda: None  # type: ignore
+            module._verify_python_env = lambda: None
 
 
 def patched_main() -> None: