]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Preserve line endings when formatting a file in place (#288)
authorZsolt Dollenstein <zsol.zsol@gmail.com>
Mon, 4 Jun 2018 22:52:06 +0000 (00:52 +0200)
committerŁukasz Langa <lukasz@langa.pl>
Mon, 4 Jun 2018 22:52:06 +0000 (15:52 -0700)
README.md
black.py
docs/reference/reference_functions.rst
tests/test_black.py

index b1eae742cd73b2b0efe25a0ca8f37433f7b16f74..beba56cc2ebff85a2ac7268e1aee841981ef7129 100644 (file)
--- a/README.md
+++ b/README.md
@@ -720,6 +720,8 @@ More details can be found in [CONTRIBUTING](CONTRIBUTING.md).
 * fixed stdin handling not working correctly if an old version of Click was
   used (#276)
 
+* *Black* now preserves line endings when formatting a file in place (#258)
+
 
 ### 18.5b1
 
index e59a1e5b6f4a291d2c46252ad419c93a6435a019..36e49b026a53492ba635621353c2e37ffa2b9520 100644 (file)
--- a/black.py
+++ b/black.py
@@ -4,6 +4,7 @@ from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 from enum import Enum, Flag
 from functools import partial, wraps
+import io
 import keyword
 import logging
 from multiprocessing import Manager
@@ -465,8 +466,9 @@ def format_file_in_place(
     """
     if src.suffix == ".pyi":
         mode |= FileMode.PYI
-    with tokenize.open(src) as src_buffer:
-        src_contents = src_buffer.read()
+
+    with open(src, "rb") as buf:
+        newline, encoding, src_contents = prepare_input(buf.read())
     try:
         dst_contents = format_file_contents(
             src_contents, line_length=line_length, fast=fast, mode=mode
@@ -475,7 +477,7 @@ def format_file_in_place(
         return False
 
     if write_back == write_back.YES:
-        with open(src, "w", encoding=src_buffer.encoding) as f:
+        with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back == write_back.DIFF:
         src_name = f"{src}  (original)"
@@ -484,7 +486,14 @@ def format_file_in_place(
         if lock:
             lock.acquire()
         try:
-            sys.stdout.write(diff_contents)
+            f = io.TextIOWrapper(
+                sys.stdout.buffer,
+                encoding=encoding,
+                newline=newline,
+                write_through=True,
+            )
+            f.write(diff_contents)
+            f.detach()
         finally:
             if lock:
                 lock.release()
@@ -503,7 +512,7 @@ def format_stdin_to_stdout(
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
-    src = sys.stdin.read()
+    newline, encoding, src = prepare_input(sys.stdin.buffer.read())
     dst = src
     try:
         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
@@ -514,11 +523,25 @@ def format_stdin_to_stdout(
 
     finally:
         if write_back == WriteBack.YES:
-            sys.stdout.write(dst)
+            f = io.TextIOWrapper(
+                sys.stdout.buffer,
+                encoding=encoding,
+                newline=newline,
+                write_through=True,
+            )
+            f.write(dst)
+            f.detach()
         elif write_back == WriteBack.DIFF:
             src_name = "<stdin>  (original)"
             dst_name = "<stdin>  (formatted)"
-            sys.stdout.write(diff(src, dst, src_name, dst_name))
+            f = io.TextIOWrapper(
+                sys.stdout.buffer,
+                encoding=encoding,
+                newline=newline,
+                write_through=True,
+            )
+            f.write(diff(src, dst, src_name, dst_name))
+            f.detach()
 
 
 def format_file_contents(
@@ -579,6 +602,19 @@ def format_str(
     return dst_contents
 
 
+def prepare_input(src: bytes) -> Tuple[str, str, str]:
+    """Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
+
+    Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
+    universal newlines (i.e. only LF).
+    """
+    srcbuf = io.BytesIO(src)
+    encoding, lines = tokenize.detect_encoding(srcbuf.readline)
+    newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
+    srcbuf.seek(0)
+    return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
+
+
 GRAMMARS = [
     pygram.python_grammar_no_print_statement_no_exec_statement,
     pygram.python_grammar_no_print_statement,
@@ -590,8 +626,7 @@ def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
     if src_txt[-1] != "\n":
-        nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
-        src_txt += nl
+        src_txt += "\n"
     for grammar in GRAMMARS:
         drv = driver.Driver(grammar, pytree.convert)
         try:
index 133f249d2b3eeb46397c768acfc9313096027ff1..a4f00dba0819fce2aaef68369812d7ed0892a3e8 100644 (file)
@@ -61,6 +61,8 @@ Parsing
 
 .. autofunction:: black.lib2to3_unparse
 
+.. autofunction:: black.prepare_input
+
 Split functions
 ---------------
 
index adf5ede63715fd4d58f0c624eb99385e3db6a5a5..1f93e6afdf7c2172621db8924b24e2d598b11b04 100644 (file)
@@ -3,7 +3,7 @@ import asyncio
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import contextmanager
 from functools import partial
-from io import StringIO
+from io import BytesIO, TextIOWrapper
 import os
 from pathlib import Path
 import sys
@@ -121,8 +121,9 @@ class BlackTestCase(unittest.TestCase):
         source, expected = read_data("../black")
         hold_stdin, hold_stdout = sys.stdin, sys.stdout
         try:
-            sys.stdin, sys.stdout = StringIO(source), StringIO()
-            sys.stdin.name = "<stdin>"
+            sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
+            sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
+            sys.stdin.buffer.name = "<stdin>"  # type: ignore
             black.format_stdin_to_stdout(
                 line_length=ll, fast=True, write_back=black.WriteBack.YES
             )
@@ -139,8 +140,9 @@ class BlackTestCase(unittest.TestCase):
         expected, _ = read_data("expression.diff")
         hold_stdin, hold_stdout = sys.stdin, sys.stdout
         try:
-            sys.stdin, sys.stdout = StringIO(source), StringIO()
-            sys.stdin.name = "<stdin>"
+            sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
+            sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
+            sys.stdin.buffer.name = "<stdin>"  # type: ignore
             black.format_stdin_to_stdout(
                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
             )
@@ -204,7 +206,7 @@ class BlackTestCase(unittest.TestCase):
         tmp_file = Path(black.dump_to_file(source))
         hold_stdout = sys.stdout
         try:
-            sys.stdout = StringIO()
+            sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
             sys.stdout.seek(0)
             actual = sys.stdout.read()
@@ -1108,6 +1110,18 @@ class BlackTestCase(unittest.TestCase):
             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
             self.assertEqual(result.exit_code, 2)
 
+    def test_preserves_line_endings(self) -> None:
+        with TemporaryDirectory() as workspace:
+            test_file = Path(workspace) / "test.py"
+            for nl in ["\n", "\r\n"]:
+                contents = nl.join(["def f(  ):", "    pass"])
+                test_file.write_bytes(contents.encode())
+                ff(test_file, write_back=black.WriteBack.YES)
+                updated_contents: bytes = test_file.read_bytes()
+                self.assertIn(nl.encode(), updated_contents)  # type: ignore
+                if nl == "\n":
+                    self.assertNotIn(b"\r\n", updated_contents)  # type: ignore
+
 
 if __name__ == "__main__":
     unittest.main()