All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
8 from pathlib import Path
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
13 from unittest.mock import patch
15 from click import unstyle
16 from click.testing import CliRunner
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
28 def dump_to_stderr(*output: str) -> str:
29 return "\n" + "\n".join(output) + "\n"
32 def read_data(name: str) -> Tuple[str, str]:
33 """read_data('test_name') -> 'input', 'output'"""
34 if not name.endswith((".py", ".out", ".diff")):
36 _input: List[str] = []
37 _output: List[str] = []
38 with open(THIS_DIR / name, "r", encoding="utf8") as test:
39 lines = test.readlines()
42 line = line.replace(EMPTY_LINE, "")
43 if line.rstrip() == "# output":
48 if _input and not _output:
49 # If there's no output marker, treat the entire file as already pre-formatted.
51 return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56 with TemporaryDirectory() as workspace:
57 cache_dir = Path(workspace)
59 cache_dir = cache_dir / "new"
60 with patch("black.CACHE_DIR", cache_dir):
65 def event_loop(close: bool) -> Iterator[None]:
66 policy = asyncio.get_event_loop_policy()
67 old_loop = policy.get_event_loop()
68 loop = policy.new_event_loop()
69 asyncio.set_event_loop(loop)
74 policy.set_event_loop(old_loop)
79 class BlackTestCase(unittest.TestCase):
82 def assertFormatEqual(self, expected: str, actual: str) -> None:
83 if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84 bdv: black.DebugVisitor[Any]
85 black.out("Expected tree:", fg="green")
87 exp_node = black.lib2to3_parse(expected)
88 bdv = black.DebugVisitor()
89 list(bdv.visit(exp_node))
90 except Exception as ve:
92 black.out("Actual tree:", fg="red")
94 exp_node = black.lib2to3_parse(actual)
95 bdv = black.DebugVisitor()
96 list(bdv.visit(exp_node))
97 except Exception as ve:
99 self.assertEqual(expected, actual)
101 @patch("black.dump_to_file", dump_to_stderr)
102 def test_self(self) -> None:
103 source, expected = read_data("test_black")
105 self.assertFormatEqual(expected, actual)
106 black.assert_equivalent(source, actual)
107 black.assert_stable(source, actual, line_length=ll)
108 self.assertFalse(ff(THIS_FILE))
110 @patch("black.dump_to_file", dump_to_stderr)
111 def test_black(self) -> None:
112 source, expected = read_data("../black")
114 self.assertFormatEqual(expected, actual)
115 black.assert_equivalent(source, actual)
116 black.assert_stable(source, actual, line_length=ll)
117 self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
119 def test_piping(self) -> None:
120 source, expected = read_data("../black")
121 hold_stdin, hold_stdout = sys.stdin, sys.stdout
123 sys.stdin, sys.stdout = StringIO(source), StringIO()
124 sys.stdin.name = "<stdin>"
125 black.format_stdin_to_stdout(
126 line_length=ll, fast=True, write_back=black.WriteBack.YES
129 actual = sys.stdout.read()
131 sys.stdin, sys.stdout = hold_stdin, hold_stdout
132 self.assertFormatEqual(expected, actual)
133 black.assert_equivalent(source, actual)
134 black.assert_stable(source, actual, line_length=ll)
136 def test_piping_diff(self) -> None:
137 source, _ = read_data("expression.py")
138 expected, _ = read_data("expression.diff")
139 hold_stdin, hold_stdout = sys.stdin, sys.stdout
141 sys.stdin, sys.stdout = StringIO(source), StringIO()
142 sys.stdin.name = "<stdin>"
143 black.format_stdin_to_stdout(
144 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
147 actual = sys.stdout.read()
149 sys.stdin, sys.stdout = hold_stdin, hold_stdout
150 actual = actual.rstrip() + "\n" # the diff output has a trailing space
151 self.assertEqual(expected, actual)
153 @patch("black.dump_to_file", dump_to_stderr)
154 def test_setup(self) -> None:
155 source, expected = read_data("../setup")
157 self.assertFormatEqual(expected, actual)
158 black.assert_equivalent(source, actual)
159 black.assert_stable(source, actual, line_length=ll)
160 self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
162 @patch("black.dump_to_file", dump_to_stderr)
163 def test_function(self) -> None:
164 source, expected = read_data("function")
166 self.assertFormatEqual(expected, actual)
167 black.assert_equivalent(source, actual)
168 black.assert_stable(source, actual, line_length=ll)
170 @patch("black.dump_to_file", dump_to_stderr)
171 def test_function2(self) -> None:
172 source, expected = read_data("function2")
174 self.assertFormatEqual(expected, actual)
175 black.assert_equivalent(source, actual)
176 black.assert_stable(source, actual, line_length=ll)
178 @patch("black.dump_to_file", dump_to_stderr)
179 def test_expression(self) -> None:
180 source, expected = read_data("expression")
182 self.assertFormatEqual(expected, actual)
183 black.assert_equivalent(source, actual)
184 black.assert_stable(source, actual, line_length=ll)
186 def test_expression_ff(self) -> None:
187 source, expected = read_data("expression")
188 tmp_file = Path(black.dump_to_file(source))
190 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
191 with open(tmp_file, encoding="utf8") as f:
195 self.assertFormatEqual(expected, actual)
196 with patch("black.dump_to_file", dump_to_stderr):
197 black.assert_equivalent(source, actual)
198 black.assert_stable(source, actual, line_length=ll)
200 def test_expression_diff(self) -> None:
201 source, _ = read_data("expression.py")
202 expected, _ = read_data("expression.diff")
203 tmp_file = Path(black.dump_to_file(source))
204 hold_stdout = sys.stdout
206 sys.stdout = StringIO()
207 self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
209 actual = sys.stdout.read()
210 actual = actual.replace(str(tmp_file), "<stdin>")
212 sys.stdout = hold_stdout
214 actual = actual.rstrip() + "\n" # the diff output has a trailing space
215 if expected != actual:
216 dump = black.dump_to_file(actual)
218 f"Expected diff isn't equal to the actual. If you made changes "
219 f"to expression.py and this is an anticipated difference, "
220 f"overwrite tests/expression.diff with {dump}"
222 self.assertEqual(expected, actual, msg)
224 @patch("black.dump_to_file", dump_to_stderr)
225 def test_fstring(self) -> None:
226 source, expected = read_data("fstring")
228 self.assertFormatEqual(expected, actual)
229 black.assert_equivalent(source, actual)
230 black.assert_stable(source, actual, line_length=ll)
232 @patch("black.dump_to_file", dump_to_stderr)
233 def test_string_quotes(self) -> None:
234 source, expected = read_data("string_quotes")
236 self.assertFormatEqual(expected, actual)
237 black.assert_equivalent(source, actual)
238 black.assert_stable(source, actual, line_length=ll)
240 @patch("black.dump_to_file", dump_to_stderr)
241 def test_slices(self) -> None:
242 source, expected = read_data("slices")
244 self.assertFormatEqual(expected, actual)
245 black.assert_equivalent(source, actual)
246 black.assert_stable(source, actual, line_length=ll)
248 @patch("black.dump_to_file", dump_to_stderr)
249 def test_comments(self) -> None:
250 source, expected = read_data("comments")
252 self.assertFormatEqual(expected, actual)
253 black.assert_equivalent(source, actual)
254 black.assert_stable(source, actual, line_length=ll)
256 @patch("black.dump_to_file", dump_to_stderr)
257 def test_comments2(self) -> None:
258 source, expected = read_data("comments2")
260 self.assertFormatEqual(expected, actual)
261 black.assert_equivalent(source, actual)
262 black.assert_stable(source, actual, line_length=ll)
264 @patch("black.dump_to_file", dump_to_stderr)
265 def test_comments3(self) -> None:
266 source, expected = read_data("comments3")
268 self.assertFormatEqual(expected, actual)
269 black.assert_equivalent(source, actual)
270 black.assert_stable(source, actual, line_length=ll)
272 @patch("black.dump_to_file", dump_to_stderr)
273 def test_comments4(self) -> None:
274 source, expected = read_data("comments4")
276 self.assertFormatEqual(expected, actual)
277 black.assert_equivalent(source, actual)
278 black.assert_stable(source, actual, line_length=ll)
280 @patch("black.dump_to_file", dump_to_stderr)
281 def test_comments5(self) -> None:
282 source, expected = read_data("comments5")
284 self.assertFormatEqual(expected, actual)
285 black.assert_equivalent(source, actual)
286 black.assert_stable(source, actual, line_length=ll)
288 @patch("black.dump_to_file", dump_to_stderr)
289 def test_cantfit(self) -> None:
290 source, expected = read_data("cantfit")
292 self.assertFormatEqual(expected, actual)
293 black.assert_equivalent(source, actual)
294 black.assert_stable(source, actual, line_length=ll)
296 @patch("black.dump_to_file", dump_to_stderr)
297 def test_import_spacing(self) -> None:
298 source, expected = read_data("import_spacing")
300 self.assertFormatEqual(expected, actual)
301 black.assert_equivalent(source, actual)
302 black.assert_stable(source, actual, line_length=ll)
304 @patch("black.dump_to_file", dump_to_stderr)
305 def test_composition(self) -> None:
306 source, expected = read_data("composition")
308 self.assertFormatEqual(expected, actual)
309 black.assert_equivalent(source, actual)
310 black.assert_stable(source, actual, line_length=ll)
312 @patch("black.dump_to_file", dump_to_stderr)
313 def test_empty_lines(self) -> None:
314 source, expected = read_data("empty_lines")
316 self.assertFormatEqual(expected, actual)
317 black.assert_equivalent(source, actual)
318 black.assert_stable(source, actual, line_length=ll)
320 @patch("black.dump_to_file", dump_to_stderr)
321 def test_string_prefixes(self) -> None:
322 source, expected = read_data("string_prefixes")
324 self.assertFormatEqual(expected, actual)
325 black.assert_equivalent(source, actual)
326 black.assert_stable(source, actual, line_length=ll)
328 @patch("black.dump_to_file", dump_to_stderr)
329 def test_python2(self) -> None:
330 source, expected = read_data("python2")
332 self.assertFormatEqual(expected, actual)
333 # black.assert_equivalent(source, actual)
334 black.assert_stable(source, actual, line_length=ll)
336 @patch("black.dump_to_file", dump_to_stderr)
337 def test_python2_unicode_literals(self) -> None:
338 source, expected = read_data("python2_unicode_literals")
340 self.assertFormatEqual(expected, actual)
341 black.assert_stable(source, actual, line_length=ll)
343 @patch("black.dump_to_file", dump_to_stderr)
344 def test_fmtonoff(self) -> None:
345 source, expected = read_data("fmtonoff")
347 self.assertFormatEqual(expected, actual)
348 black.assert_equivalent(source, actual)
349 black.assert_stable(source, actual, line_length=ll)
351 @patch("black.dump_to_file", dump_to_stderr)
352 def test_remove_empty_parentheses_after_class(self) -> None:
353 source, expected = read_data("class_blank_parentheses")
355 self.assertFormatEqual(expected, actual)
356 black.assert_equivalent(source, actual)
357 black.assert_stable(source, actual, line_length=ll)
359 def test_report(self) -> None:
360 report = black.Report()
364 def out(msg: str, **kwargs: Any) -> None:
365 out_lines.append(msg)
367 def err(msg: str, **kwargs: Any) -> None:
368 err_lines.append(msg)
370 with patch("black.out", out), patch("black.err", err):
371 report.done(Path("f1"), black.Changed.NO)
372 self.assertEqual(len(out_lines), 1)
373 self.assertEqual(len(err_lines), 0)
374 self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
375 self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
376 self.assertEqual(report.return_code, 0)
377 report.done(Path("f2"), black.Changed.YES)
378 self.assertEqual(len(out_lines), 2)
379 self.assertEqual(len(err_lines), 0)
380 self.assertEqual(out_lines[-1], "reformatted f2")
382 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
384 report.done(Path("f3"), black.Changed.CACHED)
385 self.assertEqual(len(out_lines), 3)
386 self.assertEqual(len(err_lines), 0)
388 out_lines[-1], "f3 wasn't modified on disk since last run."
391 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
393 self.assertEqual(report.return_code, 0)
395 self.assertEqual(report.return_code, 1)
397 report.failed(Path("e1"), "boom")
398 self.assertEqual(len(out_lines), 3)
399 self.assertEqual(len(err_lines), 1)
400 self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
402 unstyle(str(report)),
403 "1 file reformatted, 2 files left unchanged, "
404 "1 file failed to reformat.",
406 self.assertEqual(report.return_code, 123)
407 report.done(Path("f3"), black.Changed.YES)
408 self.assertEqual(len(out_lines), 4)
409 self.assertEqual(len(err_lines), 1)
410 self.assertEqual(out_lines[-1], "reformatted f3")
412 unstyle(str(report)),
413 "2 files reformatted, 2 files left unchanged, "
414 "1 file failed to reformat.",
416 self.assertEqual(report.return_code, 123)
417 report.failed(Path("e2"), "boom")
418 self.assertEqual(len(out_lines), 4)
419 self.assertEqual(len(err_lines), 2)
420 self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
422 unstyle(str(report)),
423 "2 files reformatted, 2 files left unchanged, "
424 "2 files failed to reformat.",
426 self.assertEqual(report.return_code, 123)
427 report.done(Path("f4"), black.Changed.NO)
428 self.assertEqual(len(out_lines), 5)
429 self.assertEqual(len(err_lines), 2)
430 self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
432 unstyle(str(report)),
433 "2 files reformatted, 3 files left unchanged, "
434 "2 files failed to reformat.",
436 self.assertEqual(report.return_code, 123)
439 unstyle(str(report)),
440 "2 files would be reformatted, 3 files would be left unchanged, "
441 "2 files would fail to reformat.",
444 def test_is_python36(self) -> None:
445 node = black.lib2to3_parse("def f(*, arg): ...\n")
446 self.assertFalse(black.is_python36(node))
447 node = black.lib2to3_parse("def f(*, arg,): ...\n")
448 self.assertTrue(black.is_python36(node))
449 node = black.lib2to3_parse("def f(*, arg): f'string'\n")
450 self.assertTrue(black.is_python36(node))
451 source, expected = read_data("function")
452 node = black.lib2to3_parse(source)
453 self.assertTrue(black.is_python36(node))
454 node = black.lib2to3_parse(expected)
455 self.assertTrue(black.is_python36(node))
456 source, expected = read_data("expression")
457 node = black.lib2to3_parse(source)
458 self.assertFalse(black.is_python36(node))
459 node = black.lib2to3_parse(expected)
460 self.assertFalse(black.is_python36(node))
462 def test_get_future_imports(self) -> None:
463 node = black.lib2to3_parse("\n")
464 self.assertEqual(set(), black.get_future_imports(node))
465 node = black.lib2to3_parse("from __future__ import black\n")
466 self.assertEqual({"black"}, black.get_future_imports(node))
467 node = black.lib2to3_parse("from __future__ import multiple, imports\n")
468 self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
469 node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
470 self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
471 node = black.lib2to3_parse(
472 "from __future__ import multiple\nfrom __future__ import imports\n"
474 self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
475 node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
476 self.assertEqual({"black"}, black.get_future_imports(node))
477 node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
478 self.assertEqual({"black"}, black.get_future_imports(node))
479 node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
480 self.assertEqual(set(), black.get_future_imports(node))
481 node = black.lib2to3_parse("from some.module import black\n")
482 self.assertEqual(set(), black.get_future_imports(node))
484 def test_debug_visitor(self) -> None:
485 source, _ = read_data("debug_visitor.py")
486 expected, _ = read_data("debug_visitor.out")
490 def out(msg: str, **kwargs: Any) -> None:
491 out_lines.append(msg)
493 def err(msg: str, **kwargs: Any) -> None:
494 err_lines.append(msg)
496 with patch("black.out", out), patch("black.err", err):
497 black.DebugVisitor.show(source)
498 actual = "\n".join(out_lines) + "\n"
500 if expected != actual:
501 log_name = black.dump_to_file(*out_lines)
505 f"AST print out is different. Actual version dumped to {log_name}",
508 def test_format_file_contents(self) -> None:
510 with self.assertRaises(black.NothingChanged):
511 black.format_file_contents(empty, line_length=ll, fast=False)
513 with self.assertRaises(black.NothingChanged):
514 black.format_file_contents(just_nl, line_length=ll, fast=False)
515 same = "l = [1, 2, 3]\n"
516 with self.assertRaises(black.NothingChanged):
517 black.format_file_contents(same, line_length=ll, fast=False)
518 different = "l = [1,2,3]"
520 actual = black.format_file_contents(different, line_length=ll, fast=False)
521 self.assertEqual(expected, actual)
522 invalid = "return if you can"
523 with self.assertRaises(ValueError) as e:
524 black.format_file_contents(invalid, line_length=ll, fast=False)
525 self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
527 def test_endmarker(self) -> None:
528 n = black.lib2to3_parse("\n")
529 self.assertEqual(n.type, black.syms.file_input)
530 self.assertEqual(len(n.children), 1)
531 self.assertEqual(n.children[0].type, black.token.ENDMARKER)
533 @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
534 def test_assertFormatEqual(self) -> None:
538 def out(msg: str, **kwargs: Any) -> None:
539 out_lines.append(msg)
541 def err(msg: str, **kwargs: Any) -> None:
542 err_lines.append(msg)
544 with patch("black.out", out), patch("black.err", err):
545 with self.assertRaises(AssertionError):
546 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
548 out_str = "".join(out_lines)
549 self.assertTrue("Expected tree:" in out_str)
550 self.assertTrue("Actual tree:" in out_str)
551 self.assertEqual("".join(err_lines), "")
553 def test_cache_broken_file(self) -> None:
554 with cache_dir() as workspace:
555 cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
556 with cache_file.open("w") as fobj:
557 fobj.write("this is not a pickle")
558 self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
559 src = (workspace / "test.py").resolve()
560 with src.open("w") as fobj:
561 fobj.write("print('hello')")
562 result = CliRunner().invoke(black.main, [str(src)])
563 self.assertEqual(result.exit_code, 0)
564 cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
565 self.assertIn(src, cache)
567 def test_cache_single_file_already_cached(self) -> None:
568 with cache_dir() as workspace:
569 src = (workspace / "test.py").resolve()
570 with src.open("w") as fobj:
571 fobj.write("print('hello')")
572 black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
573 result = CliRunner().invoke(black.main, [str(src)])
574 self.assertEqual(result.exit_code, 0)
575 with src.open("r") as fobj:
576 self.assertEqual(fobj.read(), "print('hello')")
578 @event_loop(close=False)
579 def test_cache_multiple_files(self) -> None:
580 with cache_dir() as workspace, patch(
581 "black.ProcessPoolExecutor", new=ThreadPoolExecutor
583 one = (workspace / "one.py").resolve()
584 with one.open("w") as fobj:
585 fobj.write("print('hello')")
586 two = (workspace / "two.py").resolve()
587 with two.open("w") as fobj:
588 fobj.write("print('hello')")
589 black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH)
590 result = CliRunner().invoke(black.main, [str(workspace)])
591 self.assertEqual(result.exit_code, 0)
592 with one.open("r") as fobj:
593 self.assertEqual(fobj.read(), "print('hello')")
594 with two.open("r") as fobj:
595 self.assertEqual(fobj.read(), 'print("hello")\n')
596 cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
597 self.assertIn(one, cache)
598 self.assertIn(two, cache)
600 def test_no_cache_when_writeback_diff(self) -> None:
601 with cache_dir() as workspace:
602 src = (workspace / "test.py").resolve()
603 with src.open("w") as fobj:
604 fobj.write("print('hello')")
605 result = CliRunner().invoke(black.main, [str(src), "--diff"])
606 self.assertEqual(result.exit_code, 0)
607 cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
608 self.assertFalse(cache_file.exists())
610 def test_no_cache_when_stdin(self) -> None:
612 result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
613 self.assertEqual(result.exit_code, 0)
614 cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
615 self.assertFalse(cache_file.exists())
617 def test_read_cache_no_cachefile(self) -> None:
619 self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
621 def test_write_cache_read_cache(self) -> None:
622 with cache_dir() as workspace:
623 src = (workspace / "test.py").resolve()
625 black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
626 cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
627 self.assertIn(src, cache)
628 self.assertEqual(cache[src], black.get_cache_info(src))
630 def test_filter_cached(self) -> None:
631 with TemporaryDirectory() as workspace:
632 path = Path(workspace)
633 uncached = (path / "uncached").resolve()
634 cached = (path / "cached").resolve()
635 cached_but_changed = (path / "changed").resolve()
638 cached_but_changed.touch()
639 cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
640 todo, done = black.filter_cached(
641 cache, [uncached, cached, cached_but_changed]
643 self.assertEqual(todo, [uncached, cached_but_changed])
644 self.assertEqual(done, [cached])
646 def test_write_cache_creates_directory_if_needed(self) -> None:
647 with cache_dir(exists=False) as workspace:
648 self.assertFalse(workspace.exists())
649 black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
650 self.assertTrue(workspace.exists())
652 @event_loop(close=False)
653 def test_failed_formatting_does_not_get_cached(self) -> None:
654 with cache_dir() as workspace, patch(
655 "black.ProcessPoolExecutor", new=ThreadPoolExecutor
657 failing = (workspace / "failing.py").resolve()
658 with failing.open("w") as fobj:
659 fobj.write("not actually python")
660 clean = (workspace / "clean.py").resolve()
661 with clean.open("w") as fobj:
662 fobj.write('print("hello")\n')
663 result = CliRunner().invoke(black.main, [str(workspace)])
664 self.assertEqual(result.exit_code, 123)
665 cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
666 self.assertNotIn(failing, cache)
667 self.assertIn(clean, cache)
669 def test_write_cache_write_fail(self) -> None:
670 with cache_dir(), patch.object(Path, "open") as mock:
671 mock.side_effect = OSError
672 black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
674 def test_check_diff_use_together(self) -> None:
676 # Files which will be reformatted.
677 src1 = (THIS_DIR / "string_quotes.py").resolve()
678 result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
679 self.assertEqual(result.exit_code, 1)
681 # Files which will not be reformatted.
682 src2 = (THIS_DIR / "composition.py").resolve()
683 result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
684 self.assertEqual(result.exit_code, 0)
686 # Multi file command.
687 result = CliRunner().invoke(
688 black.main, [str(src1), str(src2), "--diff", "--check"]
690 self.assertEqual(result.exit_code, 1)
692 def test_no_files(self) -> None:
694 # Without an argument, black exits with error code 0.
695 result = CliRunner().invoke(black.main, [])
696 self.assertEqual(result.exit_code, 0)
698 def test_broken_symlink(self) -> None:
699 with cache_dir() as workspace:
700 symlink = workspace / "broken_link.py"
701 symlink.symlink_to("nonexistent.py")
702 result = CliRunner().invoke(black.main, [str(workspace.resolve())])
703 self.assertEqual(result.exit_code, 0)
705 def test_read_cache_line_lengths(self) -> None:
706 with cache_dir() as workspace:
707 path = (workspace / "file.py").resolve()
709 black.write_cache({}, [path], 1)
710 one = black.read_cache(1)
711 self.assertIn(path, one)
712 two = black.read_cache(2)
713 self.assertNotIn(path, two)
716 if __name__ == "__main__":