from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from functools import partial
-from io import StringIO
+from io import BytesIO, TextIOWrapper
import os
from pathlib import Path
+import re
import sys
from tempfile import TemporaryDirectory
from typing import Any, List, Tuple, Iterator
import unittest
from unittest.mock import patch
-import re
from click import unstyle
from click.testing import CliRunner
import black
+
ll = 88
ff = partial(black.format_file_in_place, line_length=ll, fast=True)
fs = partial(black.format_str, line_length=ll)
black.err(str(ve))
self.assertEqual(expected, actual)
+ @patch("black.dump_to_file", dump_to_stderr)
+ def test_empty(self) -> None:
+ source = expected = ""
+ actual = fs(source)
+ self.assertFormatEqual(expected, actual)
+ black.assert_equivalent(source, actual)
+ black.assert_stable(source, actual, line_length=ll)
+
+ def test_empty_ff(self) -> None:
+ expected = ""
+ tmp_file = Path(black.dump_to_file())
+ try:
+ self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
+ with open(tmp_file, encoding="utf8") as f:
+ actual = f.read()
+ finally:
+ os.unlink(tmp_file)
+ self.assertFormatEqual(expected, actual)
+
@patch("black.dump_to_file", dump_to_stderr)
def test_self(self) -> None:
source, expected = read_data("test_black")
source, expected = read_data("../black")
hold_stdin, hold_stdout = sys.stdin, sys.stdout
try:
- sys.stdin, sys.stdout = StringIO(source), StringIO()
- sys.stdin.name = "<stdin>"
+ sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
+ sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
+ sys.stdin.buffer.name = "<stdin>" # type: ignore
black.format_stdin_to_stdout(
line_length=ll, fast=True, write_back=black.WriteBack.YES
)
black.assert_stable(source, actual, line_length=ll)
def test_piping_diff(self) -> None:
+ diff_header = re.compile(
+ rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
+ rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
+ )
source, _ = read_data("expression.py")
expected, _ = read_data("expression.diff")
hold_stdin, hold_stdout = sys.stdin, sys.stdout
try:
- sys.stdin, sys.stdout = StringIO(source), StringIO()
- sys.stdin.name = "<stdin>"
+ sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
+ sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
black.format_stdin_to_stdout(
line_length=ll, fast=True, write_back=black.WriteBack.DIFF
)
sys.stdout.seek(0)
actual = sys.stdout.read()
+ actual = diff_header.sub("[Deterministic header]", actual)
finally:
sys.stdin, sys.stdout = hold_stdin, hold_stdout
actual = actual.rstrip() + "\n" # the diff output has a trailing space
source, _ = read_data("expression.py")
expected, _ = read_data("expression.diff")
tmp_file = Path(black.dump_to_file(source))
+ diff_header = re.compile(
+ rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
+ rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
+ )
hold_stdout = sys.stdout
try:
- sys.stdout = StringIO()
+ sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
sys.stdout.seek(0)
actual = sys.stdout.read()
- actual = actual.replace(str(tmp_file), "<stdin>")
+ actual = diff_header.sub("[Deterministic header]", actual)
finally:
sys.stdout = hold_stdout
os.unlink(tmp_file)
cached_but_changed.touch()
cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
todo, done = black.filter_cached(
- cache, [uncached, cached, cached_but_changed]
+ cache, {uncached, cached, cached_but_changed}
)
- self.assertEqual(todo, [uncached, cached_but_changed])
- self.assertEqual(done, [cached])
+ self.assertEqual(todo, {uncached, cached_but_changed})
+ self.assertEqual(done, {cached})
def test_write_cache_creates_directory_if_needed(self) -> None:
mode = black.FileMode.AUTO_DETECT
result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
self.assertEqual(result.exit_code, 2)
+ def test_preserves_line_endings(self) -> None:
+ with TemporaryDirectory() as workspace:
+ test_file = Path(workspace) / "test.py"
+ for nl in ["\n", "\r\n"]:
+ contents = nl.join(["def f( ):", " pass"])
+ test_file.write_bytes(contents.encode())
+ ff(test_file, write_back=black.WriteBack.YES)
+ updated_contents: bytes = test_file.read_bytes()
+ self.assertIn(nl.encode(), updated_contents) # type: ignore
+ if nl == "\n":
+ self.assertNotIn(b"\r\n", updated_contents) # type: ignore
+
if __name__ == "__main__":
unittest.main()