]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

951b2988631979e1e8de4a13e0da92a29741ce5b
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14
15 from click import unstyle
16 from click.testing import CliRunner
17
18 import black
19
20 ll = 88
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
26
27
28 def dump_to_stderr(*output: str) -> str:
29     return "\n" + "\n".join(output) + "\n"
30
31
32 def read_data(name: str) -> Tuple[str, str]:
33     """read_data('test_name') -> 'input', 'output'"""
34     if not name.endswith((".py", ".out", ".diff")):
35         name += ".py"
36     _input: List[str] = []
37     _output: List[str] = []
38     with open(THIS_DIR / name, "r", encoding="utf8") as test:
39         lines = test.readlines()
40     result = _input
41     for line in lines:
42         line = line.replace(EMPTY_LINE, "")
43         if line.rstrip() == "# output":
44             result = _output
45             continue
46
47         result.append(line)
48     if _input and not _output:
49         # If there's no output marker, treat the entire file as already pre-formatted.
50         _output = _input[:]
51     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
52
53
54 @contextmanager
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56     with TemporaryDirectory() as workspace:
57         cache_dir = Path(workspace)
58         if not exists:
59             cache_dir = cache_dir / "new"
60         with patch("black.CACHE_DIR", cache_dir):
61             yield cache_dir
62
63
64 @contextmanager
65 def event_loop(close: bool) -> Iterator[None]:
66     policy = asyncio.get_event_loop_policy()
67     old_loop = policy.get_event_loop()
68     loop = policy.new_event_loop()
69     asyncio.set_event_loop(loop)
70     try:
71         yield
72
73     finally:
74         policy.set_event_loop(old_loop)
75         if close:
76             loop.close()
77
78
79 class BlackTestCase(unittest.TestCase):
80     maxDiff = None
81
82     def assertFormatEqual(self, expected: str, actual: str) -> None:
83         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84             bdv: black.DebugVisitor[Any]
85             black.out("Expected tree:", fg="green")
86             try:
87                 exp_node = black.lib2to3_parse(expected)
88                 bdv = black.DebugVisitor()
89                 list(bdv.visit(exp_node))
90             except Exception as ve:
91                 black.err(str(ve))
92             black.out("Actual tree:", fg="red")
93             try:
94                 exp_node = black.lib2to3_parse(actual)
95                 bdv = black.DebugVisitor()
96                 list(bdv.visit(exp_node))
97             except Exception as ve:
98                 black.err(str(ve))
99         self.assertEqual(expected, actual)
100
101     @patch("black.dump_to_file", dump_to_stderr)
102     def test_self(self) -> None:
103         source, expected = read_data("test_black")
104         actual = fs(source)
105         self.assertFormatEqual(expected, actual)
106         black.assert_equivalent(source, actual)
107         black.assert_stable(source, actual, line_length=ll)
108         self.assertFalse(ff(THIS_FILE))
109
110     @patch("black.dump_to_file", dump_to_stderr)
111     def test_black(self) -> None:
112         source, expected = read_data("../black")
113         actual = fs(source)
114         self.assertFormatEqual(expected, actual)
115         black.assert_equivalent(source, actual)
116         black.assert_stable(source, actual, line_length=ll)
117         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
118
119     def test_piping(self) -> None:
120         source, expected = read_data("../black")
121         hold_stdin, hold_stdout = sys.stdin, sys.stdout
122         try:
123             sys.stdin, sys.stdout = StringIO(source), StringIO()
124             sys.stdin.name = "<stdin>"
125             black.format_stdin_to_stdout(
126                 line_length=ll, fast=True, write_back=black.WriteBack.YES
127             )
128             sys.stdout.seek(0)
129             actual = sys.stdout.read()
130         finally:
131             sys.stdin, sys.stdout = hold_stdin, hold_stdout
132         self.assertFormatEqual(expected, actual)
133         black.assert_equivalent(source, actual)
134         black.assert_stable(source, actual, line_length=ll)
135
136     def test_piping_diff(self) -> None:
137         source, _ = read_data("expression.py")
138         expected, _ = read_data("expression.diff")
139         hold_stdin, hold_stdout = sys.stdin, sys.stdout
140         try:
141             sys.stdin, sys.stdout = StringIO(source), StringIO()
142             sys.stdin.name = "<stdin>"
143             black.format_stdin_to_stdout(
144                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
145             )
146             sys.stdout.seek(0)
147             actual = sys.stdout.read()
148         finally:
149             sys.stdin, sys.stdout = hold_stdin, hold_stdout
150         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
151         self.assertEqual(expected, actual)
152
153     @patch("black.dump_to_file", dump_to_stderr)
154     def test_setup(self) -> None:
155         source, expected = read_data("../setup")
156         actual = fs(source)
157         self.assertFormatEqual(expected, actual)
158         black.assert_equivalent(source, actual)
159         black.assert_stable(source, actual, line_length=ll)
160         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
161
162     @patch("black.dump_to_file", dump_to_stderr)
163     def test_function(self) -> None:
164         source, expected = read_data("function")
165         actual = fs(source)
166         self.assertFormatEqual(expected, actual)
167         black.assert_equivalent(source, actual)
168         black.assert_stable(source, actual, line_length=ll)
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_expression(self) -> None:
172         source, expected = read_data("expression")
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, line_length=ll)
177
178     def test_expression_ff(self) -> None:
179         source, expected = read_data("expression")
180         tmp_file = Path(black.dump_to_file(source))
181         try:
182             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188         with patch("black.dump_to_file", dump_to_stderr):
189             black.assert_equivalent(source, actual)
190             black.assert_stable(source, actual, line_length=ll)
191
192     def test_expression_diff(self) -> None:
193         source, _ = read_data("expression.py")
194         expected, _ = read_data("expression.diff")
195         tmp_file = Path(black.dump_to_file(source))
196         hold_stdout = sys.stdout
197         try:
198             sys.stdout = StringIO()
199             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
200             sys.stdout.seek(0)
201             actual = sys.stdout.read()
202             actual = actual.replace(str(tmp_file), "<stdin>")
203         finally:
204             sys.stdout = hold_stdout
205             os.unlink(tmp_file)
206         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
207         if expected != actual:
208             dump = black.dump_to_file(actual)
209             msg = (
210                 f"Expected diff isn't equal to the actual. If you made changes "
211                 f"to expression.py and this is an anticipated difference, "
212                 f"overwrite tests/expression.diff with {dump}"
213             )
214             self.assertEqual(expected, actual, msg)
215
216     @patch("black.dump_to_file", dump_to_stderr)
217     def test_fstring(self) -> None:
218         source, expected = read_data("fstring")
219         actual = fs(source)
220         self.assertFormatEqual(expected, actual)
221         black.assert_equivalent(source, actual)
222         black.assert_stable(source, actual, line_length=ll)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_string_quotes(self) -> None:
226         source, expected = read_data("string_quotes")
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_slices(self) -> None:
234         source, expected = read_data("slices")
235         actual = fs(source)
236         self.assertFormatEqual(expected, actual)
237         black.assert_equivalent(source, actual)
238         black.assert_stable(source, actual, line_length=ll)
239
240     @patch("black.dump_to_file", dump_to_stderr)
241     def test_comments(self) -> None:
242         source, expected = read_data("comments")
243         actual = fs(source)
244         self.assertFormatEqual(expected, actual)
245         black.assert_equivalent(source, actual)
246         black.assert_stable(source, actual, line_length=ll)
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_comments2(self) -> None:
250         source, expected = read_data("comments2")
251         actual = fs(source)
252         self.assertFormatEqual(expected, actual)
253         black.assert_equivalent(source, actual)
254         black.assert_stable(source, actual, line_length=ll)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_comments3(self) -> None:
258         source, expected = read_data("comments3")
259         actual = fs(source)
260         self.assertFormatEqual(expected, actual)
261         black.assert_equivalent(source, actual)
262         black.assert_stable(source, actual, line_length=ll)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_comments4(self) -> None:
266         source, expected = read_data("comments4")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, line_length=ll)
271
272     @patch("black.dump_to_file", dump_to_stderr)
273     def test_comments5(self) -> None:
274         source, expected = read_data("comments5")
275         actual = fs(source)
276         self.assertFormatEqual(expected, actual)
277         black.assert_equivalent(source, actual)
278         black.assert_stable(source, actual, line_length=ll)
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_cantfit(self) -> None:
282         source, expected = read_data("cantfit")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_equivalent(source, actual)
286         black.assert_stable(source, actual, line_length=ll)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_import_spacing(self) -> None:
290         source, expected = read_data("import_spacing")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_composition(self) -> None:
298         source, expected = read_data("composition")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_empty_lines(self) -> None:
306         source, expected = read_data("empty_lines")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     @patch("black.dump_to_file", dump_to_stderr)
313     def test_python2(self) -> None:
314         source, expected = read_data("python2")
315         actual = fs(source)
316         self.assertFormatEqual(expected, actual)
317         # black.assert_equivalent(source, actual)
318         black.assert_stable(source, actual, line_length=ll)
319
320     @patch("black.dump_to_file", dump_to_stderr)
321     def test_fmtonoff(self) -> None:
322         source, expected = read_data("fmtonoff")
323         actual = fs(source)
324         self.assertFormatEqual(expected, actual)
325         black.assert_equivalent(source, actual)
326         black.assert_stable(source, actual, line_length=ll)
327
328     @patch("black.dump_to_file", dump_to_stderr)
329     def test_remove_empty_parentheses_after_class(self) -> None:
330         source, expected = read_data("class_blank_parentheses")
331         actual = fs(source)
332         self.assertFormatEqual(expected, actual)
333         black.assert_equivalent(source, actual)
334         black.assert_stable(source, actual, line_length=ll)
335
336     def test_report(self) -> None:
337         report = black.Report()
338         out_lines = []
339         err_lines = []
340
341         def out(msg: str, **kwargs: Any) -> None:
342             out_lines.append(msg)
343
344         def err(msg: str, **kwargs: Any) -> None:
345             err_lines.append(msg)
346
347         with patch("black.out", out), patch("black.err", err):
348             report.done(Path("f1"), black.Changed.NO)
349             self.assertEqual(len(out_lines), 1)
350             self.assertEqual(len(err_lines), 0)
351             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
352             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
353             self.assertEqual(report.return_code, 0)
354             report.done(Path("f2"), black.Changed.YES)
355             self.assertEqual(len(out_lines), 2)
356             self.assertEqual(len(err_lines), 0)
357             self.assertEqual(out_lines[-1], "reformatted f2")
358             self.assertEqual(
359                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
360             )
361             report.done(Path("f3"), black.Changed.CACHED)
362             self.assertEqual(len(out_lines), 3)
363             self.assertEqual(len(err_lines), 0)
364             self.assertEqual(
365                 out_lines[-1], "f3 wasn't modified on disk since last run."
366             )
367             self.assertEqual(
368                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
369             )
370             self.assertEqual(report.return_code, 0)
371             report.check = True
372             self.assertEqual(report.return_code, 1)
373             report.check = False
374             report.failed(Path("e1"), "boom")
375             self.assertEqual(len(out_lines), 3)
376             self.assertEqual(len(err_lines), 1)
377             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
378             self.assertEqual(
379                 unstyle(str(report)),
380                 "1 file reformatted, 2 files left unchanged, "
381                 "1 file failed to reformat.",
382             )
383             self.assertEqual(report.return_code, 123)
384             report.done(Path("f3"), black.Changed.YES)
385             self.assertEqual(len(out_lines), 4)
386             self.assertEqual(len(err_lines), 1)
387             self.assertEqual(out_lines[-1], "reformatted f3")
388             self.assertEqual(
389                 unstyle(str(report)),
390                 "2 files reformatted, 2 files left unchanged, "
391                 "1 file failed to reformat.",
392             )
393             self.assertEqual(report.return_code, 123)
394             report.failed(Path("e2"), "boom")
395             self.assertEqual(len(out_lines), 4)
396             self.assertEqual(len(err_lines), 2)
397             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
398             self.assertEqual(
399                 unstyle(str(report)),
400                 "2 files reformatted, 2 files left unchanged, "
401                 "2 files failed to reformat.",
402             )
403             self.assertEqual(report.return_code, 123)
404             report.done(Path("f4"), black.Changed.NO)
405             self.assertEqual(len(out_lines), 5)
406             self.assertEqual(len(err_lines), 2)
407             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
408             self.assertEqual(
409                 unstyle(str(report)),
410                 "2 files reformatted, 3 files left unchanged, "
411                 "2 files failed to reformat.",
412             )
413             self.assertEqual(report.return_code, 123)
414             report.check = True
415             self.assertEqual(
416                 unstyle(str(report)),
417                 "2 files would be reformatted, 3 files would be left unchanged, "
418                 "2 files would fail to reformat.",
419             )
420
421     def test_is_python36(self) -> None:
422         node = black.lib2to3_parse("def f(*, arg): ...\n")
423         self.assertFalse(black.is_python36(node))
424         node = black.lib2to3_parse("def f(*, arg,): ...\n")
425         self.assertTrue(black.is_python36(node))
426         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
427         self.assertTrue(black.is_python36(node))
428         source, expected = read_data("function")
429         node = black.lib2to3_parse(source)
430         self.assertTrue(black.is_python36(node))
431         node = black.lib2to3_parse(expected)
432         self.assertTrue(black.is_python36(node))
433         source, expected = read_data("expression")
434         node = black.lib2to3_parse(source)
435         self.assertFalse(black.is_python36(node))
436         node = black.lib2to3_parse(expected)
437         self.assertFalse(black.is_python36(node))
438
439     def test_debug_visitor(self) -> None:
440         source, _ = read_data("debug_visitor.py")
441         expected, _ = read_data("debug_visitor.out")
442         out_lines = []
443         err_lines = []
444
445         def out(msg: str, **kwargs: Any) -> None:
446             out_lines.append(msg)
447
448         def err(msg: str, **kwargs: Any) -> None:
449             err_lines.append(msg)
450
451         with patch("black.out", out), patch("black.err", err):
452             black.DebugVisitor.show(source)
453         actual = "\n".join(out_lines) + "\n"
454         log_name = ""
455         if expected != actual:
456             log_name = black.dump_to_file(*out_lines)
457         self.assertEqual(
458             expected,
459             actual,
460             f"AST print out is different. Actual version dumped to {log_name}",
461         )
462
463     def test_format_file_contents(self) -> None:
464         empty = ""
465         with self.assertRaises(black.NothingChanged):
466             black.format_file_contents(empty, line_length=ll, fast=False)
467         just_nl = "\n"
468         with self.assertRaises(black.NothingChanged):
469             black.format_file_contents(just_nl, line_length=ll, fast=False)
470         same = "l = [1, 2, 3]\n"
471         with self.assertRaises(black.NothingChanged):
472             black.format_file_contents(same, line_length=ll, fast=False)
473         different = "l = [1,2,3]"
474         expected = same
475         actual = black.format_file_contents(different, line_length=ll, fast=False)
476         self.assertEqual(expected, actual)
477         invalid = "return if you can"
478         with self.assertRaises(ValueError) as e:
479             black.format_file_contents(invalid, line_length=ll, fast=False)
480         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
481
482     def test_endmarker(self) -> None:
483         n = black.lib2to3_parse("\n")
484         self.assertEqual(n.type, black.syms.file_input)
485         self.assertEqual(len(n.children), 1)
486         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
487
488     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
489     def test_assertFormatEqual(self) -> None:
490         out_lines = []
491         err_lines = []
492
493         def out(msg: str, **kwargs: Any) -> None:
494             out_lines.append(msg)
495
496         def err(msg: str, **kwargs: Any) -> None:
497             err_lines.append(msg)
498
499         with patch("black.out", out), patch("black.err", err):
500             with self.assertRaises(AssertionError):
501                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
502
503         out_str = "".join(out_lines)
504         self.assertTrue("Expected tree:" in out_str)
505         self.assertTrue("Actual tree:" in out_str)
506         self.assertEqual("".join(err_lines), "")
507
508     def test_cache_broken_file(self) -> None:
509         with cache_dir() as workspace:
510             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
511             with cache_file.open("w") as fobj:
512                 fobj.write("this is not a pickle")
513             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
514             src = (workspace / "test.py").resolve()
515             with src.open("w") as fobj:
516                 fobj.write("print('hello')")
517             result = CliRunner().invoke(black.main, [str(src)])
518             self.assertEqual(result.exit_code, 0)
519             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
520             self.assertIn(src, cache)
521
522     def test_cache_single_file_already_cached(self) -> None:
523         with cache_dir() as workspace:
524             src = (workspace / "test.py").resolve()
525             with src.open("w") as fobj:
526                 fobj.write("print('hello')")
527             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
528             result = CliRunner().invoke(black.main, [str(src)])
529             self.assertEqual(result.exit_code, 0)
530             with src.open("r") as fobj:
531                 self.assertEqual(fobj.read(), "print('hello')")
532
533     @event_loop(close=False)
534     def test_cache_multiple_files(self) -> None:
535         with cache_dir() as workspace, patch(
536             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
537         ):
538             one = (workspace / "one.py").resolve()
539             with one.open("w") as fobj:
540                 fobj.write("print('hello')")
541             two = (workspace / "two.py").resolve()
542             with two.open("w") as fobj:
543                 fobj.write("print('hello')")
544             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH)
545             result = CliRunner().invoke(black.main, [str(workspace)])
546             self.assertEqual(result.exit_code, 0)
547             with one.open("r") as fobj:
548                 self.assertEqual(fobj.read(), "print('hello')")
549             with two.open("r") as fobj:
550                 self.assertEqual(fobj.read(), 'print("hello")\n')
551             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
552             self.assertIn(one, cache)
553             self.assertIn(two, cache)
554
555     def test_no_cache_when_writeback_diff(self) -> None:
556         with cache_dir() as workspace:
557             src = (workspace / "test.py").resolve()
558             with src.open("w") as fobj:
559                 fobj.write("print('hello')")
560             result = CliRunner().invoke(black.main, [str(src), "--diff"])
561             self.assertEqual(result.exit_code, 0)
562             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
563             self.assertFalse(cache_file.exists())
564
565     def test_no_cache_when_stdin(self) -> None:
566         with cache_dir():
567             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
568             self.assertEqual(result.exit_code, 0)
569             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
570             self.assertFalse(cache_file.exists())
571
572     def test_read_cache_no_cachefile(self) -> None:
573         with cache_dir():
574             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
575
576     def test_write_cache_read_cache(self) -> None:
577         with cache_dir() as workspace:
578             src = (workspace / "test.py").resolve()
579             src.touch()
580             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
581             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
582             self.assertIn(src, cache)
583             self.assertEqual(cache[src], black.get_cache_info(src))
584
585     def test_filter_cached(self) -> None:
586         with TemporaryDirectory() as workspace:
587             path = Path(workspace)
588             uncached = (path / "uncached").resolve()
589             cached = (path / "cached").resolve()
590             cached_but_changed = (path / "changed").resolve()
591             uncached.touch()
592             cached.touch()
593             cached_but_changed.touch()
594             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
595             todo, done = black.filter_cached(
596                 cache, [uncached, cached, cached_but_changed]
597             )
598             self.assertEqual(todo, [uncached, cached_but_changed])
599             self.assertEqual(done, [cached])
600
601     def test_write_cache_creates_directory_if_needed(self) -> None:
602         with cache_dir(exists=False) as workspace:
603             self.assertFalse(workspace.exists())
604             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
605             self.assertTrue(workspace.exists())
606
607     @event_loop(close=False)
608     def test_failed_formatting_does_not_get_cached(self) -> None:
609         with cache_dir() as workspace, patch(
610             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
611         ):
612             failing = (workspace / "failing.py").resolve()
613             with failing.open("w") as fobj:
614                 fobj.write("not actually python")
615             clean = (workspace / "clean.py").resolve()
616             with clean.open("w") as fobj:
617                 fobj.write('print("hello")\n')
618             result = CliRunner().invoke(black.main, [str(workspace)])
619             self.assertEqual(result.exit_code, 123)
620             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
621             self.assertNotIn(failing, cache)
622             self.assertIn(clean, cache)
623
624     def test_write_cache_write_fail(self) -> None:
625         with cache_dir(), patch.object(Path, "open") as mock:
626             mock.side_effect = OSError
627             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
628
629     def test_check_diff_use_together(self) -> None:
630         with cache_dir():
631             # Files which will be reformatted.
632             src1 = (THIS_DIR / "string_quotes.py").resolve()
633             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
634             self.assertEqual(result.exit_code, 1)
635
636             # Files which will not be reformatted.
637             src2 = (THIS_DIR / "composition.py").resolve()
638             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
639             self.assertEqual(result.exit_code, 0)
640
641             # Multi file command.
642             result = CliRunner().invoke(
643                 black.main, [str(src1), str(src2), "--diff", "--check"]
644             )
645             self.assertEqual(result.exit_code, 1)
646
647     def test_no_files(self) -> None:
648         with cache_dir():
649             # Without an argument, black exits with error code 0.
650             result = CliRunner().invoke(black.main, [])
651             self.assertEqual(result.exit_code, 0)
652
653     def test_read_cache_line_lengths(self) -> None:
654         with cache_dir() as workspace:
655             path = (workspace / "file.py").resolve()
656             path.touch()
657             black.write_cache({}, [path], 1)
658             one = black.read_cache(1)
659             self.assertIn(path, one)
660             two = black.read_cache(2)
661             self.assertNotIn(path, two)
662
663
664 if __name__ == "__main__":
665     unittest.main()