X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/4c352ad4be70c72ba9b949d3afb7c242522d058e..e5452a6b676c161d01ae0ac6cbb5a7cc4c395745:/tests/test_black.py diff --git a/tests/test_black.py b/tests/test_black.py index adf5ede..1f93e6a 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -3,7 +3,7 @@ import asyncio from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from functools import partial -from io import StringIO +from io import BytesIO, TextIOWrapper import os from pathlib import Path import sys @@ -121,8 +121,9 @@ class BlackTestCase(unittest.TestCase): source, expected = read_data("../black") hold_stdin, hold_stdout = sys.stdin, sys.stdout try: - sys.stdin, sys.stdout = StringIO(source), StringIO() - sys.stdin.name = "" + sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8") + sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8") + sys.stdin.buffer.name = "" # type: ignore black.format_stdin_to_stdout( line_length=ll, fast=True, write_back=black.WriteBack.YES ) @@ -139,8 +140,9 @@ class BlackTestCase(unittest.TestCase): expected, _ = read_data("expression.diff") hold_stdin, hold_stdout = sys.stdin, sys.stdout try: - sys.stdin, sys.stdout = StringIO(source), StringIO() - sys.stdin.name = "" + sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8") + sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8") + sys.stdin.buffer.name = "" # type: ignore black.format_stdin_to_stdout( line_length=ll, fast=True, write_back=black.WriteBack.DIFF ) @@ -204,7 +206,7 @@ class BlackTestCase(unittest.TestCase): tmp_file = Path(black.dump_to_file(source)) hold_stdout = sys.stdout try: - sys.stdout = StringIO() + sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8") self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF)) sys.stdout.seek(0) actual = sys.stdout.read() @@ -1108,6 +1110,18 @@ class BlackTestCase(unittest.TestCase): result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"]) self.assertEqual(result.exit_code, 2) + def test_preserves_line_endings(self) -> None: + with TemporaryDirectory() as workspace: + test_file = Path(workspace) / "test.py" + for nl in ["\n", "\r\n"]: + contents = nl.join(["def f( ):", " pass"]) + test_file.write_bytes(contents.encode()) + ff(test_file, write_back=black.WriteBack.YES) + updated_contents: bytes = test_file.read_bytes() + self.assertIn(nl.encode(), updated_contents) # type: ignore + if nl == "\n": + self.assertNotIn(b"\r\n", updated_contents) # type: ignore + if __name__ == "__main__": unittest.main()