X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/a20a3eeb0f738d3434efe3be8932db11722757a4..661908cd0282ff464794a8193475693e9130b866:/tests/test_black.py diff --git a/tests/test_black.py b/tests/test_black.py index 226a119..a4d2382 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -150,7 +150,7 @@ class BlackTestCase(unittest.TestCase): tmp_file = Path(black.dump_to_file(source)) try: self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES)) - with open(tmp_file) as f: + with open(tmp_file, encoding="utf8") as f: actual = f.read() finally: os.unlink(tmp_file) @@ -174,7 +174,14 @@ class BlackTestCase(unittest.TestCase): sys.stdout = hold_stdout os.unlink(tmp_file) actual = actual.rstrip() + "\n" # the diff output has a trailing space - self.assertEqual(expected, actual) + if expected != actual: + dump = black.dump_to_file(actual) + msg = ( + f"Expected diff isn't equal to the actual. If you made changes " + f"to expression.py and this is an anticipated difference, " + f"overwrite tests/expression.diff with {dump}" + ) + self.assertEqual(expected, actual, msg) @patch("black.dump_to_file", dump_to_stderr) def test_fstring(self) -> None: @@ -390,6 +397,51 @@ class BlackTestCase(unittest.TestCase): f"AST print out is different. Actual version dumped to {log_name}", ) + def test_format_file_contents(self) -> None: + empty = "" + with self.assertRaises(black.NothingChanged): + black.format_file_contents(empty, line_length=ll, fast=False) + just_nl = "\n" + with self.assertRaises(black.NothingChanged): + black.format_file_contents(just_nl, line_length=ll, fast=False) + same = "l = [1, 2, 3]\n" + with self.assertRaises(black.NothingChanged): + black.format_file_contents(same, line_length=ll, fast=False) + different = "l = [1,2,3]" + expected = same + actual = black.format_file_contents(different, line_length=ll, fast=False) + self.assertEqual(expected, actual) + invalid = "return if you can" + with self.assertRaises(ValueError) as e: + black.format_file_contents(invalid, line_length=ll, fast=False) + self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can") + + def test_endmarker(self) -> None: + n = black.lib2to3_parse("\n") + self.assertEqual(n.type, black.syms.file_input) + self.assertEqual(len(n.children), 1) + self.assertEqual(n.children[0].type, black.token.ENDMARKER) + + @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT") + def test_assertFormatEqual(self) -> None: + out_lines = [] + err_lines = [] + + def out(msg: str, **kwargs: Any) -> None: + out_lines.append(msg) + + def err(msg: str, **kwargs: Any) -> None: + err_lines.append(msg) + + with patch("black.out", out), patch("black.err", err): + with self.assertRaises(AssertionError): + self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]") + + out_str = "".join(out_lines) + self.assertTrue("Expected tree:" in out_str) + self.assertTrue("Actual tree:" in out_str) + self.assertEqual("".join(err_lines), "") + if __name__ == "__main__": unittest.main()