def read_data(name: str) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'"""
- if not name.endswith((".py", ".out")):
+ if not name.endswith((".py", ".out", ".diff")):
name += ".py"
_input: List[str] = []
_output: List[str] = []
try:
sys.stdin, sys.stdout = StringIO(source), StringIO()
sys.stdin.name = "<stdin>"
- black.format_stdin_to_stdout(line_length=ll, fast=True, write_back=True)
+ black.format_stdin_to_stdout(
+ line_length=ll, fast=True, write_back=black.WriteBack.YES
+ )
sys.stdout.seek(0)
actual = sys.stdout.read()
finally:
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
+ def test_piping_diff(self) -> None:
+ source, _ = read_data("expression.py")
+ expected, _ = read_data("expression.diff")
+ hold_stdin, hold_stdout = sys.stdin, sys.stdout
+ try:
+ sys.stdin, sys.stdout = StringIO(source), StringIO()
+ sys.stdin.name = "<stdin>"
+ black.format_stdin_to_stdout(
+ line_length=ll, fast=True, write_back=black.WriteBack.DIFF
+ )
+ sys.stdout.seek(0)
+ actual = sys.stdout.read()
+ finally:
+ sys.stdin, sys.stdout = hold_stdin, hold_stdout
+ actual = actual.rstrip() + "\n" # the diff output has a trailing space
+ self.assertEqual(expected, actual)
+
@patch("black.dump_to_file", dump_to_stderr)
def test_setup(self) -> None:
source, expected = read_data("../setup")
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
+ def test_expression_ff(self) -> None:
+ source, expected = read_data("expression")
+ tmp_file = Path(black.dump_to_file(source))
+ try:
+ self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
+ with open(tmp_file) as f:
+ actual = f.read()
+ finally:
+ os.unlink(tmp_file)
+ self.assertFormatEqual(expected, actual)
+ with patch("black.dump_to_file", dump_to_stderr):
+ black.assert_equivalent(source, actual)
+ black.assert_stable(source, actual, line_length=ll)
+
+ def test_expression_diff(self) -> None:
+ source, _ = read_data("expression.py")
+ expected, _ = read_data("expression.diff")
+ tmp_file = Path(black.dump_to_file(source))
+ hold_stdout = sys.stdout
+ try:
+ sys.stdout = StringIO()
+ self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
+ sys.stdout.seek(0)
+ actual = sys.stdout.read()
+ actual = actual.replace(tmp_file.name, "<stdin>")
+ finally:
+ sys.stdout = hold_stdout
+ os.unlink(tmp_file)
+ actual = actual.rstrip() + "\n" # the diff output has a trailing space
+ if expected != actual:
+ dump = black.dump_to_file(actual)
+ msg = (
+ f"Expected diff isn't equal to the actual. If you made changes "
+ f"to expression.py and this is an anticipated difference, "
+ f"overwrite tests/expression.diff with {dump}."
+ )
+ self.assertEqual(expected, actual, msg)
+
@patch("black.dump_to_file", dump_to_stderr)
def test_fstring(self) -> None:
source, expected = read_data("fstring")
f"AST print out is different. Actual version dumped to {log_name}",
)
+ def test_format_file_contents(self) -> None:
+ empty = ""
+ with self.assertRaises(black.NothingChanged):
+ black.format_file_contents(empty, line_length=ll, fast=False)
+ just_nl = "\n"
+ with self.assertRaises(black.NothingChanged):
+ black.format_file_contents(just_nl, line_length=ll, fast=False)
+ same = "l = [1, 2, 3]\n"
+ with self.assertRaises(black.NothingChanged):
+ black.format_file_contents(same, line_length=ll, fast=False)
+ different = "l = [1,2,3]"
+ expected = same
+ actual = black.format_file_contents(different, line_length=ll, fast=False)
+ self.assertEqual(expected, actual)
+ invalid = "return if you can"
+ with self.assertRaises(ValueError) as e:
+ black.format_file_contents(invalid, line_length=ll, fast=False)
+ self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
+
+ def test_endmarker(self) -> None:
+ n = black.lib2to3_parse("\n")
+ self.assertEqual(n.type, black.syms.file_input)
+ self.assertEqual(len(n.children), 1)
+ self.assertEqual(n.children[0].type, black.token.ENDMARKER)
+
+ @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
+ def test_assertFormatEqual(self) -> None:
+ out_lines = []
+ err_lines = []
+
+ def out(msg: str, **kwargs: Any) -> None:
+ out_lines.append(msg)
+
+ def err(msg: str, **kwargs: Any) -> None:
+ err_lines.append(msg)
+
+ with patch("black.out", out), patch("black.err", err):
+ with self.assertRaises(AssertionError):
+ self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
+
+ out_str = "".join(out_lines)
+ self.assertTrue("Expected tree:" in out_str)
+ self.assertTrue("Actual tree:" in out_str)
+ self.assertEqual("".join(err_lines), "")
+
if __name__ == "__main__":
unittest.main()