All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
8 from pathlib import Path
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
13 from unittest.mock import patch
15 from click import unstyle
16 from click.testing import CliRunner
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
28 def dump_to_stderr(*output: str) -> str:
29 return "\n" + "\n".join(output) + "\n"
32 def read_data(name: str) -> Tuple[str, str]:
33 """read_data('test_name') -> 'input', 'output'"""
34 if not name.endswith((".py", ".out", ".diff")):
36 _input: List[str] = []
37 _output: List[str] = []
38 with open(THIS_DIR / name, "r", encoding="utf8") as test:
39 lines = test.readlines()
42 line = line.replace(EMPTY_LINE, "")
43 if line.rstrip() == "# output":
48 if _input and not _output:
49 # If there's no output marker, treat the entire file as already pre-formatted.
51 return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56 with TemporaryDirectory() as workspace:
57 cache_dir = Path(workspace)
59 cache_dir = cache_dir / "new"
60 cache_file = cache_dir / "cache.pkl"
61 with patch("black.CACHE_DIR", cache_dir), patch("black.CACHE_FILE", cache_file):
66 def event_loop(close: bool) -> Iterator[None]:
67 policy = asyncio.get_event_loop_policy()
68 old_loop = policy.get_event_loop()
69 loop = policy.new_event_loop()
70 asyncio.set_event_loop(loop)
75 policy.set_event_loop(old_loop)
80 class BlackTestCase(unittest.TestCase):
83 def assertFormatEqual(self, expected: str, actual: str) -> None:
84 if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
85 bdv: black.DebugVisitor[Any]
86 black.out("Expected tree:", fg="green")
88 exp_node = black.lib2to3_parse(expected)
89 bdv = black.DebugVisitor()
90 list(bdv.visit(exp_node))
91 except Exception as ve:
93 black.out("Actual tree:", fg="red")
95 exp_node = black.lib2to3_parse(actual)
96 bdv = black.DebugVisitor()
97 list(bdv.visit(exp_node))
98 except Exception as ve:
100 self.assertEqual(expected, actual)
102 @patch("black.dump_to_file", dump_to_stderr)
103 def test_self(self) -> None:
104 source, expected = read_data("test_black")
106 self.assertFormatEqual(expected, actual)
107 black.assert_equivalent(source, actual)
108 black.assert_stable(source, actual, line_length=ll)
109 self.assertFalse(ff(THIS_FILE))
111 @patch("black.dump_to_file", dump_to_stderr)
112 def test_black(self) -> None:
113 source, expected = read_data("../black")
115 self.assertFormatEqual(expected, actual)
116 black.assert_equivalent(source, actual)
117 black.assert_stable(source, actual, line_length=ll)
118 self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
120 def test_piping(self) -> None:
121 source, expected = read_data("../black")
122 hold_stdin, hold_stdout = sys.stdin, sys.stdout
124 sys.stdin, sys.stdout = StringIO(source), StringIO()
125 sys.stdin.name = "<stdin>"
126 black.format_stdin_to_stdout(
127 line_length=ll, fast=True, write_back=black.WriteBack.YES
130 actual = sys.stdout.read()
132 sys.stdin, sys.stdout = hold_stdin, hold_stdout
133 self.assertFormatEqual(expected, actual)
134 black.assert_equivalent(source, actual)
135 black.assert_stable(source, actual, line_length=ll)
137 def test_piping_diff(self) -> None:
138 source, _ = read_data("expression.py")
139 expected, _ = read_data("expression.diff")
140 hold_stdin, hold_stdout = sys.stdin, sys.stdout
142 sys.stdin, sys.stdout = StringIO(source), StringIO()
143 sys.stdin.name = "<stdin>"
144 black.format_stdin_to_stdout(
145 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
148 actual = sys.stdout.read()
150 sys.stdin, sys.stdout = hold_stdin, hold_stdout
151 actual = actual.rstrip() + "\n" # the diff output has a trailing space
152 self.assertEqual(expected, actual)
154 @patch("black.dump_to_file", dump_to_stderr)
155 def test_setup(self) -> None:
156 source, expected = read_data("../setup")
158 self.assertFormatEqual(expected, actual)
159 black.assert_equivalent(source, actual)
160 black.assert_stable(source, actual, line_length=ll)
161 self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
163 @patch("black.dump_to_file", dump_to_stderr)
164 def test_function(self) -> None:
165 source, expected = read_data("function")
167 self.assertFormatEqual(expected, actual)
168 black.assert_equivalent(source, actual)
169 black.assert_stable(source, actual, line_length=ll)
171 @patch("black.dump_to_file", dump_to_stderr)
172 def test_expression(self) -> None:
173 source, expected = read_data("expression")
175 self.assertFormatEqual(expected, actual)
176 black.assert_equivalent(source, actual)
177 black.assert_stable(source, actual, line_length=ll)
179 def test_expression_ff(self) -> None:
180 source, expected = read_data("expression")
181 tmp_file = Path(black.dump_to_file(source))
183 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
184 with open(tmp_file, encoding="utf8") as f:
188 self.assertFormatEqual(expected, actual)
189 with patch("black.dump_to_file", dump_to_stderr):
190 black.assert_equivalent(source, actual)
191 black.assert_stable(source, actual, line_length=ll)
193 def test_expression_diff(self) -> None:
194 source, _ = read_data("expression.py")
195 expected, _ = read_data("expression.diff")
196 tmp_file = Path(black.dump_to_file(source))
197 hold_stdout = sys.stdout
199 sys.stdout = StringIO()
200 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
202 actual = sys.stdout.read()
203 actual = actual.replace(str(tmp_file), "<stdin>")
205 sys.stdout = hold_stdout
207 actual = actual.rstrip() + "\n" # the diff output has a trailing space
208 if expected != actual:
209 dump = black.dump_to_file(actual)
211 f"Expected diff isn't equal to the actual. If you made changes "
212 f"to expression.py and this is an anticipated difference, "
213 f"overwrite tests/expression.diff with {dump}"
215 self.assertEqual(expected, actual, msg)
217 @patch("black.dump_to_file", dump_to_stderr)
218 def test_fstring(self) -> None:
219 source, expected = read_data("fstring")
221 self.assertFormatEqual(expected, actual)
222 black.assert_equivalent(source, actual)
223 black.assert_stable(source, actual, line_length=ll)
225 @patch("black.dump_to_file", dump_to_stderr)
226 def test_string_quotes(self) -> None:
227 source, expected = read_data("string_quotes")
229 self.assertFormatEqual(expected, actual)
230 black.assert_equivalent(source, actual)
231 black.assert_stable(source, actual, line_length=ll)
233 @patch("black.dump_to_file", dump_to_stderr)
234 def test_comments(self) -> None:
235 source, expected = read_data("comments")
237 self.assertFormatEqual(expected, actual)
238 black.assert_equivalent(source, actual)
239 black.assert_stable(source, actual, line_length=ll)
241 @patch("black.dump_to_file", dump_to_stderr)
242 def test_comments2(self) -> None:
243 source, expected = read_data("comments2")
245 self.assertFormatEqual(expected, actual)
246 black.assert_equivalent(source, actual)
247 black.assert_stable(source, actual, line_length=ll)
249 @patch("black.dump_to_file", dump_to_stderr)
250 def test_comments3(self) -> None:
251 source, expected = read_data("comments3")
253 self.assertFormatEqual(expected, actual)
254 black.assert_equivalent(source, actual)
255 black.assert_stable(source, actual, line_length=ll)
257 @patch("black.dump_to_file", dump_to_stderr)
258 def test_comments4(self) -> None:
259 source, expected = read_data("comments4")
261 self.assertFormatEqual(expected, actual)
262 black.assert_equivalent(source, actual)
263 black.assert_stable(source, actual, line_length=ll)
265 @patch("black.dump_to_file", dump_to_stderr)
266 def test_comments5(self) -> None:
267 source, expected = read_data("comments5")
269 self.assertFormatEqual(expected, actual)
270 black.assert_equivalent(source, actual)
271 black.assert_stable(source, actual, line_length=ll)
273 @patch("black.dump_to_file", dump_to_stderr)
274 def test_cantfit(self) -> None:
275 source, expected = read_data("cantfit")
277 self.assertFormatEqual(expected, actual)
278 black.assert_equivalent(source, actual)
279 black.assert_stable(source, actual, line_length=ll)
281 @patch("black.dump_to_file", dump_to_stderr)
282 def test_import_spacing(self) -> None:
283 source, expected = read_data("import_spacing")
285 self.assertFormatEqual(expected, actual)
286 black.assert_equivalent(source, actual)
287 black.assert_stable(source, actual, line_length=ll)
289 @patch("black.dump_to_file", dump_to_stderr)
290 def test_composition(self) -> None:
291 source, expected = read_data("composition")
293 self.assertFormatEqual(expected, actual)
294 black.assert_equivalent(source, actual)
295 black.assert_stable(source, actual, line_length=ll)
297 @patch("black.dump_to_file", dump_to_stderr)
298 def test_empty_lines(self) -> None:
299 source, expected = read_data("empty_lines")
301 self.assertFormatEqual(expected, actual)
302 black.assert_equivalent(source, actual)
303 black.assert_stable(source, actual, line_length=ll)
305 @patch("black.dump_to_file", dump_to_stderr)
306 def test_python2(self) -> None:
307 source, expected = read_data("python2")
309 self.assertFormatEqual(expected, actual)
310 # black.assert_equivalent(source, actual)
311 black.assert_stable(source, actual, line_length=ll)
313 @patch("black.dump_to_file", dump_to_stderr)
314 def test_fmtonoff(self) -> None:
315 source, expected = read_data("fmtonoff")
317 self.assertFormatEqual(expected, actual)
318 black.assert_equivalent(source, actual)
319 black.assert_stable(source, actual, line_length=ll)
321 def test_report(self) -> None:
322 report = black.Report()
326 def out(msg: str, **kwargs: Any) -> None:
327 out_lines.append(msg)
329 def err(msg: str, **kwargs: Any) -> None:
330 err_lines.append(msg)
332 with patch("black.out", out), patch("black.err", err):
333 report.done(Path("f1"), black.Changed.NO)
334 self.assertEqual(len(out_lines), 1)
335 self.assertEqual(len(err_lines), 0)
336 self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
337 self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
338 self.assertEqual(report.return_code, 0)
339 report.done(Path("f2"), black.Changed.YES)
340 self.assertEqual(len(out_lines), 2)
341 self.assertEqual(len(err_lines), 0)
342 self.assertEqual(out_lines[-1], "reformatted f2")
344 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
346 report.done(Path("f3"), black.Changed.CACHED)
347 self.assertEqual(len(out_lines), 3)
348 self.assertEqual(len(err_lines), 0)
350 out_lines[-1], "f3 wasn't modified on disk since last run."
353 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
355 self.assertEqual(report.return_code, 0)
357 self.assertEqual(report.return_code, 1)
359 report.failed(Path("e1"), "boom")
360 self.assertEqual(len(out_lines), 3)
361 self.assertEqual(len(err_lines), 1)
362 self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
364 unstyle(str(report)),
365 "1 file reformatted, 2 files left unchanged, "
366 "1 file failed to reformat.",
368 self.assertEqual(report.return_code, 123)
369 report.done(Path("f3"), black.Changed.YES)
370 self.assertEqual(len(out_lines), 4)
371 self.assertEqual(len(err_lines), 1)
372 self.assertEqual(out_lines[-1], "reformatted f3")
374 unstyle(str(report)),
375 "2 files reformatted, 2 files left unchanged, "
376 "1 file failed to reformat.",
378 self.assertEqual(report.return_code, 123)
379 report.failed(Path("e2"), "boom")
380 self.assertEqual(len(out_lines), 4)
381 self.assertEqual(len(err_lines), 2)
382 self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
384 unstyle(str(report)),
385 "2 files reformatted, 2 files left unchanged, "
386 "2 files failed to reformat.",
388 self.assertEqual(report.return_code, 123)
389 report.done(Path("f4"), black.Changed.NO)
390 self.assertEqual(len(out_lines), 5)
391 self.assertEqual(len(err_lines), 2)
392 self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
394 unstyle(str(report)),
395 "2 files reformatted, 3 files left unchanged, "
396 "2 files failed to reformat.",
398 self.assertEqual(report.return_code, 123)
401 unstyle(str(report)),
402 "2 files would be reformatted, 3 files would be left unchanged, "
403 "2 files would fail to reformat.",
406 def test_is_python36(self) -> None:
407 node = black.lib2to3_parse("def f(*, arg): ...\n")
408 self.assertFalse(black.is_python36(node))
409 node = black.lib2to3_parse("def f(*, arg,): ...\n")
410 self.assertTrue(black.is_python36(node))
411 node = black.lib2to3_parse("def f(*, arg): f'string'\n")
412 self.assertTrue(black.is_python36(node))
413 source, expected = read_data("function")
414 node = black.lib2to3_parse(source)
415 self.assertTrue(black.is_python36(node))
416 node = black.lib2to3_parse(expected)
417 self.assertTrue(black.is_python36(node))
418 source, expected = read_data("expression")
419 node = black.lib2to3_parse(source)
420 self.assertFalse(black.is_python36(node))
421 node = black.lib2to3_parse(expected)
422 self.assertFalse(black.is_python36(node))
424 def test_debug_visitor(self) -> None:
425 source, _ = read_data("debug_visitor.py")
426 expected, _ = read_data("debug_visitor.out")
430 def out(msg: str, **kwargs: Any) -> None:
431 out_lines.append(msg)
433 def err(msg: str, **kwargs: Any) -> None:
434 err_lines.append(msg)
436 with patch("black.out", out), patch("black.err", err):
437 black.DebugVisitor.show(source)
438 actual = "\n".join(out_lines) + "\n"
440 if expected != actual:
441 log_name = black.dump_to_file(*out_lines)
445 f"AST print out is different. Actual version dumped to {log_name}",
448 def test_format_file_contents(self) -> None:
450 with self.assertRaises(black.NothingChanged):
451 black.format_file_contents(empty, line_length=ll, fast=False)
453 with self.assertRaises(black.NothingChanged):
454 black.format_file_contents(just_nl, line_length=ll, fast=False)
455 same = "l = [1, 2, 3]\n"
456 with self.assertRaises(black.NothingChanged):
457 black.format_file_contents(same, line_length=ll, fast=False)
458 different = "l = [1,2,3]"
460 actual = black.format_file_contents(different, line_length=ll, fast=False)
461 self.assertEqual(expected, actual)
462 invalid = "return if you can"
463 with self.assertRaises(ValueError) as e:
464 black.format_file_contents(invalid, line_length=ll, fast=False)
465 self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
467 def test_endmarker(self) -> None:
468 n = black.lib2to3_parse("\n")
469 self.assertEqual(n.type, black.syms.file_input)
470 self.assertEqual(len(n.children), 1)
471 self.assertEqual(n.children[0].type, black.token.ENDMARKER)
473 @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
474 def test_assertFormatEqual(self) -> None:
478 def out(msg: str, **kwargs: Any) -> None:
479 out_lines.append(msg)
481 def err(msg: str, **kwargs: Any) -> None:
482 err_lines.append(msg)
484 with patch("black.out", out), patch("black.err", err):
485 with self.assertRaises(AssertionError):
486 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
488 out_str = "".join(out_lines)
489 self.assertTrue("Expected tree:" in out_str)
490 self.assertTrue("Actual tree:" in out_str)
491 self.assertEqual("".join(err_lines), "")
493 def test_cache_broken_file(self) -> None:
494 with cache_dir() as workspace:
495 with black.CACHE_FILE.open("w") as fobj:
496 fobj.write("this is not a pickle")
497 self.assertEqual(black.read_cache(), {})
498 src = (workspace / "test.py").resolve()
499 with src.open("w") as fobj:
500 fobj.write("print('hello')")
501 result = CliRunner().invoke(black.main, [str(src)])
502 self.assertEqual(result.exit_code, 0)
503 cache = black.read_cache()
504 self.assertIn(src, cache)
506 def test_cache_single_file_already_cached(self) -> None:
507 with cache_dir() as workspace:
508 src = (workspace / "test.py").resolve()
509 with src.open("w") as fobj:
510 fobj.write("print('hello')")
511 black.write_cache({}, [src])
512 result = CliRunner().invoke(black.main, [str(src)])
513 self.assertEqual(result.exit_code, 0)
514 with src.open("r") as fobj:
515 self.assertEqual(fobj.read(), "print('hello')")
517 @event_loop(close=False)
518 def test_cache_multiple_files(self) -> None:
519 with cache_dir() as workspace, patch(
520 "black.ProcessPoolExecutor", new=ThreadPoolExecutor
522 one = (workspace / "one.py").resolve()
523 with one.open("w") as fobj:
524 fobj.write("print('hello')")
525 two = (workspace / "two.py").resolve()
526 with two.open("w") as fobj:
527 fobj.write("print('hello')")
528 black.write_cache({}, [one])
529 result = CliRunner().invoke(black.main, [str(workspace)])
530 self.assertEqual(result.exit_code, 0)
531 with one.open("r") as fobj:
532 self.assertEqual(fobj.read(), "print('hello')")
533 with two.open("r") as fobj:
534 self.assertEqual(fobj.read(), 'print("hello")\n')
535 cache = black.read_cache()
536 self.assertIn(one, cache)
537 self.assertIn(two, cache)
539 def test_no_cache_when_writeback_diff(self) -> None:
540 with cache_dir() as workspace:
541 src = (workspace / "test.py").resolve()
542 with src.open("w") as fobj:
543 fobj.write("print('hello')")
544 result = CliRunner().invoke(black.main, [str(src), "--diff"])
545 self.assertEqual(result.exit_code, 0)
546 self.assertFalse(black.CACHE_FILE.exists())
548 def test_no_cache_when_stdin(self) -> None:
550 result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
551 self.assertEqual(result.exit_code, 0)
552 self.assertFalse(black.CACHE_FILE.exists())
554 def test_read_cache_no_cachefile(self) -> None:
556 self.assertEqual(black.read_cache(), {})
558 def test_write_cache_read_cache(self) -> None:
559 with cache_dir() as workspace:
560 src = (workspace / "test.py").resolve()
562 black.write_cache({}, [src])
563 cache = black.read_cache()
564 self.assertIn(src, cache)
565 self.assertEqual(cache[src], black.get_cache_info(src))
567 def test_filter_cached(self) -> None:
568 with TemporaryDirectory() as workspace:
569 path = Path(workspace)
570 uncached = (path / "uncached").resolve()
571 cached = (path / "cached").resolve()
572 cached_but_changed = (path / "changed").resolve()
575 cached_but_changed.touch()
576 cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
577 todo, done = black.filter_cached(
578 cache, [uncached, cached, cached_but_changed]
580 self.assertEqual(todo, [uncached, cached_but_changed])
581 self.assertEqual(done, [cached])
583 def test_write_cache_creates_directory_if_needed(self) -> None:
584 with cache_dir(exists=False) as workspace:
585 self.assertFalse(workspace.exists())
586 black.write_cache({}, [])
587 self.assertTrue(workspace.exists())
589 @event_loop(close=False)
590 def test_failed_formatting_does_not_get_cached(self) -> None:
591 with cache_dir() as workspace, patch(
592 "black.ProcessPoolExecutor", new=ThreadPoolExecutor
594 failing = (workspace / "failing.py").resolve()
595 with failing.open("w") as fobj:
596 fobj.write("not actually python")
597 clean = (workspace / "clean.py").resolve()
598 with clean.open("w") as fobj:
599 fobj.write('print("hello")\n')
600 result = CliRunner().invoke(black.main, [str(workspace)])
601 self.assertEqual(result.exit_code, 123)
602 cache = black.read_cache()
603 self.assertNotIn(failing, cache)
604 self.assertIn(clean, cache)
606 def test_write_cache_write_fail(self) -> None:
607 with cache_dir(), patch.object(Path, "open") as mock:
608 mock.side_effect = OSError
609 black.write_cache({}, [])
611 def test_check_diff_use_together(self) -> None:
613 # Files which will be reformatted.
614 src1 = (THIS_DIR / "string_quotes.py").resolve()
615 result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
616 self.assertEqual(result.exit_code, 1)
618 # Files which will not be reformatted.
619 src2 = (THIS_DIR / "composition.py").resolve()
620 result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
621 self.assertEqual(result.exit_code, 0)
623 # Multi file command.
624 result = CliRunner().invoke(
625 black.main, [str(src1), str(src2), "--diff", "--check"]
627 self.assertEqual(result.exit_code, 1)
630 if __name__ == "__main__":