All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
   3 from contextlib import contextmanager
 
   4 from functools import partial
 
   5 from io import StringIO
 
   7 from pathlib import Path
 
   9 from tempfile import TemporaryDirectory
 
  10 from typing import Any, List, Tuple, Iterator
 
  12 from unittest.mock import patch
 
  14 from click import unstyle
 
  15 from click.testing import CliRunner
 
  20 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
 
  21 fs = partial(black.format_str, line_length=ll)
 
  22 THIS_FILE = Path(__file__)
 
  23 THIS_DIR = THIS_FILE.parent
 
  24 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
 
  27 def dump_to_stderr(*output: str) -> str:
 
  28     return "\n" + "\n".join(output) + "\n"
 
  31 def read_data(name: str) -> Tuple[str, str]:
 
  32     """read_data('test_name') -> 'input', 'output'"""
 
  33     if not name.endswith((".py", ".out", ".diff")):
 
  35     _input: List[str] = []
 
  36     _output: List[str] = []
 
  37     with open(THIS_DIR / name, "r", encoding="utf8") as test:
 
  38         lines = test.readlines()
 
  41         line = line.replace(EMPTY_LINE, "")
 
  42         if line.rstrip() == "# output":
 
  47     if _input and not _output:
 
  48         # If there's no output marker, treat the entire file as already pre-formatted.
 
  50     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
 
  54 def cache_dir(exists: bool = True) -> Iterator[Path]:
 
  55     with TemporaryDirectory() as workspace:
 
  56         cache_dir = Path(workspace)
 
  58             cache_dir = cache_dir / "new"
 
  59         cache_file = cache_dir / "cache.pkl"
 
  60         with patch("black.CACHE_DIR", cache_dir), patch("black.CACHE_FILE", cache_file):
 
  65 def event_loop(close: bool) -> Iterator[None]:
 
  66     policy = asyncio.get_event_loop_policy()
 
  67     old_loop = policy.get_event_loop()
 
  68     loop = policy.new_event_loop()
 
  69     asyncio.set_event_loop(loop)
 
  74         policy.set_event_loop(old_loop)
 
  79 class BlackTestCase(unittest.TestCase):
 
  82     def assertFormatEqual(self, expected: str, actual: str) -> None:
 
  83         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
 
  84             bdv: black.DebugVisitor[Any]
 
  85             black.out("Expected tree:", fg="green")
 
  87                 exp_node = black.lib2to3_parse(expected)
 
  88                 bdv = black.DebugVisitor()
 
  89                 list(bdv.visit(exp_node))
 
  90             except Exception as ve:
 
  92             black.out("Actual tree:", fg="red")
 
  94                 exp_node = black.lib2to3_parse(actual)
 
  95                 bdv = black.DebugVisitor()
 
  96                 list(bdv.visit(exp_node))
 
  97             except Exception as ve:
 
  99         self.assertEqual(expected, actual)
 
 101     @patch("black.dump_to_file", dump_to_stderr)
 
 102     def test_self(self) -> None:
 
 103         source, expected = read_data("test_black")
 
 105         self.assertFormatEqual(expected, actual)
 
 106         black.assert_equivalent(source, actual)
 
 107         black.assert_stable(source, actual, line_length=ll)
 
 108         self.assertFalse(ff(THIS_FILE))
 
 110     @patch("black.dump_to_file", dump_to_stderr)
 
 111     def test_black(self) -> None:
 
 112         source, expected = read_data("../black")
 
 114         self.assertFormatEqual(expected, actual)
 
 115         black.assert_equivalent(source, actual)
 
 116         black.assert_stable(source, actual, line_length=ll)
 
 117         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
 
 119     def test_piping(self) -> None:
 
 120         source, expected = read_data("../black")
 
 121         hold_stdin, hold_stdout = sys.stdin, sys.stdout
 
 123             sys.stdin, sys.stdout = StringIO(source), StringIO()
 
 124             sys.stdin.name = "<stdin>"
 
 125             black.format_stdin_to_stdout(
 
 126                 line_length=ll, fast=True, write_back=black.WriteBack.YES
 
 129             actual = sys.stdout.read()
 
 131             sys.stdin, sys.stdout = hold_stdin, hold_stdout
 
 132         self.assertFormatEqual(expected, actual)
 
 133         black.assert_equivalent(source, actual)
 
 134         black.assert_stable(source, actual, line_length=ll)
 
 136     def test_piping_diff(self) -> None:
 
 137         source, _ = read_data("expression.py")
 
 138         expected, _ = read_data("expression.diff")
 
 139         hold_stdin, hold_stdout = sys.stdin, sys.stdout
 
 141             sys.stdin, sys.stdout = StringIO(source), StringIO()
 
 142             sys.stdin.name = "<stdin>"
 
 143             black.format_stdin_to_stdout(
 
 144                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
 
 147             actual = sys.stdout.read()
 
 149             sys.stdin, sys.stdout = hold_stdin, hold_stdout
 
 150         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
 
 151         self.assertEqual(expected, actual)
 
 153     @patch("black.dump_to_file", dump_to_stderr)
 
 154     def test_setup(self) -> None:
 
 155         source, expected = read_data("../setup")
 
 157         self.assertFormatEqual(expected, actual)
 
 158         black.assert_equivalent(source, actual)
 
 159         black.assert_stable(source, actual, line_length=ll)
 
 160         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
 
 162     @patch("black.dump_to_file", dump_to_stderr)
 
 163     def test_function(self) -> None:
 
 164         source, expected = read_data("function")
 
 166         self.assertFormatEqual(expected, actual)
 
 167         black.assert_equivalent(source, actual)
 
 168         black.assert_stable(source, actual, line_length=ll)
 
 170     @patch("black.dump_to_file", dump_to_stderr)
 
 171     def test_expression(self) -> None:
 
 172         source, expected = read_data("expression")
 
 174         self.assertFormatEqual(expected, actual)
 
 175         black.assert_equivalent(source, actual)
 
 176         black.assert_stable(source, actual, line_length=ll)
 
 178     def test_expression_ff(self) -> None:
 
 179         source, expected = read_data("expression")
 
 180         tmp_file = Path(black.dump_to_file(source))
 
 182             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
 
 183             with open(tmp_file, encoding="utf8") as f:
 
 187         self.assertFormatEqual(expected, actual)
 
 188         with patch("black.dump_to_file", dump_to_stderr):
 
 189             black.assert_equivalent(source, actual)
 
 190             black.assert_stable(source, actual, line_length=ll)
 
 192     def test_expression_diff(self) -> None:
 
 193         source, _ = read_data("expression.py")
 
 194         expected, _ = read_data("expression.diff")
 
 195         tmp_file = Path(black.dump_to_file(source))
 
 196         hold_stdout = sys.stdout
 
 198             sys.stdout = StringIO()
 
 199             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
 
 201             actual = sys.stdout.read()
 
 202             actual = actual.replace(tmp_file.name, "<stdin>")
 
 204             sys.stdout = hold_stdout
 
 206         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
 
 207         if expected != actual:
 
 208             dump = black.dump_to_file(actual)
 
 210                 f"Expected diff isn't equal to the actual. If you made changes "
 
 211                 f"to expression.py and this is an anticipated difference, "
 
 212                 f"overwrite tests/expression.diff with {dump}"
 
 214             self.assertEqual(expected, actual, msg)
 
 216     @patch("black.dump_to_file", dump_to_stderr)
 
 217     def test_fstring(self) -> None:
 
 218         source, expected = read_data("fstring")
 
 220         self.assertFormatEqual(expected, actual)
 
 221         black.assert_equivalent(source, actual)
 
 222         black.assert_stable(source, actual, line_length=ll)
 
 224     @patch("black.dump_to_file", dump_to_stderr)
 
 225     def test_string_quotes(self) -> None:
 
 226         source, expected = read_data("string_quotes")
 
 228         self.assertFormatEqual(expected, actual)
 
 229         black.assert_equivalent(source, actual)
 
 230         black.assert_stable(source, actual, line_length=ll)
 
 232     @patch("black.dump_to_file", dump_to_stderr)
 
 233     def test_comments(self) -> None:
 
 234         source, expected = read_data("comments")
 
 236         self.assertFormatEqual(expected, actual)
 
 237         black.assert_equivalent(source, actual)
 
 238         black.assert_stable(source, actual, line_length=ll)
 
 240     @patch("black.dump_to_file", dump_to_stderr)
 
 241     def test_comments2(self) -> None:
 
 242         source, expected = read_data("comments2")
 
 244         self.assertFormatEqual(expected, actual)
 
 245         black.assert_equivalent(source, actual)
 
 246         black.assert_stable(source, actual, line_length=ll)
 
 248     @patch("black.dump_to_file", dump_to_stderr)
 
 249     def test_comments3(self) -> None:
 
 250         source, expected = read_data("comments3")
 
 252         self.assertFormatEqual(expected, actual)
 
 253         black.assert_equivalent(source, actual)
 
 254         black.assert_stable(source, actual, line_length=ll)
 
 256     @patch("black.dump_to_file", dump_to_stderr)
 
 257     def test_comments4(self) -> None:
 
 258         source, expected = read_data("comments4")
 
 260         self.assertFormatEqual(expected, actual)
 
 261         black.assert_equivalent(source, actual)
 
 262         black.assert_stable(source, actual, line_length=ll)
 
 264     @patch("black.dump_to_file", dump_to_stderr)
 
 265     def test_cantfit(self) -> None:
 
 266         source, expected = read_data("cantfit")
 
 268         self.assertFormatEqual(expected, actual)
 
 269         black.assert_equivalent(source, actual)
 
 270         black.assert_stable(source, actual, line_length=ll)
 
 272     @patch("black.dump_to_file", dump_to_stderr)
 
 273     def test_import_spacing(self) -> None:
 
 274         source, expected = read_data("import_spacing")
 
 276         self.assertFormatEqual(expected, actual)
 
 277         black.assert_equivalent(source, actual)
 
 278         black.assert_stable(source, actual, line_length=ll)
 
 280     @patch("black.dump_to_file", dump_to_stderr)
 
 281     def test_composition(self) -> None:
 
 282         source, expected = read_data("composition")
 
 284         self.assertFormatEqual(expected, actual)
 
 285         black.assert_equivalent(source, actual)
 
 286         black.assert_stable(source, actual, line_length=ll)
 
 288     @patch("black.dump_to_file", dump_to_stderr)
 
 289     def test_empty_lines(self) -> None:
 
 290         source, expected = read_data("empty_lines")
 
 292         self.assertFormatEqual(expected, actual)
 
 293         black.assert_equivalent(source, actual)
 
 294         black.assert_stable(source, actual, line_length=ll)
 
 296     @patch("black.dump_to_file", dump_to_stderr)
 
 297     def test_python2(self) -> None:
 
 298         source, expected = read_data("python2")
 
 300         self.assertFormatEqual(expected, actual)
 
 301         # black.assert_equivalent(source, actual)
 
 302         black.assert_stable(source, actual, line_length=ll)
 
 304     @patch("black.dump_to_file", dump_to_stderr)
 
 305     def test_fmtonoff(self) -> None:
 
 306         source, expected = read_data("fmtonoff")
 
 308         self.assertFormatEqual(expected, actual)
 
 309         black.assert_equivalent(source, actual)
 
 310         black.assert_stable(source, actual, line_length=ll)
 
 312     def test_report(self) -> None:
 
 313         report = black.Report()
 
 317         def out(msg: str, **kwargs: Any) -> None:
 
 318             out_lines.append(msg)
 
 320         def err(msg: str, **kwargs: Any) -> None:
 
 321             err_lines.append(msg)
 
 323         with patch("black.out", out), patch("black.err", err):
 
 324             report.done(Path("f1"), black.Changed.NO)
 
 325             self.assertEqual(len(out_lines), 1)
 
 326             self.assertEqual(len(err_lines), 0)
 
 327             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
 
 328             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
 
 329             self.assertEqual(report.return_code, 0)
 
 330             report.done(Path("f2"), black.Changed.YES)
 
 331             self.assertEqual(len(out_lines), 2)
 
 332             self.assertEqual(len(err_lines), 0)
 
 333             self.assertEqual(out_lines[-1], "reformatted f2")
 
 335                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
 
 337             report.done(Path("f3"), black.Changed.CACHED)
 
 338             self.assertEqual(len(out_lines), 3)
 
 339             self.assertEqual(len(err_lines), 0)
 
 341                 out_lines[-1], "f3 wasn't modified on disk since last run."
 
 344                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
 
 346             self.assertEqual(report.return_code, 0)
 
 348             self.assertEqual(report.return_code, 1)
 
 350             report.failed(Path("e1"), "boom")
 
 351             self.assertEqual(len(out_lines), 3)
 
 352             self.assertEqual(len(err_lines), 1)
 
 353             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
 
 355                 unstyle(str(report)),
 
 356                 "1 file reformatted, 2 files left unchanged, "
 
 357                 "1 file failed to reformat.",
 
 359             self.assertEqual(report.return_code, 123)
 
 360             report.done(Path("f3"), black.Changed.YES)
 
 361             self.assertEqual(len(out_lines), 4)
 
 362             self.assertEqual(len(err_lines), 1)
 
 363             self.assertEqual(out_lines[-1], "reformatted f3")
 
 365                 unstyle(str(report)),
 
 366                 "2 files reformatted, 2 files left unchanged, "
 
 367                 "1 file failed to reformat.",
 
 369             self.assertEqual(report.return_code, 123)
 
 370             report.failed(Path("e2"), "boom")
 
 371             self.assertEqual(len(out_lines), 4)
 
 372             self.assertEqual(len(err_lines), 2)
 
 373             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
 
 375                 unstyle(str(report)),
 
 376                 "2 files reformatted, 2 files left unchanged, "
 
 377                 "2 files failed to reformat.",
 
 379             self.assertEqual(report.return_code, 123)
 
 380             report.done(Path("f4"), black.Changed.NO)
 
 381             self.assertEqual(len(out_lines), 5)
 
 382             self.assertEqual(len(err_lines), 2)
 
 383             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
 
 385                 unstyle(str(report)),
 
 386                 "2 files reformatted, 3 files left unchanged, "
 
 387                 "2 files failed to reformat.",
 
 389             self.assertEqual(report.return_code, 123)
 
 392                 unstyle(str(report)),
 
 393                 "2 files would be reformatted, 3 files would be left unchanged, "
 
 394                 "2 files would fail to reformat.",
 
 397     def test_is_python36(self) -> None:
 
 398         node = black.lib2to3_parse("def f(*, arg): ...\n")
 
 399         self.assertFalse(black.is_python36(node))
 
 400         node = black.lib2to3_parse("def f(*, arg,): ...\n")
 
 401         self.assertTrue(black.is_python36(node))
 
 402         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
 
 403         self.assertTrue(black.is_python36(node))
 
 404         source, expected = read_data("function")
 
 405         node = black.lib2to3_parse(source)
 
 406         self.assertTrue(black.is_python36(node))
 
 407         node = black.lib2to3_parse(expected)
 
 408         self.assertTrue(black.is_python36(node))
 
 409         source, expected = read_data("expression")
 
 410         node = black.lib2to3_parse(source)
 
 411         self.assertFalse(black.is_python36(node))
 
 412         node = black.lib2to3_parse(expected)
 
 413         self.assertFalse(black.is_python36(node))
 
 415     def test_debug_visitor(self) -> None:
 
 416         source, _ = read_data("debug_visitor.py")
 
 417         expected, _ = read_data("debug_visitor.out")
 
 421         def out(msg: str, **kwargs: Any) -> None:
 
 422             out_lines.append(msg)
 
 424         def err(msg: str, **kwargs: Any) -> None:
 
 425             err_lines.append(msg)
 
 427         with patch("black.out", out), patch("black.err", err):
 
 428             black.DebugVisitor.show(source)
 
 429         actual = "\n".join(out_lines) + "\n"
 
 431         if expected != actual:
 
 432             log_name = black.dump_to_file(*out_lines)
 
 436             f"AST print out is different. Actual version dumped to {log_name}",
 
 439     def test_format_file_contents(self) -> None:
 
 441         with self.assertRaises(black.NothingChanged):
 
 442             black.format_file_contents(empty, line_length=ll, fast=False)
 
 444         with self.assertRaises(black.NothingChanged):
 
 445             black.format_file_contents(just_nl, line_length=ll, fast=False)
 
 446         same = "l = [1, 2, 3]\n"
 
 447         with self.assertRaises(black.NothingChanged):
 
 448             black.format_file_contents(same, line_length=ll, fast=False)
 
 449         different = "l = [1,2,3]"
 
 451         actual = black.format_file_contents(different, line_length=ll, fast=False)
 
 452         self.assertEqual(expected, actual)
 
 453         invalid = "return if you can"
 
 454         with self.assertRaises(ValueError) as e:
 
 455             black.format_file_contents(invalid, line_length=ll, fast=False)
 
 456         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
 
 458     def test_endmarker(self) -> None:
 
 459         n = black.lib2to3_parse("\n")
 
 460         self.assertEqual(n.type, black.syms.file_input)
 
 461         self.assertEqual(len(n.children), 1)
 
 462         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
 
 464     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
 
 465     def test_assertFormatEqual(self) -> None:
 
 469         def out(msg: str, **kwargs: Any) -> None:
 
 470             out_lines.append(msg)
 
 472         def err(msg: str, **kwargs: Any) -> None:
 
 473             err_lines.append(msg)
 
 475         with patch("black.out", out), patch("black.err", err):
 
 476             with self.assertRaises(AssertionError):
 
 477                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
 
 479         out_str = "".join(out_lines)
 
 480         self.assertTrue("Expected tree:" in out_str)
 
 481         self.assertTrue("Actual tree:" in out_str)
 
 482         self.assertEqual("".join(err_lines), "")
 
 484     def test_cache_broken_file(self) -> None:
 
 485         with cache_dir() as workspace:
 
 486             with black.CACHE_FILE.open("w") as fobj:
 
 487                 fobj.write("this is not a pickle")
 
 488             self.assertEqual(black.read_cache(), {})
 
 489             src = (workspace / "test.py").resolve()
 
 490             with src.open("w") as fobj:
 
 491                 fobj.write("print('hello')")
 
 492             result = CliRunner().invoke(black.main, [str(src)])
 
 493             self.assertEqual(result.exit_code, 0)
 
 494             cache = black.read_cache()
 
 495             self.assertIn(src, cache)
 
 497     def test_cache_single_file_already_cached(self) -> None:
 
 498         with cache_dir() as workspace:
 
 499             src = (workspace / "test.py").resolve()
 
 500             with src.open("w") as fobj:
 
 501                 fobj.write("print('hello')")
 
 502             black.write_cache({}, [src])
 
 503             result = CliRunner().invoke(black.main, [str(src)])
 
 504             self.assertEqual(result.exit_code, 0)
 
 505             with src.open("r") as fobj:
 
 506                 self.assertEqual(fobj.read(), "print('hello')")
 
 508     @event_loop(close=False)
 
 509     def test_cache_multiple_files(self) -> None:
 
 510         with cache_dir() as workspace:
 
 511             one = (workspace / "one.py").resolve()
 
 512             with one.open("w") as fobj:
 
 513                 fobj.write("print('hello')")
 
 514             two = (workspace / "two.py").resolve()
 
 515             with two.open("w") as fobj:
 
 516                 fobj.write("print('hello')")
 
 517             black.write_cache({}, [one])
 
 518             result = CliRunner().invoke(black.main, [str(workspace)])
 
 519             self.assertEqual(result.exit_code, 0)
 
 520             with one.open("r") as fobj:
 
 521                 self.assertEqual(fobj.read(), "print('hello')")
 
 522             with two.open("r") as fobj:
 
 523                 self.assertEqual(fobj.read(), 'print("hello")\n')
 
 524             cache = black.read_cache()
 
 525             self.assertIn(one, cache)
 
 526             self.assertIn(two, cache)
 
 528     def test_no_cache_when_writeback_diff(self) -> None:
 
 529         with cache_dir() as workspace:
 
 530             src = (workspace / "test.py").resolve()
 
 531             with src.open("w") as fobj:
 
 532                 fobj.write("print('hello')")
 
 533             result = CliRunner().invoke(black.main, [str(src), "--diff"])
 
 534             self.assertEqual(result.exit_code, 0)
 
 535             self.assertFalse(black.CACHE_FILE.exists())
 
 537     def test_no_cache_when_stdin(self) -> None:
 
 539             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
 
 540             self.assertEqual(result.exit_code, 0)
 
 541             self.assertFalse(black.CACHE_FILE.exists())
 
 543     def test_read_cache_no_cachefile(self) -> None:
 
 545             self.assertEqual(black.read_cache(), {})
 
 547     def test_write_cache_read_cache(self) -> None:
 
 548         with cache_dir() as workspace:
 
 549             src = (workspace / "test.py").resolve()
 
 551             black.write_cache({}, [src])
 
 552             cache = black.read_cache()
 
 553             self.assertIn(src, cache)
 
 554             self.assertEqual(cache[src], black.get_cache_info(src))
 
 556     def test_filter_cached(self) -> None:
 
 557         with TemporaryDirectory() as workspace:
 
 558             path = Path(workspace)
 
 559             uncached = (path / "uncached").resolve()
 
 560             cached = (path / "cached").resolve()
 
 561             cached_but_changed = (path / "changed").resolve()
 
 564             cached_but_changed.touch()
 
 565             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
 
 566             todo, done = black.filter_cached(
 
 567                 cache, [uncached, cached, cached_but_changed]
 
 569             self.assertEqual(todo, [uncached, cached_but_changed])
 
 570             self.assertEqual(done, [cached])
 
 572     def test_write_cache_creates_directory_if_needed(self) -> None:
 
 573         with cache_dir(exists=False) as workspace:
 
 574             self.assertFalse(workspace.exists())
 
 575             black.write_cache({}, [])
 
 576             self.assertTrue(workspace.exists())
 
 578     @event_loop(close=False)
 
 579     def test_failed_formatting_does_not_get_cached(self) -> None:
 
 580         with cache_dir() as workspace:
 
 581             failing = (workspace / "failing.py").resolve()
 
 582             with failing.open("w") as fobj:
 
 583                 fobj.write("not actually python")
 
 584             clean = (workspace / "clean.py").resolve()
 
 585             with clean.open("w") as fobj:
 
 586                 fobj.write('print("hello")\n')
 
 587             result = CliRunner().invoke(black.main, [str(workspace)])
 
 588             self.assertEqual(result.exit_code, 123)
 
 589             cache = black.read_cache()
 
 590             self.assertNotIn(failing, cache)
 
 591             self.assertIn(clean, cache)
 
 593     def test_write_cache_write_fail(self) -> None:
 
 594         with cache_dir(), patch.object(Path, "open") as mock:
 
 595             mock.side_effect = OSError
 
 596             black.write_cache({}, [])
 
 599 if __name__ == "__main__":