]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add more files/directories to .gitignore (#191)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14
15 from click import unstyle
16 from click.testing import CliRunner
17
18 import black
19
20 ll = 88
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
26
27
28 def dump_to_stderr(*output: str) -> str:
29     return "\n" + "\n".join(output) + "\n"
30
31
32 def read_data(name: str) -> Tuple[str, str]:
33     """read_data('test_name') -> 'input', 'output'"""
34     if not name.endswith((".py", ".out", ".diff")):
35         name += ".py"
36     _input: List[str] = []
37     _output: List[str] = []
38     with open(THIS_DIR / name, "r", encoding="utf8") as test:
39         lines = test.readlines()
40     result = _input
41     for line in lines:
42         line = line.replace(EMPTY_LINE, "")
43         if line.rstrip() == "# output":
44             result = _output
45             continue
46
47         result.append(line)
48     if _input and not _output:
49         # If there's no output marker, treat the entire file as already pre-formatted.
50         _output = _input[:]
51     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
52
53
54 @contextmanager
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56     with TemporaryDirectory() as workspace:
57         cache_dir = Path(workspace)
58         if not exists:
59             cache_dir = cache_dir / "new"
60         with patch("black.CACHE_DIR", cache_dir):
61             yield cache_dir
62
63
64 @contextmanager
65 def event_loop(close: bool) -> Iterator[None]:
66     policy = asyncio.get_event_loop_policy()
67     old_loop = policy.get_event_loop()
68     loop = policy.new_event_loop()
69     asyncio.set_event_loop(loop)
70     try:
71         yield
72
73     finally:
74         policy.set_event_loop(old_loop)
75         if close:
76             loop.close()
77
78
79 class BlackTestCase(unittest.TestCase):
80     maxDiff = None
81
82     def assertFormatEqual(self, expected: str, actual: str) -> None:
83         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84             bdv: black.DebugVisitor[Any]
85             black.out("Expected tree:", fg="green")
86             try:
87                 exp_node = black.lib2to3_parse(expected)
88                 bdv = black.DebugVisitor()
89                 list(bdv.visit(exp_node))
90             except Exception as ve:
91                 black.err(str(ve))
92             black.out("Actual tree:", fg="red")
93             try:
94                 exp_node = black.lib2to3_parse(actual)
95                 bdv = black.DebugVisitor()
96                 list(bdv.visit(exp_node))
97             except Exception as ve:
98                 black.err(str(ve))
99         self.assertEqual(expected, actual)
100
101     @patch("black.dump_to_file", dump_to_stderr)
102     def test_self(self) -> None:
103         source, expected = read_data("test_black")
104         actual = fs(source)
105         self.assertFormatEqual(expected, actual)
106         black.assert_equivalent(source, actual)
107         black.assert_stable(source, actual, line_length=ll)
108         self.assertFalse(ff(THIS_FILE))
109
110     @patch("black.dump_to_file", dump_to_stderr)
111     def test_black(self) -> None:
112         source, expected = read_data("../black")
113         actual = fs(source)
114         self.assertFormatEqual(expected, actual)
115         black.assert_equivalent(source, actual)
116         black.assert_stable(source, actual, line_length=ll)
117         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
118
119     def test_piping(self) -> None:
120         source, expected = read_data("../black")
121         hold_stdin, hold_stdout = sys.stdin, sys.stdout
122         try:
123             sys.stdin, sys.stdout = StringIO(source), StringIO()
124             sys.stdin.name = "<stdin>"
125             black.format_stdin_to_stdout(
126                 line_length=ll, fast=True, write_back=black.WriteBack.YES
127             )
128             sys.stdout.seek(0)
129             actual = sys.stdout.read()
130         finally:
131             sys.stdin, sys.stdout = hold_stdin, hold_stdout
132         self.assertFormatEqual(expected, actual)
133         black.assert_equivalent(source, actual)
134         black.assert_stable(source, actual, line_length=ll)
135
136     def test_piping_diff(self) -> None:
137         source, _ = read_data("expression.py")
138         expected, _ = read_data("expression.diff")
139         hold_stdin, hold_stdout = sys.stdin, sys.stdout
140         try:
141             sys.stdin, sys.stdout = StringIO(source), StringIO()
142             sys.stdin.name = "<stdin>"
143             black.format_stdin_to_stdout(
144                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
145             )
146             sys.stdout.seek(0)
147             actual = sys.stdout.read()
148         finally:
149             sys.stdin, sys.stdout = hold_stdin, hold_stdout
150         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
151         self.assertEqual(expected, actual)
152
153     @patch("black.dump_to_file", dump_to_stderr)
154     def test_setup(self) -> None:
155         source, expected = read_data("../setup")
156         actual = fs(source)
157         self.assertFormatEqual(expected, actual)
158         black.assert_equivalent(source, actual)
159         black.assert_stable(source, actual, line_length=ll)
160         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
161
162     @patch("black.dump_to_file", dump_to_stderr)
163     def test_function(self) -> None:
164         source, expected = read_data("function")
165         actual = fs(source)
166         self.assertFormatEqual(expected, actual)
167         black.assert_equivalent(source, actual)
168         black.assert_stable(source, actual, line_length=ll)
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_expression(self) -> None:
172         source, expected = read_data("expression")
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, line_length=ll)
177
178     def test_expression_ff(self) -> None:
179         source, expected = read_data("expression")
180         tmp_file = Path(black.dump_to_file(source))
181         try:
182             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188         with patch("black.dump_to_file", dump_to_stderr):
189             black.assert_equivalent(source, actual)
190             black.assert_stable(source, actual, line_length=ll)
191
192     def test_expression_diff(self) -> None:
193         source, _ = read_data("expression.py")
194         expected, _ = read_data("expression.diff")
195         tmp_file = Path(black.dump_to_file(source))
196         hold_stdout = sys.stdout
197         try:
198             sys.stdout = StringIO()
199             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
200             sys.stdout.seek(0)
201             actual = sys.stdout.read()
202             actual = actual.replace(str(tmp_file), "<stdin>")
203         finally:
204             sys.stdout = hold_stdout
205             os.unlink(tmp_file)
206         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
207         if expected != actual:
208             dump = black.dump_to_file(actual)
209             msg = (
210                 f"Expected diff isn't equal to the actual. If you made changes "
211                 f"to expression.py and this is an anticipated difference, "
212                 f"overwrite tests/expression.diff with {dump}"
213             )
214             self.assertEqual(expected, actual, msg)
215
216     @patch("black.dump_to_file", dump_to_stderr)
217     def test_fstring(self) -> None:
218         source, expected = read_data("fstring")
219         actual = fs(source)
220         self.assertFormatEqual(expected, actual)
221         black.assert_equivalent(source, actual)
222         black.assert_stable(source, actual, line_length=ll)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_string_quotes(self) -> None:
226         source, expected = read_data("string_quotes")
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_slices(self) -> None:
234         source, expected = read_data("slices")
235         actual = fs(source)
236         self.assertFormatEqual(expected, actual)
237         black.assert_equivalent(source, actual)
238         black.assert_stable(source, actual, line_length=ll)
239
240     @patch("black.dump_to_file", dump_to_stderr)
241     def test_comments(self) -> None:
242         source, expected = read_data("comments")
243         actual = fs(source)
244         self.assertFormatEqual(expected, actual)
245         black.assert_equivalent(source, actual)
246         black.assert_stable(source, actual, line_length=ll)
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_comments2(self) -> None:
250         source, expected = read_data("comments2")
251         actual = fs(source)
252         self.assertFormatEqual(expected, actual)
253         black.assert_equivalent(source, actual)
254         black.assert_stable(source, actual, line_length=ll)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_comments3(self) -> None:
258         source, expected = read_data("comments3")
259         actual = fs(source)
260         self.assertFormatEqual(expected, actual)
261         black.assert_equivalent(source, actual)
262         black.assert_stable(source, actual, line_length=ll)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_comments4(self) -> None:
266         source, expected = read_data("comments4")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, line_length=ll)
271
272     @patch("black.dump_to_file", dump_to_stderr)
273     def test_comments5(self) -> None:
274         source, expected = read_data("comments5")
275         actual = fs(source)
276         self.assertFormatEqual(expected, actual)
277         black.assert_equivalent(source, actual)
278         black.assert_stable(source, actual, line_length=ll)
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_cantfit(self) -> None:
282         source, expected = read_data("cantfit")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_equivalent(source, actual)
286         black.assert_stable(source, actual, line_length=ll)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_import_spacing(self) -> None:
290         source, expected = read_data("import_spacing")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_composition(self) -> None:
298         source, expected = read_data("composition")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_empty_lines(self) -> None:
306         source, expected = read_data("empty_lines")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     @patch("black.dump_to_file", dump_to_stderr)
313     def test_python2(self) -> None:
314         source, expected = read_data("python2")
315         actual = fs(source)
316         self.assertFormatEqual(expected, actual)
317         # black.assert_equivalent(source, actual)
318         black.assert_stable(source, actual, line_length=ll)
319
320     @patch("black.dump_to_file", dump_to_stderr)
321     def test_fmtonoff(self) -> None:
322         source, expected = read_data("fmtonoff")
323         actual = fs(source)
324         self.assertFormatEqual(expected, actual)
325         black.assert_equivalent(source, actual)
326         black.assert_stable(source, actual, line_length=ll)
327
328     def test_report(self) -> None:
329         report = black.Report()
330         out_lines = []
331         err_lines = []
332
333         def out(msg: str, **kwargs: Any) -> None:
334             out_lines.append(msg)
335
336         def err(msg: str, **kwargs: Any) -> None:
337             err_lines.append(msg)
338
339         with patch("black.out", out), patch("black.err", err):
340             report.done(Path("f1"), black.Changed.NO)
341             self.assertEqual(len(out_lines), 1)
342             self.assertEqual(len(err_lines), 0)
343             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
344             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
345             self.assertEqual(report.return_code, 0)
346             report.done(Path("f2"), black.Changed.YES)
347             self.assertEqual(len(out_lines), 2)
348             self.assertEqual(len(err_lines), 0)
349             self.assertEqual(out_lines[-1], "reformatted f2")
350             self.assertEqual(
351                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
352             )
353             report.done(Path("f3"), black.Changed.CACHED)
354             self.assertEqual(len(out_lines), 3)
355             self.assertEqual(len(err_lines), 0)
356             self.assertEqual(
357                 out_lines[-1], "f3 wasn't modified on disk since last run."
358             )
359             self.assertEqual(
360                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
361             )
362             self.assertEqual(report.return_code, 0)
363             report.check = True
364             self.assertEqual(report.return_code, 1)
365             report.check = False
366             report.failed(Path("e1"), "boom")
367             self.assertEqual(len(out_lines), 3)
368             self.assertEqual(len(err_lines), 1)
369             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
370             self.assertEqual(
371                 unstyle(str(report)),
372                 "1 file reformatted, 2 files left unchanged, "
373                 "1 file failed to reformat.",
374             )
375             self.assertEqual(report.return_code, 123)
376             report.done(Path("f3"), black.Changed.YES)
377             self.assertEqual(len(out_lines), 4)
378             self.assertEqual(len(err_lines), 1)
379             self.assertEqual(out_lines[-1], "reformatted f3")
380             self.assertEqual(
381                 unstyle(str(report)),
382                 "2 files reformatted, 2 files left unchanged, "
383                 "1 file failed to reformat.",
384             )
385             self.assertEqual(report.return_code, 123)
386             report.failed(Path("e2"), "boom")
387             self.assertEqual(len(out_lines), 4)
388             self.assertEqual(len(err_lines), 2)
389             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
390             self.assertEqual(
391                 unstyle(str(report)),
392                 "2 files reformatted, 2 files left unchanged, "
393                 "2 files failed to reformat.",
394             )
395             self.assertEqual(report.return_code, 123)
396             report.done(Path("f4"), black.Changed.NO)
397             self.assertEqual(len(out_lines), 5)
398             self.assertEqual(len(err_lines), 2)
399             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
400             self.assertEqual(
401                 unstyle(str(report)),
402                 "2 files reformatted, 3 files left unchanged, "
403                 "2 files failed to reformat.",
404             )
405             self.assertEqual(report.return_code, 123)
406             report.check = True
407             self.assertEqual(
408                 unstyle(str(report)),
409                 "2 files would be reformatted, 3 files would be left unchanged, "
410                 "2 files would fail to reformat.",
411             )
412
413     def test_is_python36(self) -> None:
414         node = black.lib2to3_parse("def f(*, arg): ...\n")
415         self.assertFalse(black.is_python36(node))
416         node = black.lib2to3_parse("def f(*, arg,): ...\n")
417         self.assertTrue(black.is_python36(node))
418         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
419         self.assertTrue(black.is_python36(node))
420         source, expected = read_data("function")
421         node = black.lib2to3_parse(source)
422         self.assertTrue(black.is_python36(node))
423         node = black.lib2to3_parse(expected)
424         self.assertTrue(black.is_python36(node))
425         source, expected = read_data("expression")
426         node = black.lib2to3_parse(source)
427         self.assertFalse(black.is_python36(node))
428         node = black.lib2to3_parse(expected)
429         self.assertFalse(black.is_python36(node))
430
431     def test_debug_visitor(self) -> None:
432         source, _ = read_data("debug_visitor.py")
433         expected, _ = read_data("debug_visitor.out")
434         out_lines = []
435         err_lines = []
436
437         def out(msg: str, **kwargs: Any) -> None:
438             out_lines.append(msg)
439
440         def err(msg: str, **kwargs: Any) -> None:
441             err_lines.append(msg)
442
443         with patch("black.out", out), patch("black.err", err):
444             black.DebugVisitor.show(source)
445         actual = "\n".join(out_lines) + "\n"
446         log_name = ""
447         if expected != actual:
448             log_name = black.dump_to_file(*out_lines)
449         self.assertEqual(
450             expected,
451             actual,
452             f"AST print out is different. Actual version dumped to {log_name}",
453         )
454
455     def test_format_file_contents(self) -> None:
456         empty = ""
457         with self.assertRaises(black.NothingChanged):
458             black.format_file_contents(empty, line_length=ll, fast=False)
459         just_nl = "\n"
460         with self.assertRaises(black.NothingChanged):
461             black.format_file_contents(just_nl, line_length=ll, fast=False)
462         same = "l = [1, 2, 3]\n"
463         with self.assertRaises(black.NothingChanged):
464             black.format_file_contents(same, line_length=ll, fast=False)
465         different = "l = [1,2,3]"
466         expected = same
467         actual = black.format_file_contents(different, line_length=ll, fast=False)
468         self.assertEqual(expected, actual)
469         invalid = "return if you can"
470         with self.assertRaises(ValueError) as e:
471             black.format_file_contents(invalid, line_length=ll, fast=False)
472         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
473
474     def test_endmarker(self) -> None:
475         n = black.lib2to3_parse("\n")
476         self.assertEqual(n.type, black.syms.file_input)
477         self.assertEqual(len(n.children), 1)
478         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
479
480     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
481     def test_assertFormatEqual(self) -> None:
482         out_lines = []
483         err_lines = []
484
485         def out(msg: str, **kwargs: Any) -> None:
486             out_lines.append(msg)
487
488         def err(msg: str, **kwargs: Any) -> None:
489             err_lines.append(msg)
490
491         with patch("black.out", out), patch("black.err", err):
492             with self.assertRaises(AssertionError):
493                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
494
495         out_str = "".join(out_lines)
496         self.assertTrue("Expected tree:" in out_str)
497         self.assertTrue("Actual tree:" in out_str)
498         self.assertEqual("".join(err_lines), "")
499
500     def test_cache_broken_file(self) -> None:
501         with cache_dir() as workspace:
502             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
503             with cache_file.open("w") as fobj:
504                 fobj.write("this is not a pickle")
505             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
506             src = (workspace / "test.py").resolve()
507             with src.open("w") as fobj:
508                 fobj.write("print('hello')")
509             result = CliRunner().invoke(black.main, [str(src)])
510             self.assertEqual(result.exit_code, 0)
511             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
512             self.assertIn(src, cache)
513
514     def test_cache_single_file_already_cached(self) -> None:
515         with cache_dir() as workspace:
516             src = (workspace / "test.py").resolve()
517             with src.open("w") as fobj:
518                 fobj.write("print('hello')")
519             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
520             result = CliRunner().invoke(black.main, [str(src)])
521             self.assertEqual(result.exit_code, 0)
522             with src.open("r") as fobj:
523                 self.assertEqual(fobj.read(), "print('hello')")
524
525     @event_loop(close=False)
526     def test_cache_multiple_files(self) -> None:
527         with cache_dir() as workspace, patch(
528             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
529         ):
530             one = (workspace / "one.py").resolve()
531             with one.open("w") as fobj:
532                 fobj.write("print('hello')")
533             two = (workspace / "two.py").resolve()
534             with two.open("w") as fobj:
535                 fobj.write("print('hello')")
536             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH)
537             result = CliRunner().invoke(black.main, [str(workspace)])
538             self.assertEqual(result.exit_code, 0)
539             with one.open("r") as fobj:
540                 self.assertEqual(fobj.read(), "print('hello')")
541             with two.open("r") as fobj:
542                 self.assertEqual(fobj.read(), 'print("hello")\n')
543             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
544             self.assertIn(one, cache)
545             self.assertIn(two, cache)
546
547     def test_no_cache_when_writeback_diff(self) -> None:
548         with cache_dir() as workspace:
549             src = (workspace / "test.py").resolve()
550             with src.open("w") as fobj:
551                 fobj.write("print('hello')")
552             result = CliRunner().invoke(black.main, [str(src), "--diff"])
553             self.assertEqual(result.exit_code, 0)
554             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
555             self.assertFalse(cache_file.exists())
556
557     def test_no_cache_when_stdin(self) -> None:
558         with cache_dir():
559             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
560             self.assertEqual(result.exit_code, 0)
561             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
562             self.assertFalse(cache_file.exists())
563
564     def test_read_cache_no_cachefile(self) -> None:
565         with cache_dir():
566             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
567
568     def test_write_cache_read_cache(self) -> None:
569         with cache_dir() as workspace:
570             src = (workspace / "test.py").resolve()
571             src.touch()
572             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
573             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
574             self.assertIn(src, cache)
575             self.assertEqual(cache[src], black.get_cache_info(src))
576
577     def test_filter_cached(self) -> None:
578         with TemporaryDirectory() as workspace:
579             path = Path(workspace)
580             uncached = (path / "uncached").resolve()
581             cached = (path / "cached").resolve()
582             cached_but_changed = (path / "changed").resolve()
583             uncached.touch()
584             cached.touch()
585             cached_but_changed.touch()
586             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
587             todo, done = black.filter_cached(
588                 cache, [uncached, cached, cached_but_changed]
589             )
590             self.assertEqual(todo, [uncached, cached_but_changed])
591             self.assertEqual(done, [cached])
592
593     def test_write_cache_creates_directory_if_needed(self) -> None:
594         with cache_dir(exists=False) as workspace:
595             self.assertFalse(workspace.exists())
596             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
597             self.assertTrue(workspace.exists())
598
599     @event_loop(close=False)
600     def test_failed_formatting_does_not_get_cached(self) -> None:
601         with cache_dir() as workspace, patch(
602             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
603         ):
604             failing = (workspace / "failing.py").resolve()
605             with failing.open("w") as fobj:
606                 fobj.write("not actually python")
607             clean = (workspace / "clean.py").resolve()
608             with clean.open("w") as fobj:
609                 fobj.write('print("hello")\n')
610             result = CliRunner().invoke(black.main, [str(workspace)])
611             self.assertEqual(result.exit_code, 123)
612             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
613             self.assertNotIn(failing, cache)
614             self.assertIn(clean, cache)
615
616     def test_write_cache_write_fail(self) -> None:
617         with cache_dir(), patch.object(Path, "open") as mock:
618             mock.side_effect = OSError
619             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
620
621     def test_check_diff_use_together(self) -> None:
622         with cache_dir():
623             # Files which will be reformatted.
624             src1 = (THIS_DIR / "string_quotes.py").resolve()
625             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
626             self.assertEqual(result.exit_code, 1)
627
628             # Files which will not be reformatted.
629             src2 = (THIS_DIR / "composition.py").resolve()
630             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
631             self.assertEqual(result.exit_code, 0)
632
633             # Multi file command.
634             result = CliRunner().invoke(
635                 black.main, [str(src1), str(src2), "--diff", "--check"]
636             )
637             self.assertEqual(result.exit_code, 1)
638
639     def test_read_cache_line_lengths(self) -> None:
640         with cache_dir() as workspace:
641             path = (workspace / "file.py").resolve()
642             path.touch()
643             black.write_cache({}, [path], 1)
644             one = black.read_cache(1)
645             self.assertIn(path, one)
646             two = black.read_cache(2)
647             self.assertNotIn(path, two)
648
649
650 if __name__ == "__main__":
651     unittest.main()