]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

6f0ffa364ac749e770e67549ee6e4420bfc8e739
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14
15 from click import unstyle
16 from click.testing import CliRunner
17
18 import black
19
20 ll = 88
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
26
27
28 def dump_to_stderr(*output: str) -> str:
29     return "\n" + "\n".join(output) + "\n"
30
31
32 def read_data(name: str) -> Tuple[str, str]:
33     """read_data('test_name') -> 'input', 'output'"""
34     if not name.endswith((".py", ".out", ".diff")):
35         name += ".py"
36     _input: List[str] = []
37     _output: List[str] = []
38     with open(THIS_DIR / name, "r", encoding="utf8") as test:
39         lines = test.readlines()
40     result = _input
41     for line in lines:
42         line = line.replace(EMPTY_LINE, "")
43         if line.rstrip() == "# output":
44             result = _output
45             continue
46
47         result.append(line)
48     if _input and not _output:
49         # If there's no output marker, treat the entire file as already pre-formatted.
50         _output = _input[:]
51     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
52
53
54 @contextmanager
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56     with TemporaryDirectory() as workspace:
57         cache_dir = Path(workspace)
58         if not exists:
59             cache_dir = cache_dir / "new"
60         cache_file = cache_dir / "cache.pkl"
61         with patch("black.CACHE_DIR", cache_dir), patch("black.CACHE_FILE", cache_file):
62             yield cache_dir
63
64
65 @contextmanager
66 def event_loop(close: bool) -> Iterator[None]:
67     policy = asyncio.get_event_loop_policy()
68     old_loop = policy.get_event_loop()
69     loop = policy.new_event_loop()
70     asyncio.set_event_loop(loop)
71     try:
72         yield
73
74     finally:
75         policy.set_event_loop(old_loop)
76         if close:
77             loop.close()
78
79
80 class BlackTestCase(unittest.TestCase):
81     maxDiff = None
82
83     def assertFormatEqual(self, expected: str, actual: str) -> None:
84         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
85             bdv: black.DebugVisitor[Any]
86             black.out("Expected tree:", fg="green")
87             try:
88                 exp_node = black.lib2to3_parse(expected)
89                 bdv = black.DebugVisitor()
90                 list(bdv.visit(exp_node))
91             except Exception as ve:
92                 black.err(str(ve))
93             black.out("Actual tree:", fg="red")
94             try:
95                 exp_node = black.lib2to3_parse(actual)
96                 bdv = black.DebugVisitor()
97                 list(bdv.visit(exp_node))
98             except Exception as ve:
99                 black.err(str(ve))
100         self.assertEqual(expected, actual)
101
102     @patch("black.dump_to_file", dump_to_stderr)
103     def test_self(self) -> None:
104         source, expected = read_data("test_black")
105         actual = fs(source)
106         self.assertFormatEqual(expected, actual)
107         black.assert_equivalent(source, actual)
108         black.assert_stable(source, actual, line_length=ll)
109         self.assertFalse(ff(THIS_FILE))
110
111     @patch("black.dump_to_file", dump_to_stderr)
112     def test_black(self) -> None:
113         source, expected = read_data("../black")
114         actual = fs(source)
115         self.assertFormatEqual(expected, actual)
116         black.assert_equivalent(source, actual)
117         black.assert_stable(source, actual, line_length=ll)
118         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
119
120     def test_piping(self) -> None:
121         source, expected = read_data("../black")
122         hold_stdin, hold_stdout = sys.stdin, sys.stdout
123         try:
124             sys.stdin, sys.stdout = StringIO(source), StringIO()
125             sys.stdin.name = "<stdin>"
126             black.format_stdin_to_stdout(
127                 line_length=ll, fast=True, write_back=black.WriteBack.YES
128             )
129             sys.stdout.seek(0)
130             actual = sys.stdout.read()
131         finally:
132             sys.stdin, sys.stdout = hold_stdin, hold_stdout
133         self.assertFormatEqual(expected, actual)
134         black.assert_equivalent(source, actual)
135         black.assert_stable(source, actual, line_length=ll)
136
137     def test_piping_diff(self) -> None:
138         source, _ = read_data("expression.py")
139         expected, _ = read_data("expression.diff")
140         hold_stdin, hold_stdout = sys.stdin, sys.stdout
141         try:
142             sys.stdin, sys.stdout = StringIO(source), StringIO()
143             sys.stdin.name = "<stdin>"
144             black.format_stdin_to_stdout(
145                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
146             )
147             sys.stdout.seek(0)
148             actual = sys.stdout.read()
149         finally:
150             sys.stdin, sys.stdout = hold_stdin, hold_stdout
151         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
152         self.assertEqual(expected, actual)
153
154     @patch("black.dump_to_file", dump_to_stderr)
155     def test_setup(self) -> None:
156         source, expected = read_data("../setup")
157         actual = fs(source)
158         self.assertFormatEqual(expected, actual)
159         black.assert_equivalent(source, actual)
160         black.assert_stable(source, actual, line_length=ll)
161         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
162
163     @patch("black.dump_to_file", dump_to_stderr)
164     def test_function(self) -> None:
165         source, expected = read_data("function")
166         actual = fs(source)
167         self.assertFormatEqual(expected, actual)
168         black.assert_equivalent(source, actual)
169         black.assert_stable(source, actual, line_length=ll)
170
171     @patch("black.dump_to_file", dump_to_stderr)
172     def test_expression(self) -> None:
173         source, expected = read_data("expression")
174         actual = fs(source)
175         self.assertFormatEqual(expected, actual)
176         black.assert_equivalent(source, actual)
177         black.assert_stable(source, actual, line_length=ll)
178
179     def test_expression_ff(self) -> None:
180         source, expected = read_data("expression")
181         tmp_file = Path(black.dump_to_file(source))
182         try:
183             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
184             with open(tmp_file, encoding="utf8") as f:
185                 actual = f.read()
186         finally:
187             os.unlink(tmp_file)
188         self.assertFormatEqual(expected, actual)
189         with patch("black.dump_to_file", dump_to_stderr):
190             black.assert_equivalent(source, actual)
191             black.assert_stable(source, actual, line_length=ll)
192
193     def test_expression_diff(self) -> None:
194         source, _ = read_data("expression.py")
195         expected, _ = read_data("expression.diff")
196         tmp_file = Path(black.dump_to_file(source))
197         hold_stdout = sys.stdout
198         try:
199             sys.stdout = StringIO()
200             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
201             sys.stdout.seek(0)
202             actual = sys.stdout.read()
203             actual = actual.replace(str(tmp_file), "<stdin>")
204         finally:
205             sys.stdout = hold_stdout
206             os.unlink(tmp_file)
207         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
208         if expected != actual:
209             dump = black.dump_to_file(actual)
210             msg = (
211                 f"Expected diff isn't equal to the actual. If you made changes "
212                 f"to expression.py and this is an anticipated difference, "
213                 f"overwrite tests/expression.diff with {dump}"
214             )
215             self.assertEqual(expected, actual, msg)
216
217     @patch("black.dump_to_file", dump_to_stderr)
218     def test_fstring(self) -> None:
219         source, expected = read_data("fstring")
220         actual = fs(source)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, line_length=ll)
224
225     @patch("black.dump_to_file", dump_to_stderr)
226     def test_string_quotes(self) -> None:
227         source, expected = read_data("string_quotes")
228         actual = fs(source)
229         self.assertFormatEqual(expected, actual)
230         black.assert_equivalent(source, actual)
231         black.assert_stable(source, actual, line_length=ll)
232
233     @patch("black.dump_to_file", dump_to_stderr)
234     def test_comments(self) -> None:
235         source, expected = read_data("comments")
236         actual = fs(source)
237         self.assertFormatEqual(expected, actual)
238         black.assert_equivalent(source, actual)
239         black.assert_stable(source, actual, line_length=ll)
240
241     @patch("black.dump_to_file", dump_to_stderr)
242     def test_comments2(self) -> None:
243         source, expected = read_data("comments2")
244         actual = fs(source)
245         self.assertFormatEqual(expected, actual)
246         black.assert_equivalent(source, actual)
247         black.assert_stable(source, actual, line_length=ll)
248
249     @patch("black.dump_to_file", dump_to_stderr)
250     def test_comments3(self) -> None:
251         source, expected = read_data("comments3")
252         actual = fs(source)
253         self.assertFormatEqual(expected, actual)
254         black.assert_equivalent(source, actual)
255         black.assert_stable(source, actual, line_length=ll)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_comments4(self) -> None:
259         source, expected = read_data("comments4")
260         actual = fs(source)
261         self.assertFormatEqual(expected, actual)
262         black.assert_equivalent(source, actual)
263         black.assert_stable(source, actual, line_length=ll)
264
265     @patch("black.dump_to_file", dump_to_stderr)
266     def test_cantfit(self) -> None:
267         source, expected = read_data("cantfit")
268         actual = fs(source)
269         self.assertFormatEqual(expected, actual)
270         black.assert_equivalent(source, actual)
271         black.assert_stable(source, actual, line_length=ll)
272
273     @patch("black.dump_to_file", dump_to_stderr)
274     def test_import_spacing(self) -> None:
275         source, expected = read_data("import_spacing")
276         actual = fs(source)
277         self.assertFormatEqual(expected, actual)
278         black.assert_equivalent(source, actual)
279         black.assert_stable(source, actual, line_length=ll)
280
281     @patch("black.dump_to_file", dump_to_stderr)
282     def test_composition(self) -> None:
283         source, expected = read_data("composition")
284         actual = fs(source)
285         self.assertFormatEqual(expected, actual)
286         black.assert_equivalent(source, actual)
287         black.assert_stable(source, actual, line_length=ll)
288
289     @patch("black.dump_to_file", dump_to_stderr)
290     def test_empty_lines(self) -> None:
291         source, expected = read_data("empty_lines")
292         actual = fs(source)
293         self.assertFormatEqual(expected, actual)
294         black.assert_equivalent(source, actual)
295         black.assert_stable(source, actual, line_length=ll)
296
297     @patch("black.dump_to_file", dump_to_stderr)
298     def test_python2(self) -> None:
299         source, expected = read_data("python2")
300         actual = fs(source)
301         self.assertFormatEqual(expected, actual)
302         # black.assert_equivalent(source, actual)
303         black.assert_stable(source, actual, line_length=ll)
304
305     @patch("black.dump_to_file", dump_to_stderr)
306     def test_fmtonoff(self) -> None:
307         source, expected = read_data("fmtonoff")
308         actual = fs(source)
309         self.assertFormatEqual(expected, actual)
310         black.assert_equivalent(source, actual)
311         black.assert_stable(source, actual, line_length=ll)
312
313     def test_report(self) -> None:
314         report = black.Report()
315         out_lines = []
316         err_lines = []
317
318         def out(msg: str, **kwargs: Any) -> None:
319             out_lines.append(msg)
320
321         def err(msg: str, **kwargs: Any) -> None:
322             err_lines.append(msg)
323
324         with patch("black.out", out), patch("black.err", err):
325             report.done(Path("f1"), black.Changed.NO)
326             self.assertEqual(len(out_lines), 1)
327             self.assertEqual(len(err_lines), 0)
328             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
329             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
330             self.assertEqual(report.return_code, 0)
331             report.done(Path("f2"), black.Changed.YES)
332             self.assertEqual(len(out_lines), 2)
333             self.assertEqual(len(err_lines), 0)
334             self.assertEqual(out_lines[-1], "reformatted f2")
335             self.assertEqual(
336                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
337             )
338             report.done(Path("f3"), black.Changed.CACHED)
339             self.assertEqual(len(out_lines), 3)
340             self.assertEqual(len(err_lines), 0)
341             self.assertEqual(
342                 out_lines[-1], "f3 wasn't modified on disk since last run."
343             )
344             self.assertEqual(
345                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
346             )
347             self.assertEqual(report.return_code, 0)
348             report.check = True
349             self.assertEqual(report.return_code, 1)
350             report.check = False
351             report.failed(Path("e1"), "boom")
352             self.assertEqual(len(out_lines), 3)
353             self.assertEqual(len(err_lines), 1)
354             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
355             self.assertEqual(
356                 unstyle(str(report)),
357                 "1 file reformatted, 2 files left unchanged, "
358                 "1 file failed to reformat.",
359             )
360             self.assertEqual(report.return_code, 123)
361             report.done(Path("f3"), black.Changed.YES)
362             self.assertEqual(len(out_lines), 4)
363             self.assertEqual(len(err_lines), 1)
364             self.assertEqual(out_lines[-1], "reformatted f3")
365             self.assertEqual(
366                 unstyle(str(report)),
367                 "2 files reformatted, 2 files left unchanged, "
368                 "1 file failed to reformat.",
369             )
370             self.assertEqual(report.return_code, 123)
371             report.failed(Path("e2"), "boom")
372             self.assertEqual(len(out_lines), 4)
373             self.assertEqual(len(err_lines), 2)
374             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
375             self.assertEqual(
376                 unstyle(str(report)),
377                 "2 files reformatted, 2 files left unchanged, "
378                 "2 files failed to reformat.",
379             )
380             self.assertEqual(report.return_code, 123)
381             report.done(Path("f4"), black.Changed.NO)
382             self.assertEqual(len(out_lines), 5)
383             self.assertEqual(len(err_lines), 2)
384             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
385             self.assertEqual(
386                 unstyle(str(report)),
387                 "2 files reformatted, 3 files left unchanged, "
388                 "2 files failed to reformat.",
389             )
390             self.assertEqual(report.return_code, 123)
391             report.check = True
392             self.assertEqual(
393                 unstyle(str(report)),
394                 "2 files would be reformatted, 3 files would be left unchanged, "
395                 "2 files would fail to reformat.",
396             )
397
398     def test_is_python36(self) -> None:
399         node = black.lib2to3_parse("def f(*, arg): ...\n")
400         self.assertFalse(black.is_python36(node))
401         node = black.lib2to3_parse("def f(*, arg,): ...\n")
402         self.assertTrue(black.is_python36(node))
403         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
404         self.assertTrue(black.is_python36(node))
405         source, expected = read_data("function")
406         node = black.lib2to3_parse(source)
407         self.assertTrue(black.is_python36(node))
408         node = black.lib2to3_parse(expected)
409         self.assertTrue(black.is_python36(node))
410         source, expected = read_data("expression")
411         node = black.lib2to3_parse(source)
412         self.assertFalse(black.is_python36(node))
413         node = black.lib2to3_parse(expected)
414         self.assertFalse(black.is_python36(node))
415
416     def test_debug_visitor(self) -> None:
417         source, _ = read_data("debug_visitor.py")
418         expected, _ = read_data("debug_visitor.out")
419         out_lines = []
420         err_lines = []
421
422         def out(msg: str, **kwargs: Any) -> None:
423             out_lines.append(msg)
424
425         def err(msg: str, **kwargs: Any) -> None:
426             err_lines.append(msg)
427
428         with patch("black.out", out), patch("black.err", err):
429             black.DebugVisitor.show(source)
430         actual = "\n".join(out_lines) + "\n"
431         log_name = ""
432         if expected != actual:
433             log_name = black.dump_to_file(*out_lines)
434         self.assertEqual(
435             expected,
436             actual,
437             f"AST print out is different. Actual version dumped to {log_name}",
438         )
439
440     def test_format_file_contents(self) -> None:
441         empty = ""
442         with self.assertRaises(black.NothingChanged):
443             black.format_file_contents(empty, line_length=ll, fast=False)
444         just_nl = "\n"
445         with self.assertRaises(black.NothingChanged):
446             black.format_file_contents(just_nl, line_length=ll, fast=False)
447         same = "l = [1, 2, 3]\n"
448         with self.assertRaises(black.NothingChanged):
449             black.format_file_contents(same, line_length=ll, fast=False)
450         different = "l = [1,2,3]"
451         expected = same
452         actual = black.format_file_contents(different, line_length=ll, fast=False)
453         self.assertEqual(expected, actual)
454         invalid = "return if you can"
455         with self.assertRaises(ValueError) as e:
456             black.format_file_contents(invalid, line_length=ll, fast=False)
457         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
458
459     def test_endmarker(self) -> None:
460         n = black.lib2to3_parse("\n")
461         self.assertEqual(n.type, black.syms.file_input)
462         self.assertEqual(len(n.children), 1)
463         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
464
465     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
466     def test_assertFormatEqual(self) -> None:
467         out_lines = []
468         err_lines = []
469
470         def out(msg: str, **kwargs: Any) -> None:
471             out_lines.append(msg)
472
473         def err(msg: str, **kwargs: Any) -> None:
474             err_lines.append(msg)
475
476         with patch("black.out", out), patch("black.err", err):
477             with self.assertRaises(AssertionError):
478                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
479
480         out_str = "".join(out_lines)
481         self.assertTrue("Expected tree:" in out_str)
482         self.assertTrue("Actual tree:" in out_str)
483         self.assertEqual("".join(err_lines), "")
484
485     def test_cache_broken_file(self) -> None:
486         with cache_dir() as workspace:
487             with black.CACHE_FILE.open("w") as fobj:
488                 fobj.write("this is not a pickle")
489             self.assertEqual(black.read_cache(), {})
490             src = (workspace / "test.py").resolve()
491             with src.open("w") as fobj:
492                 fobj.write("print('hello')")
493             result = CliRunner().invoke(black.main, [str(src)])
494             self.assertEqual(result.exit_code, 0)
495             cache = black.read_cache()
496             self.assertIn(src, cache)
497
498     def test_cache_single_file_already_cached(self) -> None:
499         with cache_dir() as workspace:
500             src = (workspace / "test.py").resolve()
501             with src.open("w") as fobj:
502                 fobj.write("print('hello')")
503             black.write_cache({}, [src])
504             result = CliRunner().invoke(black.main, [str(src)])
505             self.assertEqual(result.exit_code, 0)
506             with src.open("r") as fobj:
507                 self.assertEqual(fobj.read(), "print('hello')")
508
509     @event_loop(close=False)
510     def test_cache_multiple_files(self) -> None:
511         with cache_dir() as workspace, patch(
512             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
513         ):
514             one = (workspace / "one.py").resolve()
515             with one.open("w") as fobj:
516                 fobj.write("print('hello')")
517             two = (workspace / "two.py").resolve()
518             with two.open("w") as fobj:
519                 fobj.write("print('hello')")
520             black.write_cache({}, [one])
521             result = CliRunner().invoke(black.main, [str(workspace)])
522             self.assertEqual(result.exit_code, 0)
523             with one.open("r") as fobj:
524                 self.assertEqual(fobj.read(), "print('hello')")
525             with two.open("r") as fobj:
526                 self.assertEqual(fobj.read(), 'print("hello")\n')
527             cache = black.read_cache()
528             self.assertIn(one, cache)
529             self.assertIn(two, cache)
530
531     def test_no_cache_when_writeback_diff(self) -> None:
532         with cache_dir() as workspace:
533             src = (workspace / "test.py").resolve()
534             with src.open("w") as fobj:
535                 fobj.write("print('hello')")
536             result = CliRunner().invoke(black.main, [str(src), "--diff"])
537             self.assertEqual(result.exit_code, 0)
538             self.assertFalse(black.CACHE_FILE.exists())
539
540     def test_no_cache_when_stdin(self) -> None:
541         with cache_dir():
542             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
543             self.assertEqual(result.exit_code, 0)
544             self.assertFalse(black.CACHE_FILE.exists())
545
546     def test_read_cache_no_cachefile(self) -> None:
547         with cache_dir():
548             self.assertEqual(black.read_cache(), {})
549
550     def test_write_cache_read_cache(self) -> None:
551         with cache_dir() as workspace:
552             src = (workspace / "test.py").resolve()
553             src.touch()
554             black.write_cache({}, [src])
555             cache = black.read_cache()
556             self.assertIn(src, cache)
557             self.assertEqual(cache[src], black.get_cache_info(src))
558
559     def test_filter_cached(self) -> None:
560         with TemporaryDirectory() as workspace:
561             path = Path(workspace)
562             uncached = (path / "uncached").resolve()
563             cached = (path / "cached").resolve()
564             cached_but_changed = (path / "changed").resolve()
565             uncached.touch()
566             cached.touch()
567             cached_but_changed.touch()
568             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
569             todo, done = black.filter_cached(
570                 cache, [uncached, cached, cached_but_changed]
571             )
572             self.assertEqual(todo, [uncached, cached_but_changed])
573             self.assertEqual(done, [cached])
574
575     def test_write_cache_creates_directory_if_needed(self) -> None:
576         with cache_dir(exists=False) as workspace:
577             self.assertFalse(workspace.exists())
578             black.write_cache({}, [])
579             self.assertTrue(workspace.exists())
580
581     @event_loop(close=False)
582     def test_failed_formatting_does_not_get_cached(self) -> None:
583         with cache_dir() as workspace, patch(
584             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
585         ):
586             failing = (workspace / "failing.py").resolve()
587             with failing.open("w") as fobj:
588                 fobj.write("not actually python")
589             clean = (workspace / "clean.py").resolve()
590             with clean.open("w") as fobj:
591                 fobj.write('print("hello")\n')
592             result = CliRunner().invoke(black.main, [str(workspace)])
593             self.assertEqual(result.exit_code, 123)
594             cache = black.read_cache()
595             self.assertNotIn(failing, cache)
596             self.assertIn(clean, cache)
597
598     def test_write_cache_write_fail(self) -> None:
599         with cache_dir(), patch.object(Path, "open") as mock:
600             mock.side_effect = OSError
601             black.write_cache({}, [])
602
603     def test_check_diff_use_together(self) -> None:
604         with cache_dir():
605             # Files which will be reformatted.
606             src1 = (THIS_DIR / "string_quotes.py").resolve()
607             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
608             self.assertEqual(result.exit_code, 1)
609
610             # Files which will not be reformatted.
611             src2 = (THIS_DIR / "composition.py").resolve()
612             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
613             self.assertEqual(result.exit_code, 0)
614
615             # Multi file command.
616             result = CliRunner().invoke(
617                 black.main, [str(src1), str(src2), "--diff", "--check"]
618             )
619             self.assertEqual(result.exit_code, 1)
620
621
622 if __name__ == "__main__":
623     unittest.main()