All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
8 from pathlib import Path
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
13 from unittest.mock import patch
15 from click import unstyle
16 from click.testing import CliRunner
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
28 def dump_to_stderr(*output: str) -> str:
29 return "\n" + "\n".join(output) + "\n"
32 def read_data(name: str) -> Tuple[str, str]:
33 """read_data('test_name') -> 'input', 'output'"""
34 if not name.endswith((".py", ".out", ".diff")):
36 _input: List[str] = []
37 _output: List[str] = []
38 with open(THIS_DIR / name, "r", encoding="utf8") as test:
39 lines = test.readlines()
42 line = line.replace(EMPTY_LINE, "")
43 if line.rstrip() == "# output":
48 if _input and not _output:
49 # If there's no output marker, treat the entire file as already pre-formatted.
51 return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56 with TemporaryDirectory() as workspace:
57 cache_dir = Path(workspace)
59 cache_dir = cache_dir / "new"
60 cache_file = cache_dir / "cache.pkl"
61 with patch("black.CACHE_DIR", cache_dir), patch("black.CACHE_FILE", cache_file):
66 def event_loop(close: bool) -> Iterator[None]:
67 policy = asyncio.get_event_loop_policy()
68 old_loop = policy.get_event_loop()
69 loop = policy.new_event_loop()
70 asyncio.set_event_loop(loop)
75 policy.set_event_loop(old_loop)
80 class BlackTestCase(unittest.TestCase):
83 def assertFormatEqual(self, expected: str, actual: str) -> None:
84 if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
85 bdv: black.DebugVisitor[Any]
86 black.out("Expected tree:", fg="green")
88 exp_node = black.lib2to3_parse(expected)
89 bdv = black.DebugVisitor()
90 list(bdv.visit(exp_node))
91 except Exception as ve:
93 black.out("Actual tree:", fg="red")
95 exp_node = black.lib2to3_parse(actual)
96 bdv = black.DebugVisitor()
97 list(bdv.visit(exp_node))
98 except Exception as ve:
100 self.assertEqual(expected, actual)
102 @patch("black.dump_to_file", dump_to_stderr)
103 def test_self(self) -> None:
104 source, expected = read_data("test_black")
106 self.assertFormatEqual(expected, actual)
107 black.assert_equivalent(source, actual)
108 black.assert_stable(source, actual, line_length=ll)
109 self.assertFalse(ff(THIS_FILE))
111 @patch("black.dump_to_file", dump_to_stderr)
112 def test_black(self) -> None:
113 source, expected = read_data("../black")
115 self.assertFormatEqual(expected, actual)
116 black.assert_equivalent(source, actual)
117 black.assert_stable(source, actual, line_length=ll)
118 self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
120 def test_piping(self) -> None:
121 source, expected = read_data("../black")
122 hold_stdin, hold_stdout = sys.stdin, sys.stdout
124 sys.stdin, sys.stdout = StringIO(source), StringIO()
125 sys.stdin.name = "<stdin>"
126 black.format_stdin_to_stdout(
127 line_length=ll, fast=True, write_back=black.WriteBack.YES
130 actual = sys.stdout.read()
132 sys.stdin, sys.stdout = hold_stdin, hold_stdout
133 self.assertFormatEqual(expected, actual)
134 black.assert_equivalent(source, actual)
135 black.assert_stable(source, actual, line_length=ll)
137 def test_piping_diff(self) -> None:
138 source, _ = read_data("expression.py")
139 expected, _ = read_data("expression.diff")
140 hold_stdin, hold_stdout = sys.stdin, sys.stdout
142 sys.stdin, sys.stdout = StringIO(source), StringIO()
143 sys.stdin.name = "<stdin>"
144 black.format_stdin_to_stdout(
145 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
148 actual = sys.stdout.read()
150 sys.stdin, sys.stdout = hold_stdin, hold_stdout
151 actual = actual.rstrip() + "\n" # the diff output has a trailing space
152 self.assertEqual(expected, actual)
154 @patch("black.dump_to_file", dump_to_stderr)
155 def test_setup(self) -> None:
156 source, expected = read_data("../setup")
158 self.assertFormatEqual(expected, actual)
159 black.assert_equivalent(source, actual)
160 black.assert_stable(source, actual, line_length=ll)
161 self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
163 @patch("black.dump_to_file", dump_to_stderr)
164 def test_function(self) -> None:
165 source, expected = read_data("function")
167 self.assertFormatEqual(expected, actual)
168 black.assert_equivalent(source, actual)
169 black.assert_stable(source, actual, line_length=ll)
171 @patch("black.dump_to_file", dump_to_stderr)
172 def test_expression(self) -> None:
173 source, expected = read_data("expression")
175 self.assertFormatEqual(expected, actual)
176 black.assert_equivalent(source, actual)
177 black.assert_stable(source, actual, line_length=ll)
179 def test_expression_ff(self) -> None:
180 source, expected = read_data("expression")
181 tmp_file = Path(black.dump_to_file(source))
183 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
184 with open(tmp_file, encoding="utf8") as f:
188 self.assertFormatEqual(expected, actual)
189 with patch("black.dump_to_file", dump_to_stderr):
190 black.assert_equivalent(source, actual)
191 black.assert_stable(source, actual, line_length=ll)
193 def test_expression_diff(self) -> None:
194 source, _ = read_data("expression.py")
195 expected, _ = read_data("expression.diff")
196 tmp_file = Path(black.dump_to_file(source))
197 hold_stdout = sys.stdout
199 sys.stdout = StringIO()
200 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
202 actual = sys.stdout.read()
203 actual = actual.replace(tmp_file.name, "<stdin>")
205 sys.stdout = hold_stdout
207 actual = actual.rstrip() + "\n" # the diff output has a trailing space
208 if expected != actual:
209 dump = black.dump_to_file(actual)
211 f"Expected diff isn't equal to the actual. If you made changes "
212 f"to expression.py and this is an anticipated difference, "
213 f"overwrite tests/expression.diff with {dump}"
215 self.assertEqual(expected, actual, msg)
217 @patch("black.dump_to_file", dump_to_stderr)
218 def test_fstring(self) -> None:
219 source, expected = read_data("fstring")
221 self.assertFormatEqual(expected, actual)
222 black.assert_equivalent(source, actual)
223 black.assert_stable(source, actual, line_length=ll)
225 @patch("black.dump_to_file", dump_to_stderr)
226 def test_string_quotes(self) -> None:
227 source, expected = read_data("string_quotes")
229 self.assertFormatEqual(expected, actual)
230 black.assert_equivalent(source, actual)
231 black.assert_stable(source, actual, line_length=ll)
233 @patch("black.dump_to_file", dump_to_stderr)
234 def test_comments(self) -> None:
235 source, expected = read_data("comments")
237 self.assertFormatEqual(expected, actual)
238 black.assert_equivalent(source, actual)
239 black.assert_stable(source, actual, line_length=ll)
241 @patch("black.dump_to_file", dump_to_stderr)
242 def test_comments2(self) -> None:
243 source, expected = read_data("comments2")
245 self.assertFormatEqual(expected, actual)
246 black.assert_equivalent(source, actual)
247 black.assert_stable(source, actual, line_length=ll)
249 @patch("black.dump_to_file", dump_to_stderr)
250 def test_comments3(self) -> None:
251 source, expected = read_data("comments3")
253 self.assertFormatEqual(expected, actual)
254 black.assert_equivalent(source, actual)
255 black.assert_stable(source, actual, line_length=ll)
257 @patch("black.dump_to_file", dump_to_stderr)
258 def test_comments4(self) -> None:
259 source, expected = read_data("comments4")
261 self.assertFormatEqual(expected, actual)
262 black.assert_equivalent(source, actual)
263 black.assert_stable(source, actual, line_length=ll)
265 @patch("black.dump_to_file", dump_to_stderr)
266 def test_cantfit(self) -> None:
267 source, expected = read_data("cantfit")
269 self.assertFormatEqual(expected, actual)
270 black.assert_equivalent(source, actual)
271 black.assert_stable(source, actual, line_length=ll)
273 @patch("black.dump_to_file", dump_to_stderr)
274 def test_import_spacing(self) -> None:
275 source, expected = read_data("import_spacing")
277 self.assertFormatEqual(expected, actual)
278 black.assert_equivalent(source, actual)
279 black.assert_stable(source, actual, line_length=ll)
281 @patch("black.dump_to_file", dump_to_stderr)
282 def test_composition(self) -> None:
283 source, expected = read_data("composition")
285 self.assertFormatEqual(expected, actual)
286 black.assert_equivalent(source, actual)
287 black.assert_stable(source, actual, line_length=ll)
289 @patch("black.dump_to_file", dump_to_stderr)
290 def test_empty_lines(self) -> None:
291 source, expected = read_data("empty_lines")
293 self.assertFormatEqual(expected, actual)
294 black.assert_equivalent(source, actual)
295 black.assert_stable(source, actual, line_length=ll)
297 @patch("black.dump_to_file", dump_to_stderr)
298 def test_python2(self) -> None:
299 source, expected = read_data("python2")
301 self.assertFormatEqual(expected, actual)
302 # black.assert_equivalent(source, actual)
303 black.assert_stable(source, actual, line_length=ll)
305 @patch("black.dump_to_file", dump_to_stderr)
306 def test_fmtonoff(self) -> None:
307 source, expected = read_data("fmtonoff")
309 self.assertFormatEqual(expected, actual)
310 black.assert_equivalent(source, actual)
311 black.assert_stable(source, actual, line_length=ll)
313 def test_report(self) -> None:
314 report = black.Report()
318 def out(msg: str, **kwargs: Any) -> None:
319 out_lines.append(msg)
321 def err(msg: str, **kwargs: Any) -> None:
322 err_lines.append(msg)
324 with patch("black.out", out), patch("black.err", err):
325 report.done(Path("f1"), black.Changed.NO)
326 self.assertEqual(len(out_lines), 1)
327 self.assertEqual(len(err_lines), 0)
328 self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
329 self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
330 self.assertEqual(report.return_code, 0)
331 report.done(Path("f2"), black.Changed.YES)
332 self.assertEqual(len(out_lines), 2)
333 self.assertEqual(len(err_lines), 0)
334 self.assertEqual(out_lines[-1], "reformatted f2")
336 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
338 report.done(Path("f3"), black.Changed.CACHED)
339 self.assertEqual(len(out_lines), 3)
340 self.assertEqual(len(err_lines), 0)
342 out_lines[-1], "f3 wasn't modified on disk since last run."
345 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
347 self.assertEqual(report.return_code, 0)
349 self.assertEqual(report.return_code, 1)
351 report.failed(Path("e1"), "boom")
352 self.assertEqual(len(out_lines), 3)
353 self.assertEqual(len(err_lines), 1)
354 self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
356 unstyle(str(report)),
357 "1 file reformatted, 2 files left unchanged, "
358 "1 file failed to reformat.",
360 self.assertEqual(report.return_code, 123)
361 report.done(Path("f3"), black.Changed.YES)
362 self.assertEqual(len(out_lines), 4)
363 self.assertEqual(len(err_lines), 1)
364 self.assertEqual(out_lines[-1], "reformatted f3")
366 unstyle(str(report)),
367 "2 files reformatted, 2 files left unchanged, "
368 "1 file failed to reformat.",
370 self.assertEqual(report.return_code, 123)
371 report.failed(Path("e2"), "boom")
372 self.assertEqual(len(out_lines), 4)
373 self.assertEqual(len(err_lines), 2)
374 self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
376 unstyle(str(report)),
377 "2 files reformatted, 2 files left unchanged, "
378 "2 files failed to reformat.",
380 self.assertEqual(report.return_code, 123)
381 report.done(Path("f4"), black.Changed.NO)
382 self.assertEqual(len(out_lines), 5)
383 self.assertEqual(len(err_lines), 2)
384 self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
386 unstyle(str(report)),
387 "2 files reformatted, 3 files left unchanged, "
388 "2 files failed to reformat.",
390 self.assertEqual(report.return_code, 123)
393 unstyle(str(report)),
394 "2 files would be reformatted, 3 files would be left unchanged, "
395 "2 files would fail to reformat.",
398 def test_is_python36(self) -> None:
399 node = black.lib2to3_parse("def f(*, arg): ...\n")
400 self.assertFalse(black.is_python36(node))
401 node = black.lib2to3_parse("def f(*, arg,): ...\n")
402 self.assertTrue(black.is_python36(node))
403 node = black.lib2to3_parse("def f(*, arg): f'string'\n")
404 self.assertTrue(black.is_python36(node))
405 source, expected = read_data("function")
406 node = black.lib2to3_parse(source)
407 self.assertTrue(black.is_python36(node))
408 node = black.lib2to3_parse(expected)
409 self.assertTrue(black.is_python36(node))
410 source, expected = read_data("expression")
411 node = black.lib2to3_parse(source)
412 self.assertFalse(black.is_python36(node))
413 node = black.lib2to3_parse(expected)
414 self.assertFalse(black.is_python36(node))
416 def test_debug_visitor(self) -> None:
417 source, _ = read_data("debug_visitor.py")
418 expected, _ = read_data("debug_visitor.out")
422 def out(msg: str, **kwargs: Any) -> None:
423 out_lines.append(msg)
425 def err(msg: str, **kwargs: Any) -> None:
426 err_lines.append(msg)
428 with patch("black.out", out), patch("black.err", err):
429 black.DebugVisitor.show(source)
430 actual = "\n".join(out_lines) + "\n"
432 if expected != actual:
433 log_name = black.dump_to_file(*out_lines)
437 f"AST print out is different. Actual version dumped to {log_name}",
440 def test_format_file_contents(self) -> None:
442 with self.assertRaises(black.NothingChanged):
443 black.format_file_contents(empty, line_length=ll, fast=False)
445 with self.assertRaises(black.NothingChanged):
446 black.format_file_contents(just_nl, line_length=ll, fast=False)
447 same = "l = [1, 2, 3]\n"
448 with self.assertRaises(black.NothingChanged):
449 black.format_file_contents(same, line_length=ll, fast=False)
450 different = "l = [1,2,3]"
452 actual = black.format_file_contents(different, line_length=ll, fast=False)
453 self.assertEqual(expected, actual)
454 invalid = "return if you can"
455 with self.assertRaises(ValueError) as e:
456 black.format_file_contents(invalid, line_length=ll, fast=False)
457 self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
459 def test_endmarker(self) -> None:
460 n = black.lib2to3_parse("\n")
461 self.assertEqual(n.type, black.syms.file_input)
462 self.assertEqual(len(n.children), 1)
463 self.assertEqual(n.children[0].type, black.token.ENDMARKER)
465 @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
466 def test_assertFormatEqual(self) -> None:
470 def out(msg: str, **kwargs: Any) -> None:
471 out_lines.append(msg)
473 def err(msg: str, **kwargs: Any) -> None:
474 err_lines.append(msg)
476 with patch("black.out", out), patch("black.err", err):
477 with self.assertRaises(AssertionError):
478 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
480 out_str = "".join(out_lines)
481 self.assertTrue("Expected tree:" in out_str)
482 self.assertTrue("Actual tree:" in out_str)
483 self.assertEqual("".join(err_lines), "")
485 def test_cache_broken_file(self) -> None:
486 with cache_dir() as workspace:
487 with black.CACHE_FILE.open("w") as fobj:
488 fobj.write("this is not a pickle")
489 self.assertEqual(black.read_cache(), {})
490 src = (workspace / "test.py").resolve()
491 with src.open("w") as fobj:
492 fobj.write("print('hello')")
493 result = CliRunner().invoke(black.main, [str(src)])
494 self.assertEqual(result.exit_code, 0)
495 cache = black.read_cache()
496 self.assertIn(src, cache)
498 def test_cache_single_file_already_cached(self) -> None:
499 with cache_dir() as workspace:
500 src = (workspace / "test.py").resolve()
501 with src.open("w") as fobj:
502 fobj.write("print('hello')")
503 black.write_cache({}, [src])
504 result = CliRunner().invoke(black.main, [str(src)])
505 self.assertEqual(result.exit_code, 0)
506 with src.open("r") as fobj:
507 self.assertEqual(fobj.read(), "print('hello')")
509 @event_loop(close=False)
510 def test_cache_multiple_files(self) -> None:
511 with cache_dir() as workspace, patch(
512 "black.ProcessPoolExecutor", new=ThreadPoolExecutor
514 one = (workspace / "one.py").resolve()
515 with one.open("w") as fobj:
516 fobj.write("print('hello')")
517 two = (workspace / "two.py").resolve()
518 with two.open("w") as fobj:
519 fobj.write("print('hello')")
520 black.write_cache({}, [one])
521 result = CliRunner().invoke(black.main, [str(workspace)])
522 self.assertEqual(result.exit_code, 0)
523 with one.open("r") as fobj:
524 self.assertEqual(fobj.read(), "print('hello')")
525 with two.open("r") as fobj:
526 self.assertEqual(fobj.read(), 'print("hello")\n')
527 cache = black.read_cache()
528 self.assertIn(one, cache)
529 self.assertIn(two, cache)
531 def test_no_cache_when_writeback_diff(self) -> None:
532 with cache_dir() as workspace:
533 src = (workspace / "test.py").resolve()
534 with src.open("w") as fobj:
535 fobj.write("print('hello')")
536 result = CliRunner().invoke(black.main, [str(src), "--diff"])
537 self.assertEqual(result.exit_code, 0)
538 self.assertFalse(black.CACHE_FILE.exists())
540 def test_no_cache_when_stdin(self) -> None:
542 result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
543 self.assertEqual(result.exit_code, 0)
544 self.assertFalse(black.CACHE_FILE.exists())
546 def test_read_cache_no_cachefile(self) -> None:
548 self.assertEqual(black.read_cache(), {})
550 def test_write_cache_read_cache(self) -> None:
551 with cache_dir() as workspace:
552 src = (workspace / "test.py").resolve()
554 black.write_cache({}, [src])
555 cache = black.read_cache()
556 self.assertIn(src, cache)
557 self.assertEqual(cache[src], black.get_cache_info(src))
559 def test_filter_cached(self) -> None:
560 with TemporaryDirectory() as workspace:
561 path = Path(workspace)
562 uncached = (path / "uncached").resolve()
563 cached = (path / "cached").resolve()
564 cached_but_changed = (path / "changed").resolve()
567 cached_but_changed.touch()
568 cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
569 todo, done = black.filter_cached(
570 cache, [uncached, cached, cached_but_changed]
572 self.assertEqual(todo, [uncached, cached_but_changed])
573 self.assertEqual(done, [cached])
575 def test_write_cache_creates_directory_if_needed(self) -> None:
576 with cache_dir(exists=False) as workspace:
577 self.assertFalse(workspace.exists())
578 black.write_cache({}, [])
579 self.assertTrue(workspace.exists())
581 @event_loop(close=False)
582 def test_failed_formatting_does_not_get_cached(self) -> None:
583 with cache_dir() as workspace, patch(
584 "black.ProcessPoolExecutor", new=ThreadPoolExecutor
586 failing = (workspace / "failing.py").resolve()
587 with failing.open("w") as fobj:
588 fobj.write("not actually python")
589 clean = (workspace / "clean.py").resolve()
590 with clean.open("w") as fobj:
591 fobj.write('print("hello")\n')
592 result = CliRunner().invoke(black.main, [str(workspace)])
593 self.assertEqual(result.exit_code, 123)
594 cache = black.read_cache()
595 self.assertNotIn(failing, cache)
596 self.assertIn(clean, cache)
598 def test_write_cache_write_fail(self) -> None:
599 with cache_dir(), patch.object(Path, "open") as mock:
600 mock.side_effect = OSError
601 black.write_cache({}, [])
603 def test_check_diff_use_together(self) -> None:
605 # Files which will be reformatted.
606 src1 = (THIS_DIR / "string_quotes.py").resolve()
607 result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
608 self.assertEqual(result.exit_code, 1)
610 # Files which will not be reformatted.
611 src2 = (THIS_DIR / "composition.py").resolve()
612 result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
613 self.assertEqual(result.exit_code, 0)
615 # Multi file command.
616 result = CliRunner().invoke(
617 black.main, [str(src1), str(src2), "--diff", "--check"]
619 self.assertEqual(result.exit_code, 1)
622 if __name__ == "__main__":