All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from contextlib import contextmanager
4 from functools import partial
5 from io import StringIO
7 from pathlib import Path
9 from tempfile import TemporaryDirectory
10 from typing import Any, List, Tuple, Iterator
12 from unittest.mock import patch
14 from click import unstyle
15 from click.testing import CliRunner
20 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
21 fs = partial(black.format_str, line_length=ll)
22 THIS_FILE = Path(__file__)
23 THIS_DIR = THIS_FILE.parent
24 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
27 def dump_to_stderr(*output: str) -> str:
28 return "\n" + "\n".join(output) + "\n"
31 def read_data(name: str) -> Tuple[str, str]:
32 """read_data('test_name') -> 'input', 'output'"""
33 if not name.endswith((".py", ".out", ".diff")):
35 _input: List[str] = []
36 _output: List[str] = []
37 with open(THIS_DIR / name, "r", encoding="utf8") as test:
38 lines = test.readlines()
41 line = line.replace(EMPTY_LINE, "")
42 if line.rstrip() == "# output":
47 if _input and not _output:
48 # If there's no output marker, treat the entire file as already pre-formatted.
50 return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
54 def cache_dir(exists: bool = True) -> Iterator[Path]:
55 with TemporaryDirectory() as workspace:
56 cache_dir = Path(workspace)
58 cache_dir = cache_dir / "new"
59 cache_file = cache_dir / "cache.pkl"
60 with patch("black.CACHE_DIR", cache_dir), patch("black.CACHE_FILE", cache_file):
65 def event_loop(close: bool) -> Iterator[None]:
66 policy = asyncio.get_event_loop_policy()
67 old_loop = policy.get_event_loop()
68 loop = policy.new_event_loop()
69 asyncio.set_event_loop(loop)
74 policy.set_event_loop(old_loop)
79 class BlackTestCase(unittest.TestCase):
82 def assertFormatEqual(self, expected: str, actual: str) -> None:
83 if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84 bdv: black.DebugVisitor[Any]
85 black.out("Expected tree:", fg="green")
87 exp_node = black.lib2to3_parse(expected)
88 bdv = black.DebugVisitor()
89 list(bdv.visit(exp_node))
90 except Exception as ve:
92 black.out("Actual tree:", fg="red")
94 exp_node = black.lib2to3_parse(actual)
95 bdv = black.DebugVisitor()
96 list(bdv.visit(exp_node))
97 except Exception as ve:
99 self.assertEqual(expected, actual)
101 @patch("black.dump_to_file", dump_to_stderr)
102 def test_self(self) -> None:
103 source, expected = read_data("test_black")
105 self.assertFormatEqual(expected, actual)
106 black.assert_equivalent(source, actual)
107 black.assert_stable(source, actual, line_length=ll)
108 self.assertFalse(ff(THIS_FILE))
110 @patch("black.dump_to_file", dump_to_stderr)
111 def test_black(self) -> None:
112 source, expected = read_data("../black")
114 self.assertFormatEqual(expected, actual)
115 black.assert_equivalent(source, actual)
116 black.assert_stable(source, actual, line_length=ll)
117 self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
119 def test_piping(self) -> None:
120 source, expected = read_data("../black")
121 hold_stdin, hold_stdout = sys.stdin, sys.stdout
123 sys.stdin, sys.stdout = StringIO(source), StringIO()
124 sys.stdin.name = "<stdin>"
125 black.format_stdin_to_stdout(
126 line_length=ll, fast=True, write_back=black.WriteBack.YES
129 actual = sys.stdout.read()
131 sys.stdin, sys.stdout = hold_stdin, hold_stdout
132 self.assertFormatEqual(expected, actual)
133 black.assert_equivalent(source, actual)
134 black.assert_stable(source, actual, line_length=ll)
136 def test_piping_diff(self) -> None:
137 source, _ = read_data("expression.py")
138 expected, _ = read_data("expression.diff")
139 hold_stdin, hold_stdout = sys.stdin, sys.stdout
141 sys.stdin, sys.stdout = StringIO(source), StringIO()
142 sys.stdin.name = "<stdin>"
143 black.format_stdin_to_stdout(
144 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
147 actual = sys.stdout.read()
149 sys.stdin, sys.stdout = hold_stdin, hold_stdout
150 actual = actual.rstrip() + "\n" # the diff output has a trailing space
151 self.assertEqual(expected, actual)
153 @patch("black.dump_to_file", dump_to_stderr)
154 def test_setup(self) -> None:
155 source, expected = read_data("../setup")
157 self.assertFormatEqual(expected, actual)
158 black.assert_equivalent(source, actual)
159 black.assert_stable(source, actual, line_length=ll)
160 self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
162 @patch("black.dump_to_file", dump_to_stderr)
163 def test_function(self) -> None:
164 source, expected = read_data("function")
166 self.assertFormatEqual(expected, actual)
167 black.assert_equivalent(source, actual)
168 black.assert_stable(source, actual, line_length=ll)
170 @patch("black.dump_to_file", dump_to_stderr)
171 def test_expression(self) -> None:
172 source, expected = read_data("expression")
174 self.assertFormatEqual(expected, actual)
175 black.assert_equivalent(source, actual)
176 black.assert_stable(source, actual, line_length=ll)
178 def test_expression_ff(self) -> None:
179 source, expected = read_data("expression")
180 tmp_file = Path(black.dump_to_file(source))
182 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
183 with open(tmp_file, encoding="utf8") as f:
187 self.assertFormatEqual(expected, actual)
188 with patch("black.dump_to_file", dump_to_stderr):
189 black.assert_equivalent(source, actual)
190 black.assert_stable(source, actual, line_length=ll)
192 def test_expression_diff(self) -> None:
193 source, _ = read_data("expression.py")
194 expected, _ = read_data("expression.diff")
195 tmp_file = Path(black.dump_to_file(source))
196 hold_stdout = sys.stdout
198 sys.stdout = StringIO()
199 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
201 actual = sys.stdout.read()
202 actual = actual.replace(tmp_file.name, "<stdin>")
204 sys.stdout = hold_stdout
206 actual = actual.rstrip() + "\n" # the diff output has a trailing space
207 if expected != actual:
208 dump = black.dump_to_file(actual)
210 f"Expected diff isn't equal to the actual. If you made changes "
211 f"to expression.py and this is an anticipated difference, "
212 f"overwrite tests/expression.diff with {dump}"
214 self.assertEqual(expected, actual, msg)
216 @patch("black.dump_to_file", dump_to_stderr)
217 def test_fstring(self) -> None:
218 source, expected = read_data("fstring")
220 self.assertFormatEqual(expected, actual)
221 black.assert_equivalent(source, actual)
222 black.assert_stable(source, actual, line_length=ll)
224 @patch("black.dump_to_file", dump_to_stderr)
225 def test_string_quotes(self) -> None:
226 source, expected = read_data("string_quotes")
228 self.assertFormatEqual(expected, actual)
229 black.assert_equivalent(source, actual)
230 black.assert_stable(source, actual, line_length=ll)
232 @patch("black.dump_to_file", dump_to_stderr)
233 def test_comments(self) -> None:
234 source, expected = read_data("comments")
236 self.assertFormatEqual(expected, actual)
237 black.assert_equivalent(source, actual)
238 black.assert_stable(source, actual, line_length=ll)
240 @patch("black.dump_to_file", dump_to_stderr)
241 def test_comments2(self) -> None:
242 source, expected = read_data("comments2")
244 self.assertFormatEqual(expected, actual)
245 black.assert_equivalent(source, actual)
246 black.assert_stable(source, actual, line_length=ll)
248 @patch("black.dump_to_file", dump_to_stderr)
249 def test_comments3(self) -> None:
250 source, expected = read_data("comments3")
252 self.assertFormatEqual(expected, actual)
253 black.assert_equivalent(source, actual)
254 black.assert_stable(source, actual, line_length=ll)
256 @patch("black.dump_to_file", dump_to_stderr)
257 def test_comments4(self) -> None:
258 source, expected = read_data("comments4")
260 self.assertFormatEqual(expected, actual)
261 black.assert_equivalent(source, actual)
262 black.assert_stable(source, actual, line_length=ll)
264 @patch("black.dump_to_file", dump_to_stderr)
265 def test_cantfit(self) -> None:
266 source, expected = read_data("cantfit")
268 self.assertFormatEqual(expected, actual)
269 black.assert_equivalent(source, actual)
270 black.assert_stable(source, actual, line_length=ll)
272 @patch("black.dump_to_file", dump_to_stderr)
273 def test_import_spacing(self) -> None:
274 source, expected = read_data("import_spacing")
276 self.assertFormatEqual(expected, actual)
277 black.assert_equivalent(source, actual)
278 black.assert_stable(source, actual, line_length=ll)
280 @patch("black.dump_to_file", dump_to_stderr)
281 def test_composition(self) -> None:
282 source, expected = read_data("composition")
284 self.assertFormatEqual(expected, actual)
285 black.assert_equivalent(source, actual)
286 black.assert_stable(source, actual, line_length=ll)
288 @patch("black.dump_to_file", dump_to_stderr)
289 def test_empty_lines(self) -> None:
290 source, expected = read_data("empty_lines")
292 self.assertFormatEqual(expected, actual)
293 black.assert_equivalent(source, actual)
294 black.assert_stable(source, actual, line_length=ll)
296 @patch("black.dump_to_file", dump_to_stderr)
297 def test_python2(self) -> None:
298 source, expected = read_data("python2")
300 self.assertFormatEqual(expected, actual)
301 # black.assert_equivalent(source, actual)
302 black.assert_stable(source, actual, line_length=ll)
304 @patch("black.dump_to_file", dump_to_stderr)
305 def test_fmtonoff(self) -> None:
306 source, expected = read_data("fmtonoff")
308 self.assertFormatEqual(expected, actual)
309 black.assert_equivalent(source, actual)
310 black.assert_stable(source, actual, line_length=ll)
312 def test_report(self) -> None:
313 report = black.Report()
317 def out(msg: str, **kwargs: Any) -> None:
318 out_lines.append(msg)
320 def err(msg: str, **kwargs: Any) -> None:
321 err_lines.append(msg)
323 with patch("black.out", out), patch("black.err", err):
324 report.done(Path("f1"), black.Changed.NO)
325 self.assertEqual(len(out_lines), 1)
326 self.assertEqual(len(err_lines), 0)
327 self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
328 self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
329 self.assertEqual(report.return_code, 0)
330 report.done(Path("f2"), black.Changed.YES)
331 self.assertEqual(len(out_lines), 2)
332 self.assertEqual(len(err_lines), 0)
333 self.assertEqual(out_lines[-1], "reformatted f2")
335 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
337 report.done(Path("f3"), black.Changed.CACHED)
338 self.assertEqual(len(out_lines), 3)
339 self.assertEqual(len(err_lines), 0)
341 out_lines[-1], "f3 wasn't modified on disk since last run."
344 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
346 self.assertEqual(report.return_code, 0)
348 self.assertEqual(report.return_code, 1)
350 report.failed(Path("e1"), "boom")
351 self.assertEqual(len(out_lines), 3)
352 self.assertEqual(len(err_lines), 1)
353 self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
355 unstyle(str(report)),
356 "1 file reformatted, 2 files left unchanged, "
357 "1 file failed to reformat.",
359 self.assertEqual(report.return_code, 123)
360 report.done(Path("f3"), black.Changed.YES)
361 self.assertEqual(len(out_lines), 4)
362 self.assertEqual(len(err_lines), 1)
363 self.assertEqual(out_lines[-1], "reformatted f3")
365 unstyle(str(report)),
366 "2 files reformatted, 2 files left unchanged, "
367 "1 file failed to reformat.",
369 self.assertEqual(report.return_code, 123)
370 report.failed(Path("e2"), "boom")
371 self.assertEqual(len(out_lines), 4)
372 self.assertEqual(len(err_lines), 2)
373 self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
375 unstyle(str(report)),
376 "2 files reformatted, 2 files left unchanged, "
377 "2 files failed to reformat.",
379 self.assertEqual(report.return_code, 123)
380 report.done(Path("f4"), black.Changed.NO)
381 self.assertEqual(len(out_lines), 5)
382 self.assertEqual(len(err_lines), 2)
383 self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
385 unstyle(str(report)),
386 "2 files reformatted, 3 files left unchanged, "
387 "2 files failed to reformat.",
389 self.assertEqual(report.return_code, 123)
392 unstyle(str(report)),
393 "2 files would be reformatted, 3 files would be left unchanged, "
394 "2 files would fail to reformat.",
397 def test_is_python36(self) -> None:
398 node = black.lib2to3_parse("def f(*, arg): ...\n")
399 self.assertFalse(black.is_python36(node))
400 node = black.lib2to3_parse("def f(*, arg,): ...\n")
401 self.assertTrue(black.is_python36(node))
402 node = black.lib2to3_parse("def f(*, arg): f'string'\n")
403 self.assertTrue(black.is_python36(node))
404 source, expected = read_data("function")
405 node = black.lib2to3_parse(source)
406 self.assertTrue(black.is_python36(node))
407 node = black.lib2to3_parse(expected)
408 self.assertTrue(black.is_python36(node))
409 source, expected = read_data("expression")
410 node = black.lib2to3_parse(source)
411 self.assertFalse(black.is_python36(node))
412 node = black.lib2to3_parse(expected)
413 self.assertFalse(black.is_python36(node))
415 def test_debug_visitor(self) -> None:
416 source, _ = read_data("debug_visitor.py")
417 expected, _ = read_data("debug_visitor.out")
421 def out(msg: str, **kwargs: Any) -> None:
422 out_lines.append(msg)
424 def err(msg: str, **kwargs: Any) -> None:
425 err_lines.append(msg)
427 with patch("black.out", out), patch("black.err", err):
428 black.DebugVisitor.show(source)
429 actual = "\n".join(out_lines) + "\n"
431 if expected != actual:
432 log_name = black.dump_to_file(*out_lines)
436 f"AST print out is different. Actual version dumped to {log_name}",
439 def test_format_file_contents(self) -> None:
441 with self.assertRaises(black.NothingChanged):
442 black.format_file_contents(empty, line_length=ll, fast=False)
444 with self.assertRaises(black.NothingChanged):
445 black.format_file_contents(just_nl, line_length=ll, fast=False)
446 same = "l = [1, 2, 3]\n"
447 with self.assertRaises(black.NothingChanged):
448 black.format_file_contents(same, line_length=ll, fast=False)
449 different = "l = [1,2,3]"
451 actual = black.format_file_contents(different, line_length=ll, fast=False)
452 self.assertEqual(expected, actual)
453 invalid = "return if you can"
454 with self.assertRaises(ValueError) as e:
455 black.format_file_contents(invalid, line_length=ll, fast=False)
456 self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
458 def test_endmarker(self) -> None:
459 n = black.lib2to3_parse("\n")
460 self.assertEqual(n.type, black.syms.file_input)
461 self.assertEqual(len(n.children), 1)
462 self.assertEqual(n.children[0].type, black.token.ENDMARKER)
464 @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
465 def test_assertFormatEqual(self) -> None:
469 def out(msg: str, **kwargs: Any) -> None:
470 out_lines.append(msg)
472 def err(msg: str, **kwargs: Any) -> None:
473 err_lines.append(msg)
475 with patch("black.out", out), patch("black.err", err):
476 with self.assertRaises(AssertionError):
477 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
479 out_str = "".join(out_lines)
480 self.assertTrue("Expected tree:" in out_str)
481 self.assertTrue("Actual tree:" in out_str)
482 self.assertEqual("".join(err_lines), "")
484 def test_cache_broken_file(self) -> None:
485 with cache_dir() as workspace:
486 with black.CACHE_FILE.open("w") as fobj:
487 fobj.write("this is not a pickle")
488 self.assertEqual(black.read_cache(), {})
489 src = (workspace / "test.py").resolve()
490 with src.open("w") as fobj:
491 fobj.write("print('hello')")
492 result = CliRunner().invoke(black.main, [str(src)])
493 self.assertEqual(result.exit_code, 0)
494 cache = black.read_cache()
495 self.assertIn(src, cache)
497 def test_cache_single_file_already_cached(self) -> None:
498 with cache_dir() as workspace:
499 src = (workspace / "test.py").resolve()
500 with src.open("w") as fobj:
501 fobj.write("print('hello')")
502 black.write_cache({}, [src])
503 result = CliRunner().invoke(black.main, [str(src)])
504 self.assertEqual(result.exit_code, 0)
505 with src.open("r") as fobj:
506 self.assertEqual(fobj.read(), "print('hello')")
508 @event_loop(close=False)
509 def test_cache_multiple_files(self) -> None:
510 with cache_dir() as workspace:
511 one = (workspace / "one.py").resolve()
512 with one.open("w") as fobj:
513 fobj.write("print('hello')")
514 two = (workspace / "two.py").resolve()
515 with two.open("w") as fobj:
516 fobj.write("print('hello')")
517 black.write_cache({}, [one])
518 result = CliRunner().invoke(black.main, [str(workspace)])
519 self.assertEqual(result.exit_code, 0)
520 with one.open("r") as fobj:
521 self.assertEqual(fobj.read(), "print('hello')")
522 with two.open("r") as fobj:
523 self.assertEqual(fobj.read(), 'print("hello")\n')
524 cache = black.read_cache()
525 self.assertIn(one, cache)
526 self.assertIn(two, cache)
528 def test_no_cache_when_writeback_diff(self) -> None:
529 with cache_dir() as workspace:
530 src = (workspace / "test.py").resolve()
531 with src.open("w") as fobj:
532 fobj.write("print('hello')")
533 result = CliRunner().invoke(black.main, [str(src), "--diff"])
534 self.assertEqual(result.exit_code, 0)
535 self.assertFalse(black.CACHE_FILE.exists())
537 def test_no_cache_when_stdin(self) -> None:
539 result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
540 self.assertEqual(result.exit_code, 0)
541 self.assertFalse(black.CACHE_FILE.exists())
543 def test_read_cache_no_cachefile(self) -> None:
545 self.assertEqual(black.read_cache(), {})
547 def test_write_cache_read_cache(self) -> None:
548 with cache_dir() as workspace:
549 src = (workspace / "test.py").resolve()
551 black.write_cache({}, [src])
552 cache = black.read_cache()
553 self.assertIn(src, cache)
554 self.assertEqual(cache[src], black.get_cache_info(src))
556 def test_filter_cached(self) -> None:
557 with TemporaryDirectory() as workspace:
558 path = Path(workspace)
559 uncached = (path / "uncached").resolve()
560 cached = (path / "cached").resolve()
561 cached_but_changed = (path / "changed").resolve()
564 cached_but_changed.touch()
565 cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
566 todo, done = black.filter_cached(
567 cache, [uncached, cached, cached_but_changed]
569 self.assertEqual(todo, [uncached, cached_but_changed])
570 self.assertEqual(done, [cached])
572 def test_write_cache_creates_directory_if_needed(self) -> None:
573 with cache_dir(exists=False) as workspace:
574 self.assertFalse(workspace.exists())
575 black.write_cache({}, [])
576 self.assertTrue(workspace.exists())
578 @event_loop(close=False)
579 def test_failed_formatting_does_not_get_cached(self) -> None:
580 with cache_dir() as workspace:
581 failing = (workspace / "failing.py").resolve()
582 with failing.open("w") as fobj:
583 fobj.write("not actually python")
584 clean = (workspace / "clean.py").resolve()
585 with clean.open("w") as fobj:
586 fobj.write('print("hello")\n')
587 result = CliRunner().invoke(black.main, [str(workspace)])
588 self.assertEqual(result.exit_code, 123)
589 cache = black.read_cache()
590 self.assertNotIn(failing, cache)
591 self.assertIn(clean, cache)
593 def test_write_cache_write_fail(self) -> None:
594 with cache_dir(), patch.object(Path, "open") as mock:
595 mock.side_effect = OSError
596 black.write_cache({}, [])
599 if __name__ == "__main__":