All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
8 from pathlib import Path
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
13 from unittest.mock import patch
15 from click import unstyle
16 from click.testing import CliRunner
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
28 def dump_to_stderr(*output: str) -> str:
29 return "\n" + "\n".join(output) + "\n"
32 def read_data(name: str) -> Tuple[str, str]:
33 """read_data('test_name') -> 'input', 'output'"""
34 if not name.endswith((".py", ".out", ".diff")):
36 _input: List[str] = []
37 _output: List[str] = []
38 with open(THIS_DIR / name, "r", encoding="utf8") as test:
39 lines = test.readlines()
42 line = line.replace(EMPTY_LINE, "")
43 if line.rstrip() == "# output":
48 if _input and not _output:
49 # If there's no output marker, treat the entire file as already pre-formatted.
51 return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56 with TemporaryDirectory() as workspace:
57 cache_dir = Path(workspace)
59 cache_dir = cache_dir / "new"
60 with patch("black.CACHE_DIR", cache_dir):
65 def event_loop(close: bool) -> Iterator[None]:
66 policy = asyncio.get_event_loop_policy()
67 old_loop = policy.get_event_loop()
68 loop = policy.new_event_loop()
69 asyncio.set_event_loop(loop)
74 policy.set_event_loop(old_loop)
79 class BlackTestCase(unittest.TestCase):
82 def assertFormatEqual(self, expected: str, actual: str) -> None:
83 if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84 bdv: black.DebugVisitor[Any]
85 black.out("Expected tree:", fg="green")
87 exp_node = black.lib2to3_parse(expected)
88 bdv = black.DebugVisitor()
89 list(bdv.visit(exp_node))
90 except Exception as ve:
92 black.out("Actual tree:", fg="red")
94 exp_node = black.lib2to3_parse(actual)
95 bdv = black.DebugVisitor()
96 list(bdv.visit(exp_node))
97 except Exception as ve:
99 self.assertEqual(expected, actual)
101 @patch("black.dump_to_file", dump_to_stderr)
102 def test_self(self) -> None:
103 source, expected = read_data("test_black")
105 self.assertFormatEqual(expected, actual)
106 black.assert_equivalent(source, actual)
107 black.assert_stable(source, actual, line_length=ll)
108 self.assertFalse(ff(THIS_FILE))
110 @patch("black.dump_to_file", dump_to_stderr)
111 def test_black(self) -> None:
112 source, expected = read_data("../black")
114 self.assertFormatEqual(expected, actual)
115 black.assert_equivalent(source, actual)
116 black.assert_stable(source, actual, line_length=ll)
117 self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
119 def test_piping(self) -> None:
120 source, expected = read_data("../black")
121 hold_stdin, hold_stdout = sys.stdin, sys.stdout
123 sys.stdin, sys.stdout = StringIO(source), StringIO()
124 sys.stdin.name = "<stdin>"
125 black.format_stdin_to_stdout(
126 line_length=ll, fast=True, write_back=black.WriteBack.YES
129 actual = sys.stdout.read()
131 sys.stdin, sys.stdout = hold_stdin, hold_stdout
132 self.assertFormatEqual(expected, actual)
133 black.assert_equivalent(source, actual)
134 black.assert_stable(source, actual, line_length=ll)
136 def test_piping_diff(self) -> None:
137 source, _ = read_data("expression.py")
138 expected, _ = read_data("expression.diff")
139 hold_stdin, hold_stdout = sys.stdin, sys.stdout
141 sys.stdin, sys.stdout = StringIO(source), StringIO()
142 sys.stdin.name = "<stdin>"
143 black.format_stdin_to_stdout(
144 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
147 actual = sys.stdout.read()
149 sys.stdin, sys.stdout = hold_stdin, hold_stdout
150 actual = actual.rstrip() + "\n" # the diff output has a trailing space
151 self.assertEqual(expected, actual)
153 @patch("black.dump_to_file", dump_to_stderr)
154 def test_setup(self) -> None:
155 source, expected = read_data("../setup")
157 self.assertFormatEqual(expected, actual)
158 black.assert_equivalent(source, actual)
159 black.assert_stable(source, actual, line_length=ll)
160 self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
162 @patch("black.dump_to_file", dump_to_stderr)
163 def test_function(self) -> None:
164 source, expected = read_data("function")
166 self.assertFormatEqual(expected, actual)
167 black.assert_equivalent(source, actual)
168 black.assert_stable(source, actual, line_length=ll)
170 @patch("black.dump_to_file", dump_to_stderr)
171 def test_expression(self) -> None:
172 source, expected = read_data("expression")
174 self.assertFormatEqual(expected, actual)
175 black.assert_equivalent(source, actual)
176 black.assert_stable(source, actual, line_length=ll)
178 def test_expression_ff(self) -> None:
179 source, expected = read_data("expression")
180 tmp_file = Path(black.dump_to_file(source))
182 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
183 with open(tmp_file, encoding="utf8") as f:
187 self.assertFormatEqual(expected, actual)
188 with patch("black.dump_to_file", dump_to_stderr):
189 black.assert_equivalent(source, actual)
190 black.assert_stable(source, actual, line_length=ll)
192 def test_expression_diff(self) -> None:
193 source, _ = read_data("expression.py")
194 expected, _ = read_data("expression.diff")
195 tmp_file = Path(black.dump_to_file(source))
196 hold_stdout = sys.stdout
198 sys.stdout = StringIO()
199 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
201 actual = sys.stdout.read()
202 actual = actual.replace(str(tmp_file), "<stdin>")
204 sys.stdout = hold_stdout
206 actual = actual.rstrip() + "\n" # the diff output has a trailing space
207 if expected != actual:
208 dump = black.dump_to_file(actual)
210 f"Expected diff isn't equal to the actual. If you made changes "
211 f"to expression.py and this is an anticipated difference, "
212 f"overwrite tests/expression.diff with {dump}"
214 self.assertEqual(expected, actual, msg)
216 @patch("black.dump_to_file", dump_to_stderr)
217 def test_fstring(self) -> None:
218 source, expected = read_data("fstring")
220 self.assertFormatEqual(expected, actual)
221 black.assert_equivalent(source, actual)
222 black.assert_stable(source, actual, line_length=ll)
224 @patch("black.dump_to_file", dump_to_stderr)
225 def test_string_quotes(self) -> None:
226 source, expected = read_data("string_quotes")
228 self.assertFormatEqual(expected, actual)
229 black.assert_equivalent(source, actual)
230 black.assert_stable(source, actual, line_length=ll)
232 @patch("black.dump_to_file", dump_to_stderr)
233 def test_comments(self) -> None:
234 source, expected = read_data("comments")
236 self.assertFormatEqual(expected, actual)
237 black.assert_equivalent(source, actual)
238 black.assert_stable(source, actual, line_length=ll)
240 @patch("black.dump_to_file", dump_to_stderr)
241 def test_comments2(self) -> None:
242 source, expected = read_data("comments2")
244 self.assertFormatEqual(expected, actual)
245 black.assert_equivalent(source, actual)
246 black.assert_stable(source, actual, line_length=ll)
248 @patch("black.dump_to_file", dump_to_stderr)
249 def test_comments3(self) -> None:
250 source, expected = read_data("comments3")
252 self.assertFormatEqual(expected, actual)
253 black.assert_equivalent(source, actual)
254 black.assert_stable(source, actual, line_length=ll)
256 @patch("black.dump_to_file", dump_to_stderr)
257 def test_comments4(self) -> None:
258 source, expected = read_data("comments4")
260 self.assertFormatEqual(expected, actual)
261 black.assert_equivalent(source, actual)
262 black.assert_stable(source, actual, line_length=ll)
264 @patch("black.dump_to_file", dump_to_stderr)
265 def test_comments5(self) -> None:
266 source, expected = read_data("comments5")
268 self.assertFormatEqual(expected, actual)
269 black.assert_equivalent(source, actual)
270 black.assert_stable(source, actual, line_length=ll)
272 @patch("black.dump_to_file", dump_to_stderr)
273 def test_cantfit(self) -> None:
274 source, expected = read_data("cantfit")
276 self.assertFormatEqual(expected, actual)
277 black.assert_equivalent(source, actual)
278 black.assert_stable(source, actual, line_length=ll)
280 @patch("black.dump_to_file", dump_to_stderr)
281 def test_import_spacing(self) -> None:
282 source, expected = read_data("import_spacing")
284 self.assertFormatEqual(expected, actual)
285 black.assert_equivalent(source, actual)
286 black.assert_stable(source, actual, line_length=ll)
288 @patch("black.dump_to_file", dump_to_stderr)
289 def test_composition(self) -> None:
290 source, expected = read_data("composition")
292 self.assertFormatEqual(expected, actual)
293 black.assert_equivalent(source, actual)
294 black.assert_stable(source, actual, line_length=ll)
296 @patch("black.dump_to_file", dump_to_stderr)
297 def test_empty_lines(self) -> None:
298 source, expected = read_data("empty_lines")
300 self.assertFormatEqual(expected, actual)
301 black.assert_equivalent(source, actual)
302 black.assert_stable(source, actual, line_length=ll)
304 @patch("black.dump_to_file", dump_to_stderr)
305 def test_python2(self) -> None:
306 source, expected = read_data("python2")
308 self.assertFormatEqual(expected, actual)
309 # black.assert_equivalent(source, actual)
310 black.assert_stable(source, actual, line_length=ll)
312 @patch("black.dump_to_file", dump_to_stderr)
313 def test_fmtonoff(self) -> None:
314 source, expected = read_data("fmtonoff")
316 self.assertFormatEqual(expected, actual)
317 black.assert_equivalent(source, actual)
318 black.assert_stable(source, actual, line_length=ll)
320 def test_report(self) -> None:
321 report = black.Report()
325 def out(msg: str, **kwargs: Any) -> None:
326 out_lines.append(msg)
328 def err(msg: str, **kwargs: Any) -> None:
329 err_lines.append(msg)
331 with patch("black.out", out), patch("black.err", err):
332 report.done(Path("f1"), black.Changed.NO)
333 self.assertEqual(len(out_lines), 1)
334 self.assertEqual(len(err_lines), 0)
335 self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
336 self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
337 self.assertEqual(report.return_code, 0)
338 report.done(Path("f2"), black.Changed.YES)
339 self.assertEqual(len(out_lines), 2)
340 self.assertEqual(len(err_lines), 0)
341 self.assertEqual(out_lines[-1], "reformatted f2")
343 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
345 report.done(Path("f3"), black.Changed.CACHED)
346 self.assertEqual(len(out_lines), 3)
347 self.assertEqual(len(err_lines), 0)
349 out_lines[-1], "f3 wasn't modified on disk since last run."
352 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
354 self.assertEqual(report.return_code, 0)
356 self.assertEqual(report.return_code, 1)
358 report.failed(Path("e1"), "boom")
359 self.assertEqual(len(out_lines), 3)
360 self.assertEqual(len(err_lines), 1)
361 self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
363 unstyle(str(report)),
364 "1 file reformatted, 2 files left unchanged, "
365 "1 file failed to reformat.",
367 self.assertEqual(report.return_code, 123)
368 report.done(Path("f3"), black.Changed.YES)
369 self.assertEqual(len(out_lines), 4)
370 self.assertEqual(len(err_lines), 1)
371 self.assertEqual(out_lines[-1], "reformatted f3")
373 unstyle(str(report)),
374 "2 files reformatted, 2 files left unchanged, "
375 "1 file failed to reformat.",
377 self.assertEqual(report.return_code, 123)
378 report.failed(Path("e2"), "boom")
379 self.assertEqual(len(out_lines), 4)
380 self.assertEqual(len(err_lines), 2)
381 self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
383 unstyle(str(report)),
384 "2 files reformatted, 2 files left unchanged, "
385 "2 files failed to reformat.",
387 self.assertEqual(report.return_code, 123)
388 report.done(Path("f4"), black.Changed.NO)
389 self.assertEqual(len(out_lines), 5)
390 self.assertEqual(len(err_lines), 2)
391 self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
393 unstyle(str(report)),
394 "2 files reformatted, 3 files left unchanged, "
395 "2 files failed to reformat.",
397 self.assertEqual(report.return_code, 123)
400 unstyle(str(report)),
401 "2 files would be reformatted, 3 files would be left unchanged, "
402 "2 files would fail to reformat.",
405 def test_is_python36(self) -> None:
406 node = black.lib2to3_parse("def f(*, arg): ...\n")
407 self.assertFalse(black.is_python36(node))
408 node = black.lib2to3_parse("def f(*, arg,): ...\n")
409 self.assertTrue(black.is_python36(node))
410 node = black.lib2to3_parse("def f(*, arg): f'string'\n")
411 self.assertTrue(black.is_python36(node))
412 source, expected = read_data("function")
413 node = black.lib2to3_parse(source)
414 self.assertTrue(black.is_python36(node))
415 node = black.lib2to3_parse(expected)
416 self.assertTrue(black.is_python36(node))
417 source, expected = read_data("expression")
418 node = black.lib2to3_parse(source)
419 self.assertFalse(black.is_python36(node))
420 node = black.lib2to3_parse(expected)
421 self.assertFalse(black.is_python36(node))
423 def test_debug_visitor(self) -> None:
424 source, _ = read_data("debug_visitor.py")
425 expected, _ = read_data("debug_visitor.out")
429 def out(msg: str, **kwargs: Any) -> None:
430 out_lines.append(msg)
432 def err(msg: str, **kwargs: Any) -> None:
433 err_lines.append(msg)
435 with patch("black.out", out), patch("black.err", err):
436 black.DebugVisitor.show(source)
437 actual = "\n".join(out_lines) + "\n"
439 if expected != actual:
440 log_name = black.dump_to_file(*out_lines)
444 f"AST print out is different. Actual version dumped to {log_name}",
447 def test_format_file_contents(self) -> None:
449 with self.assertRaises(black.NothingChanged):
450 black.format_file_contents(empty, line_length=ll, fast=False)
452 with self.assertRaises(black.NothingChanged):
453 black.format_file_contents(just_nl, line_length=ll, fast=False)
454 same = "l = [1, 2, 3]\n"
455 with self.assertRaises(black.NothingChanged):
456 black.format_file_contents(same, line_length=ll, fast=False)
457 different = "l = [1,2,3]"
459 actual = black.format_file_contents(different, line_length=ll, fast=False)
460 self.assertEqual(expected, actual)
461 invalid = "return if you can"
462 with self.assertRaises(ValueError) as e:
463 black.format_file_contents(invalid, line_length=ll, fast=False)
464 self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
466 def test_endmarker(self) -> None:
467 n = black.lib2to3_parse("\n")
468 self.assertEqual(n.type, black.syms.file_input)
469 self.assertEqual(len(n.children), 1)
470 self.assertEqual(n.children[0].type, black.token.ENDMARKER)
472 @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
473 def test_assertFormatEqual(self) -> None:
477 def out(msg: str, **kwargs: Any) -> None:
478 out_lines.append(msg)
480 def err(msg: str, **kwargs: Any) -> None:
481 err_lines.append(msg)
483 with patch("black.out", out), patch("black.err", err):
484 with self.assertRaises(AssertionError):
485 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
487 out_str = "".join(out_lines)
488 self.assertTrue("Expected tree:" in out_str)
489 self.assertTrue("Actual tree:" in out_str)
490 self.assertEqual("".join(err_lines), "")
492 def test_cache_broken_file(self) -> None:
493 with cache_dir() as workspace:
494 cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
495 with cache_file.open("w") as fobj:
496 fobj.write("this is not a pickle")
497 self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
498 src = (workspace / "test.py").resolve()
499 with src.open("w") as fobj:
500 fobj.write("print('hello')")
501 result = CliRunner().invoke(black.main, [str(src)])
502 self.assertEqual(result.exit_code, 0)
503 cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
504 self.assertIn(src, cache)
506 def test_cache_single_file_already_cached(self) -> None:
507 with cache_dir() as workspace:
508 src = (workspace / "test.py").resolve()
509 with src.open("w") as fobj:
510 fobj.write("print('hello')")
511 black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
512 result = CliRunner().invoke(black.main, [str(src)])
513 self.assertEqual(result.exit_code, 0)
514 with src.open("r") as fobj:
515 self.assertEqual(fobj.read(), "print('hello')")
517 @event_loop(close=False)
518 def test_cache_multiple_files(self) -> None:
519 with cache_dir() as workspace, patch(
520 "black.ProcessPoolExecutor", new=ThreadPoolExecutor
522 one = (workspace / "one.py").resolve()
523 with one.open("w") as fobj:
524 fobj.write("print('hello')")
525 two = (workspace / "two.py").resolve()
526 with two.open("w") as fobj:
527 fobj.write("print('hello')")
528 black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH)
529 result = CliRunner().invoke(black.main, [str(workspace)])
530 self.assertEqual(result.exit_code, 0)
531 with one.open("r") as fobj:
532 self.assertEqual(fobj.read(), "print('hello')")
533 with two.open("r") as fobj:
534 self.assertEqual(fobj.read(), 'print("hello")\n')
535 cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
536 self.assertIn(one, cache)
537 self.assertIn(two, cache)
539 def test_no_cache_when_writeback_diff(self) -> None:
540 with cache_dir() as workspace:
541 src = (workspace / "test.py").resolve()
542 with src.open("w") as fobj:
543 fobj.write("print('hello')")
544 result = CliRunner().invoke(black.main, [str(src), "--diff"])
545 self.assertEqual(result.exit_code, 0)
546 cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
547 self.assertFalse(cache_file.exists())
549 def test_no_cache_when_stdin(self) -> None:
551 result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
552 self.assertEqual(result.exit_code, 0)
553 cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
554 self.assertFalse(cache_file.exists())
556 def test_read_cache_no_cachefile(self) -> None:
558 self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
560 def test_write_cache_read_cache(self) -> None:
561 with cache_dir() as workspace:
562 src = (workspace / "test.py").resolve()
564 black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
565 cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
566 self.assertIn(src, cache)
567 self.assertEqual(cache[src], black.get_cache_info(src))
569 def test_filter_cached(self) -> None:
570 with TemporaryDirectory() as workspace:
571 path = Path(workspace)
572 uncached = (path / "uncached").resolve()
573 cached = (path / "cached").resolve()
574 cached_but_changed = (path / "changed").resolve()
577 cached_but_changed.touch()
578 cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
579 todo, done = black.filter_cached(
580 cache, [uncached, cached, cached_but_changed]
582 self.assertEqual(todo, [uncached, cached_but_changed])
583 self.assertEqual(done, [cached])
585 def test_write_cache_creates_directory_if_needed(self) -> None:
586 with cache_dir(exists=False) as workspace:
587 self.assertFalse(workspace.exists())
588 black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
589 self.assertTrue(workspace.exists())
591 @event_loop(close=False)
592 def test_failed_formatting_does_not_get_cached(self) -> None:
593 with cache_dir() as workspace, patch(
594 "black.ProcessPoolExecutor", new=ThreadPoolExecutor
596 failing = (workspace / "failing.py").resolve()
597 with failing.open("w") as fobj:
598 fobj.write("not actually python")
599 clean = (workspace / "clean.py").resolve()
600 with clean.open("w") as fobj:
601 fobj.write('print("hello")\n')
602 result = CliRunner().invoke(black.main, [str(workspace)])
603 self.assertEqual(result.exit_code, 123)
604 cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
605 self.assertNotIn(failing, cache)
606 self.assertIn(clean, cache)
608 def test_write_cache_write_fail(self) -> None:
609 with cache_dir(), patch.object(Path, "open") as mock:
610 mock.side_effect = OSError
611 black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
613 def test_check_diff_use_together(self) -> None:
615 # Files which will be reformatted.
616 src1 = (THIS_DIR / "string_quotes.py").resolve()
617 result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
618 self.assertEqual(result.exit_code, 1)
620 # Files which will not be reformatted.
621 src2 = (THIS_DIR / "composition.py").resolve()
622 result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
623 self.assertEqual(result.exit_code, 0)
625 # Multi file command.
626 result = CliRunner().invoke(
627 black.main, [str(src1), str(src2), "--diff", "--check"]
629 self.assertEqual(result.exit_code, 1)
631 def test_read_cache_line_lengths(self) -> None:
632 with cache_dir() as workspace:
633 path = (workspace / "file.py").resolve()
635 black.write_cache({}, [path], 1)
636 one = black.read_cache(1)
637 self.assertIn(path, one)
638 two = black.read_cache(2)
639 self.assertNotIn(path, two)
642 if __name__ == "__main__":