All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
8 from pathlib import Path
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
13 from unittest.mock import patch
15 from click import unstyle
16 from click.testing import CliRunner
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
28 def dump_to_stderr(*output: str) -> str:
29 return "\n" + "\n".join(output) + "\n"
32 def read_data(name: str) -> Tuple[str, str]:
33 """read_data('test_name') -> 'input', 'output'"""
34 if not name.endswith((".py", ".pyi", ".out", ".diff")):
36 _input: List[str] = []
37 _output: List[str] = []
38 with open(THIS_DIR / name, "r", encoding="utf8") as test:
39 lines = test.readlines()
42 line = line.replace(EMPTY_LINE, "")
43 if line.rstrip() == "# output":
48 if _input and not _output:
49 # If there's no output marker, treat the entire file as already pre-formatted.
51 return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56 with TemporaryDirectory() as workspace:
57 cache_dir = Path(workspace)
59 cache_dir = cache_dir / "new"
60 with patch("black.CACHE_DIR", cache_dir):
65 def event_loop(close: bool) -> Iterator[None]:
66 policy = asyncio.get_event_loop_policy()
67 old_loop = policy.get_event_loop()
68 loop = policy.new_event_loop()
69 asyncio.set_event_loop(loop)
74 policy.set_event_loop(old_loop)
79 class BlackTestCase(unittest.TestCase):
82 def assertFormatEqual(self, expected: str, actual: str) -> None:
83 if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84 bdv: black.DebugVisitor[Any]
85 black.out("Expected tree:", fg="green")
87 exp_node = black.lib2to3_parse(expected)
88 bdv = black.DebugVisitor()
89 list(bdv.visit(exp_node))
90 except Exception as ve:
92 black.out("Actual tree:", fg="red")
94 exp_node = black.lib2to3_parse(actual)
95 bdv = black.DebugVisitor()
96 list(bdv.visit(exp_node))
97 except Exception as ve:
99 self.assertEqual(expected, actual)
101 @patch("black.dump_to_file", dump_to_stderr)
102 def test_self(self) -> None:
103 source, expected = read_data("test_black")
105 self.assertFormatEqual(expected, actual)
106 black.assert_equivalent(source, actual)
107 black.assert_stable(source, actual, line_length=ll)
108 self.assertFalse(ff(THIS_FILE))
110 @patch("black.dump_to_file", dump_to_stderr)
111 def test_black(self) -> None:
112 source, expected = read_data("../black")
114 self.assertFormatEqual(expected, actual)
115 black.assert_equivalent(source, actual)
116 black.assert_stable(source, actual, line_length=ll)
117 self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
119 def test_piping(self) -> None:
120 source, expected = read_data("../black")
121 hold_stdin, hold_stdout = sys.stdin, sys.stdout
123 sys.stdin, sys.stdout = StringIO(source), StringIO()
124 sys.stdin.name = "<stdin>"
125 black.format_stdin_to_stdout(
126 line_length=ll, fast=True, write_back=black.WriteBack.YES
129 actual = sys.stdout.read()
131 sys.stdin, sys.stdout = hold_stdin, hold_stdout
132 self.assertFormatEqual(expected, actual)
133 black.assert_equivalent(source, actual)
134 black.assert_stable(source, actual, line_length=ll)
136 def test_piping_diff(self) -> None:
137 source, _ = read_data("expression.py")
138 expected, _ = read_data("expression.diff")
139 hold_stdin, hold_stdout = sys.stdin, sys.stdout
141 sys.stdin, sys.stdout = StringIO(source), StringIO()
142 sys.stdin.name = "<stdin>"
143 black.format_stdin_to_stdout(
144 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
147 actual = sys.stdout.read()
149 sys.stdin, sys.stdout = hold_stdin, hold_stdout
150 actual = actual.rstrip() + "\n" # the diff output has a trailing space
151 self.assertEqual(expected, actual)
153 @patch("black.dump_to_file", dump_to_stderr)
154 def test_setup(self) -> None:
155 source, expected = read_data("../setup")
157 self.assertFormatEqual(expected, actual)
158 black.assert_equivalent(source, actual)
159 black.assert_stable(source, actual, line_length=ll)
160 self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
162 @patch("black.dump_to_file", dump_to_stderr)
163 def test_function(self) -> None:
164 source, expected = read_data("function")
166 self.assertFormatEqual(expected, actual)
167 black.assert_equivalent(source, actual)
168 black.assert_stable(source, actual, line_length=ll)
170 @patch("black.dump_to_file", dump_to_stderr)
171 def test_function2(self) -> None:
172 source, expected = read_data("function2")
174 self.assertFormatEqual(expected, actual)
175 black.assert_equivalent(source, actual)
176 black.assert_stable(source, actual, line_length=ll)
178 @patch("black.dump_to_file", dump_to_stderr)
179 def test_expression(self) -> None:
180 source, expected = read_data("expression")
182 self.assertFormatEqual(expected, actual)
183 black.assert_equivalent(source, actual)
184 black.assert_stable(source, actual, line_length=ll)
186 def test_expression_ff(self) -> None:
187 source, expected = read_data("expression")
188 tmp_file = Path(black.dump_to_file(source))
190 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
191 with open(tmp_file, encoding="utf8") as f:
195 self.assertFormatEqual(expected, actual)
196 with patch("black.dump_to_file", dump_to_stderr):
197 black.assert_equivalent(source, actual)
198 black.assert_stable(source, actual, line_length=ll)
200 def test_expression_diff(self) -> None:
201 source, _ = read_data("expression.py")
202 expected, _ = read_data("expression.diff")
203 tmp_file = Path(black.dump_to_file(source))
204 hold_stdout = sys.stdout
206 sys.stdout = StringIO()
207 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
209 actual = sys.stdout.read()
210 actual = actual.replace(str(tmp_file), "<stdin>")
212 sys.stdout = hold_stdout
214 actual = actual.rstrip() + "\n" # the diff output has a trailing space
215 if expected != actual:
216 dump = black.dump_to_file(actual)
218 f"Expected diff isn't equal to the actual. If you made changes "
219 f"to expression.py and this is an anticipated difference, "
220 f"overwrite tests/expression.diff with {dump}"
222 self.assertEqual(expected, actual, msg)
224 @patch("black.dump_to_file", dump_to_stderr)
225 def test_fstring(self) -> None:
226 source, expected = read_data("fstring")
228 self.assertFormatEqual(expected, actual)
229 black.assert_equivalent(source, actual)
230 black.assert_stable(source, actual, line_length=ll)
232 @patch("black.dump_to_file", dump_to_stderr)
233 def test_string_quotes(self) -> None:
234 source, expected = read_data("string_quotes")
236 self.assertFormatEqual(expected, actual)
237 black.assert_equivalent(source, actual)
238 black.assert_stable(source, actual, line_length=ll)
240 @patch("black.dump_to_file", dump_to_stderr)
241 def test_slices(self) -> None:
242 source, expected = read_data("slices")
244 self.assertFormatEqual(expected, actual)
245 black.assert_equivalent(source, actual)
246 black.assert_stable(source, actual, line_length=ll)
248 @patch("black.dump_to_file", dump_to_stderr)
249 def test_comments(self) -> None:
250 source, expected = read_data("comments")
252 self.assertFormatEqual(expected, actual)
253 black.assert_equivalent(source, actual)
254 black.assert_stable(source, actual, line_length=ll)
256 @patch("black.dump_to_file", dump_to_stderr)
257 def test_comments2(self) -> None:
258 source, expected = read_data("comments2")
260 self.assertFormatEqual(expected, actual)
261 black.assert_equivalent(source, actual)
262 black.assert_stable(source, actual, line_length=ll)
264 @patch("black.dump_to_file", dump_to_stderr)
265 def test_comments3(self) -> None:
266 source, expected = read_data("comments3")
268 self.assertFormatEqual(expected, actual)
269 black.assert_equivalent(source, actual)
270 black.assert_stable(source, actual, line_length=ll)
272 @patch("black.dump_to_file", dump_to_stderr)
273 def test_comments4(self) -> None:
274 source, expected = read_data("comments4")
276 self.assertFormatEqual(expected, actual)
277 black.assert_equivalent(source, actual)
278 black.assert_stable(source, actual, line_length=ll)
280 @patch("black.dump_to_file", dump_to_stderr)
281 def test_comments5(self) -> None:
282 source, expected = read_data("comments5")
284 self.assertFormatEqual(expected, actual)
285 black.assert_equivalent(source, actual)
286 black.assert_stable(source, actual, line_length=ll)
288 @patch("black.dump_to_file", dump_to_stderr)
289 def test_cantfit(self) -> None:
290 source, expected = read_data("cantfit")
292 self.assertFormatEqual(expected, actual)
293 black.assert_equivalent(source, actual)
294 black.assert_stable(source, actual, line_length=ll)
296 @patch("black.dump_to_file", dump_to_stderr)
297 def test_import_spacing(self) -> None:
298 source, expected = read_data("import_spacing")
300 self.assertFormatEqual(expected, actual)
301 black.assert_equivalent(source, actual)
302 black.assert_stable(source, actual, line_length=ll)
304 @patch("black.dump_to_file", dump_to_stderr)
305 def test_composition(self) -> None:
306 source, expected = read_data("composition")
308 self.assertFormatEqual(expected, actual)
309 black.assert_equivalent(source, actual)
310 black.assert_stable(source, actual, line_length=ll)
312 @patch("black.dump_to_file", dump_to_stderr)
313 def test_empty_lines(self) -> None:
314 source, expected = read_data("empty_lines")
316 self.assertFormatEqual(expected, actual)
317 black.assert_equivalent(source, actual)
318 black.assert_stable(source, actual, line_length=ll)
320 @patch("black.dump_to_file", dump_to_stderr)
321 def test_string_prefixes(self) -> None:
322 source, expected = read_data("string_prefixes")
324 self.assertFormatEqual(expected, actual)
325 black.assert_equivalent(source, actual)
326 black.assert_stable(source, actual, line_length=ll)
328 @patch("black.dump_to_file", dump_to_stderr)
329 def test_python2(self) -> None:
330 source, expected = read_data("python2")
332 self.assertFormatEqual(expected, actual)
333 # black.assert_equivalent(source, actual)
334 black.assert_stable(source, actual, line_length=ll)
336 @patch("black.dump_to_file", dump_to_stderr)
337 def test_python2_unicode_literals(self) -> None:
338 source, expected = read_data("python2_unicode_literals")
340 self.assertFormatEqual(expected, actual)
341 black.assert_stable(source, actual, line_length=ll)
343 @patch("black.dump_to_file", dump_to_stderr)
344 def test_stub(self) -> None:
345 mode = black.FileMode.PYI
346 source, expected = read_data("stub.pyi")
347 actual = fs(source, mode=mode)
348 self.assertFormatEqual(expected, actual)
349 black.assert_stable(source, actual, line_length=ll, mode=mode)
351 @patch("black.dump_to_file", dump_to_stderr)
352 def test_fmtonoff(self) -> None:
353 source, expected = read_data("fmtonoff")
355 self.assertFormatEqual(expected, actual)
356 black.assert_equivalent(source, actual)
357 black.assert_stable(source, actual, line_length=ll)
359 @patch("black.dump_to_file", dump_to_stderr)
360 def test_remove_empty_parentheses_after_class(self) -> None:
361 source, expected = read_data("class_blank_parentheses")
363 self.assertFormatEqual(expected, actual)
364 black.assert_equivalent(source, actual)
365 black.assert_stable(source, actual, line_length=ll)
367 @patch("black.dump_to_file", dump_to_stderr)
368 def test_new_line_between_class_and_code(self) -> None:
369 source, expected = read_data("class_methods_new_line")
371 self.assertFormatEqual(expected, actual)
372 black.assert_equivalent(source, actual)
373 black.assert_stable(source, actual, line_length=ll)
375 def test_report(self) -> None:
376 report = black.Report()
380 def out(msg: str, **kwargs: Any) -> None:
381 out_lines.append(msg)
383 def err(msg: str, **kwargs: Any) -> None:
384 err_lines.append(msg)
386 with patch("black.out", out), patch("black.err", err):
387 report.done(Path("f1"), black.Changed.NO)
388 self.assertEqual(len(out_lines), 1)
389 self.assertEqual(len(err_lines), 0)
390 self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
391 self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
392 self.assertEqual(report.return_code, 0)
393 report.done(Path("f2"), black.Changed.YES)
394 self.assertEqual(len(out_lines), 2)
395 self.assertEqual(len(err_lines), 0)
396 self.assertEqual(out_lines[-1], "reformatted f2")
398 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
400 report.done(Path("f3"), black.Changed.CACHED)
401 self.assertEqual(len(out_lines), 3)
402 self.assertEqual(len(err_lines), 0)
404 out_lines[-1], "f3 wasn't modified on disk since last run."
407 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
409 self.assertEqual(report.return_code, 0)
411 self.assertEqual(report.return_code, 1)
413 report.failed(Path("e1"), "boom")
414 self.assertEqual(len(out_lines), 3)
415 self.assertEqual(len(err_lines), 1)
416 self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
418 unstyle(str(report)),
419 "1 file reformatted, 2 files left unchanged, "
420 "1 file failed to reformat.",
422 self.assertEqual(report.return_code, 123)
423 report.done(Path("f3"), black.Changed.YES)
424 self.assertEqual(len(out_lines), 4)
425 self.assertEqual(len(err_lines), 1)
426 self.assertEqual(out_lines[-1], "reformatted f3")
428 unstyle(str(report)),
429 "2 files reformatted, 2 files left unchanged, "
430 "1 file failed to reformat.",
432 self.assertEqual(report.return_code, 123)
433 report.failed(Path("e2"), "boom")
434 self.assertEqual(len(out_lines), 4)
435 self.assertEqual(len(err_lines), 2)
436 self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
438 unstyle(str(report)),
439 "2 files reformatted, 2 files left unchanged, "
440 "2 files failed to reformat.",
442 self.assertEqual(report.return_code, 123)
443 report.done(Path("f4"), black.Changed.NO)
444 self.assertEqual(len(out_lines), 5)
445 self.assertEqual(len(err_lines), 2)
446 self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
448 unstyle(str(report)),
449 "2 files reformatted, 3 files left unchanged, "
450 "2 files failed to reformat.",
452 self.assertEqual(report.return_code, 123)
455 unstyle(str(report)),
456 "2 files would be reformatted, 3 files would be left unchanged, "
457 "2 files would fail to reformat.",
460 def test_is_python36(self) -> None:
461 node = black.lib2to3_parse("def f(*, arg): ...\n")
462 self.assertFalse(black.is_python36(node))
463 node = black.lib2to3_parse("def f(*, arg,): ...\n")
464 self.assertTrue(black.is_python36(node))
465 node = black.lib2to3_parse("def f(*, arg): f'string'\n")
466 self.assertTrue(black.is_python36(node))
467 source, expected = read_data("function")
468 node = black.lib2to3_parse(source)
469 self.assertTrue(black.is_python36(node))
470 node = black.lib2to3_parse(expected)
471 self.assertTrue(black.is_python36(node))
472 source, expected = read_data("expression")
473 node = black.lib2to3_parse(source)
474 self.assertFalse(black.is_python36(node))
475 node = black.lib2to3_parse(expected)
476 self.assertFalse(black.is_python36(node))
478 def test_get_future_imports(self) -> None:
479 node = black.lib2to3_parse("\n")
480 self.assertEqual(set(), black.get_future_imports(node))
481 node = black.lib2to3_parse("from __future__ import black\n")
482 self.assertEqual({"black"}, black.get_future_imports(node))
483 node = black.lib2to3_parse("from __future__ import multiple, imports\n")
484 self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
485 node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
486 self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
487 node = black.lib2to3_parse(
488 "from __future__ import multiple\nfrom __future__ import imports\n"
490 self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
491 node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
492 self.assertEqual({"black"}, black.get_future_imports(node))
493 node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
494 self.assertEqual({"black"}, black.get_future_imports(node))
495 node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
496 self.assertEqual(set(), black.get_future_imports(node))
497 node = black.lib2to3_parse("from some.module import black\n")
498 self.assertEqual(set(), black.get_future_imports(node))
500 def test_debug_visitor(self) -> None:
501 source, _ = read_data("debug_visitor.py")
502 expected, _ = read_data("debug_visitor.out")
506 def out(msg: str, **kwargs: Any) -> None:
507 out_lines.append(msg)
509 def err(msg: str, **kwargs: Any) -> None:
510 err_lines.append(msg)
512 with patch("black.out", out), patch("black.err", err):
513 black.DebugVisitor.show(source)
514 actual = "\n".join(out_lines) + "\n"
516 if expected != actual:
517 log_name = black.dump_to_file(*out_lines)
521 f"AST print out is different. Actual version dumped to {log_name}",
524 def test_format_file_contents(self) -> None:
526 with self.assertRaises(black.NothingChanged):
527 black.format_file_contents(empty, line_length=ll, fast=False)
529 with self.assertRaises(black.NothingChanged):
530 black.format_file_contents(just_nl, line_length=ll, fast=False)
531 same = "l = [1, 2, 3]\n"
532 with self.assertRaises(black.NothingChanged):
533 black.format_file_contents(same, line_length=ll, fast=False)
534 different = "l = [1,2,3]"
536 actual = black.format_file_contents(different, line_length=ll, fast=False)
537 self.assertEqual(expected, actual)
538 invalid = "return if you can"
539 with self.assertRaises(ValueError) as e:
540 black.format_file_contents(invalid, line_length=ll, fast=False)
541 self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
543 def test_endmarker(self) -> None:
544 n = black.lib2to3_parse("\n")
545 self.assertEqual(n.type, black.syms.file_input)
546 self.assertEqual(len(n.children), 1)
547 self.assertEqual(n.children[0].type, black.token.ENDMARKER)
549 @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
550 def test_assertFormatEqual(self) -> None:
554 def out(msg: str, **kwargs: Any) -> None:
555 out_lines.append(msg)
557 def err(msg: str, **kwargs: Any) -> None:
558 err_lines.append(msg)
560 with patch("black.out", out), patch("black.err", err):
561 with self.assertRaises(AssertionError):
562 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
564 out_str = "".join(out_lines)
565 self.assertTrue("Expected tree:" in out_str)
566 self.assertTrue("Actual tree:" in out_str)
567 self.assertEqual("".join(err_lines), "")
569 def test_cache_broken_file(self) -> None:
570 mode = black.FileMode.AUTO_DETECT
571 with cache_dir() as workspace:
572 cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
573 with cache_file.open("w") as fobj:
574 fobj.write("this is not a pickle")
575 self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
576 src = (workspace / "test.py").resolve()
577 with src.open("w") as fobj:
578 fobj.write("print('hello')")
579 result = CliRunner().invoke(black.main, [str(src)])
580 self.assertEqual(result.exit_code, 0)
581 cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
582 self.assertIn(src, cache)
584 def test_cache_single_file_already_cached(self) -> None:
585 mode = black.FileMode.AUTO_DETECT
586 with cache_dir() as workspace:
587 src = (workspace / "test.py").resolve()
588 with src.open("w") as fobj:
589 fobj.write("print('hello')")
590 black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
591 result = CliRunner().invoke(black.main, [str(src)])
592 self.assertEqual(result.exit_code, 0)
593 with src.open("r") as fobj:
594 self.assertEqual(fobj.read(), "print('hello')")
596 @event_loop(close=False)
597 def test_cache_multiple_files(self) -> None:
598 mode = black.FileMode.AUTO_DETECT
599 with cache_dir() as workspace, patch(
600 "black.ProcessPoolExecutor", new=ThreadPoolExecutor
602 one = (workspace / "one.py").resolve()
603 with one.open("w") as fobj:
604 fobj.write("print('hello')")
605 two = (workspace / "two.py").resolve()
606 with two.open("w") as fobj:
607 fobj.write("print('hello')")
608 black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
609 result = CliRunner().invoke(black.main, [str(workspace)])
610 self.assertEqual(result.exit_code, 0)
611 with one.open("r") as fobj:
612 self.assertEqual(fobj.read(), "print('hello')")
613 with two.open("r") as fobj:
614 self.assertEqual(fobj.read(), 'print("hello")\n')
615 cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
616 self.assertIn(one, cache)
617 self.assertIn(two, cache)
619 def test_no_cache_when_writeback_diff(self) -> None:
620 mode = black.FileMode.AUTO_DETECT
621 with cache_dir() as workspace:
622 src = (workspace / "test.py").resolve()
623 with src.open("w") as fobj:
624 fobj.write("print('hello')")
625 result = CliRunner().invoke(black.main, [str(src), "--diff"])
626 self.assertEqual(result.exit_code, 0)
627 cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
628 self.assertFalse(cache_file.exists())
630 def test_no_cache_when_stdin(self) -> None:
631 mode = black.FileMode.AUTO_DETECT
633 result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
634 self.assertEqual(result.exit_code, 0)
635 cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
636 self.assertFalse(cache_file.exists())
638 def test_read_cache_no_cachefile(self) -> None:
639 mode = black.FileMode.AUTO_DETECT
641 self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
643 def test_write_cache_read_cache(self) -> None:
644 mode = black.FileMode.AUTO_DETECT
645 with cache_dir() as workspace:
646 src = (workspace / "test.py").resolve()
648 black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
649 cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
650 self.assertIn(src, cache)
651 self.assertEqual(cache[src], black.get_cache_info(src))
653 def test_filter_cached(self) -> None:
654 with TemporaryDirectory() as workspace:
655 path = Path(workspace)
656 uncached = (path / "uncached").resolve()
657 cached = (path / "cached").resolve()
658 cached_but_changed = (path / "changed").resolve()
661 cached_but_changed.touch()
662 cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
663 todo, done = black.filter_cached(
664 cache, [uncached, cached, cached_but_changed]
666 self.assertEqual(todo, [uncached, cached_but_changed])
667 self.assertEqual(done, [cached])
669 def test_write_cache_creates_directory_if_needed(self) -> None:
670 mode = black.FileMode.AUTO_DETECT
671 with cache_dir(exists=False) as workspace:
672 self.assertFalse(workspace.exists())
673 black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
674 self.assertTrue(workspace.exists())
676 @event_loop(close=False)
677 def test_failed_formatting_does_not_get_cached(self) -> None:
678 mode = black.FileMode.AUTO_DETECT
679 with cache_dir() as workspace, patch(
680 "black.ProcessPoolExecutor", new=ThreadPoolExecutor
682 failing = (workspace / "failing.py").resolve()
683 with failing.open("w") as fobj:
684 fobj.write("not actually python")
685 clean = (workspace / "clean.py").resolve()
686 with clean.open("w") as fobj:
687 fobj.write('print("hello")\n')
688 result = CliRunner().invoke(black.main, [str(workspace)])
689 self.assertEqual(result.exit_code, 123)
690 cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
691 self.assertNotIn(failing, cache)
692 self.assertIn(clean, cache)
694 def test_write_cache_write_fail(self) -> None:
695 mode = black.FileMode.AUTO_DETECT
696 with cache_dir(), patch.object(Path, "open") as mock:
697 mock.side_effect = OSError
698 black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
700 @event_loop(close=False)
701 def test_check_diff_use_together(self) -> None:
703 # Files which will be reformatted.
704 src1 = (THIS_DIR / "string_quotes.py").resolve()
705 result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
706 self.assertEqual(result.exit_code, 1)
708 # Files which will not be reformatted.
709 src2 = (THIS_DIR / "composition.py").resolve()
710 result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
711 self.assertEqual(result.exit_code, 0)
713 # Multi file command.
714 result = CliRunner().invoke(
715 black.main, [str(src1), str(src2), "--diff", "--check"]
717 self.assertEqual(result.exit_code, 1, result.output)
719 def test_no_files(self) -> None:
721 # Without an argument, black exits with error code 0.
722 result = CliRunner().invoke(black.main, [])
723 self.assertEqual(result.exit_code, 0)
725 def test_broken_symlink(self) -> None:
726 with cache_dir() as workspace:
727 symlink = workspace / "broken_link.py"
728 symlink.symlink_to("nonexistent.py")
729 result = CliRunner().invoke(black.main, [str(workspace.resolve())])
730 self.assertEqual(result.exit_code, 0)
732 def test_read_cache_line_lengths(self) -> None:
733 mode = black.FileMode.AUTO_DETECT
734 with cache_dir() as workspace:
735 path = (workspace / "file.py").resolve()
737 black.write_cache({}, [path], 1, mode)
738 one = black.read_cache(1, mode)
739 self.assertIn(path, one)
740 two = black.read_cache(2, mode)
741 self.assertNotIn(path, two)
743 def test_single_file_force_pyi(self) -> None:
744 reg_mode = black.FileMode.AUTO_DETECT
745 pyi_mode = black.FileMode.PYI
746 contents, expected = read_data("force_pyi")
747 with cache_dir() as workspace:
748 path = (workspace / "file.py").resolve()
749 with open(path, "w") as fh:
751 result = CliRunner().invoke(black.main, [str(path), "--pyi"])
752 self.assertEqual(result.exit_code, 0)
753 with open(path, "r") as fh:
755 # verify cache with --pyi is separate
756 pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
757 self.assertIn(path, pyi_cache)
758 normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
759 self.assertNotIn(path, normal_cache)
760 self.assertEqual(actual, expected)
762 @event_loop(close=False)
763 def test_multi_file_force_pyi(self) -> None:
764 reg_mode = black.FileMode.AUTO_DETECT
765 pyi_mode = black.FileMode.PYI
766 contents, expected = read_data("force_pyi")
767 with cache_dir() as workspace:
769 (workspace / "file1.py").resolve(),
770 (workspace / "file2.py").resolve(),
773 with open(path, "w") as fh:
775 result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
776 self.assertEqual(result.exit_code, 0)
778 with open(path, "r") as fh:
780 self.assertEqual(actual, expected)
781 # verify cache with --pyi is separate
782 pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
783 normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
785 self.assertIn(path, pyi_cache)
786 self.assertNotIn(path, normal_cache)
788 def test_pipe_force_pyi(self) -> None:
789 source, expected = read_data("force_pyi")
790 result = CliRunner().invoke(black.main, ["-", "-q", "--pyi"], input=source)
791 self.assertEqual(result.exit_code, 0)
792 actual = result.output
793 self.assertFormatEqual(actual, expected)
795 def test_single_file_force_py36(self) -> None:
796 reg_mode = black.FileMode.AUTO_DETECT
797 py36_mode = black.FileMode.PYTHON36
798 source, expected = read_data("force_py36")
799 with cache_dir() as workspace:
800 path = (workspace / "file.py").resolve()
801 with open(path, "w") as fh:
803 result = CliRunner().invoke(black.main, [str(path), "--py36"])
804 self.assertEqual(result.exit_code, 0)
805 with open(path, "r") as fh:
807 # verify cache with --py36 is separate
808 py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
809 self.assertIn(path, py36_cache)
810 normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
811 self.assertNotIn(path, normal_cache)
812 self.assertEqual(actual, expected)
814 @event_loop(close=False)
815 def test_multi_file_force_py36(self) -> None:
816 reg_mode = black.FileMode.AUTO_DETECT
817 py36_mode = black.FileMode.PYTHON36
818 source, expected = read_data("force_py36")
819 with cache_dir() as workspace:
821 (workspace / "file1.py").resolve(),
822 (workspace / "file2.py").resolve(),
825 with open(path, "w") as fh:
827 result = CliRunner().invoke(
828 black.main, [str(p) for p in paths] + ["--py36"]
830 self.assertEqual(result.exit_code, 0)
832 with open(path, "r") as fh:
834 self.assertEqual(actual, expected)
835 # verify cache with --py36 is separate
836 pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
837 normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
839 self.assertIn(path, pyi_cache)
840 self.assertNotIn(path, normal_cache)
842 def test_pipe_force_py36(self) -> None:
843 source, expected = read_data("force_py36")
844 result = CliRunner().invoke(black.main, ["-", "-q", "--py36"], input=source)
845 self.assertEqual(result.exit_code, 0)
846 actual = result.output
847 self.assertFormatEqual(actual, expected)
850 if __name__ == "__main__":