]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix docstrings of visit_stmt and normalize_invisible_parens
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14
15 from click import unstyle
16 from click.testing import CliRunner
17
18 import black
19
20 ll = 88
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
26
27
28 def dump_to_stderr(*output: str) -> str:
29     return "\n" + "\n".join(output) + "\n"
30
31
32 def read_data(name: str) -> Tuple[str, str]:
33     """read_data('test_name') -> 'input', 'output'"""
34     if not name.endswith((".py", ".out", ".diff")):
35         name += ".py"
36     _input: List[str] = []
37     _output: List[str] = []
38     with open(THIS_DIR / name, "r", encoding="utf8") as test:
39         lines = test.readlines()
40     result = _input
41     for line in lines:
42         line = line.replace(EMPTY_LINE, "")
43         if line.rstrip() == "# output":
44             result = _output
45             continue
46
47         result.append(line)
48     if _input and not _output:
49         # If there's no output marker, treat the entire file as already pre-formatted.
50         _output = _input[:]
51     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
52
53
54 @contextmanager
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56     with TemporaryDirectory() as workspace:
57         cache_dir = Path(workspace)
58         if not exists:
59             cache_dir = cache_dir / "new"
60         with patch("black.CACHE_DIR", cache_dir):
61             yield cache_dir
62
63
64 @contextmanager
65 def event_loop(close: bool) -> Iterator[None]:
66     policy = asyncio.get_event_loop_policy()
67     old_loop = policy.get_event_loop()
68     loop = policy.new_event_loop()
69     asyncio.set_event_loop(loop)
70     try:
71         yield
72
73     finally:
74         policy.set_event_loop(old_loop)
75         if close:
76             loop.close()
77
78
79 class BlackTestCase(unittest.TestCase):
80     maxDiff = None
81
82     def assertFormatEqual(self, expected: str, actual: str) -> None:
83         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84             bdv: black.DebugVisitor[Any]
85             black.out("Expected tree:", fg="green")
86             try:
87                 exp_node = black.lib2to3_parse(expected)
88                 bdv = black.DebugVisitor()
89                 list(bdv.visit(exp_node))
90             except Exception as ve:
91                 black.err(str(ve))
92             black.out("Actual tree:", fg="red")
93             try:
94                 exp_node = black.lib2to3_parse(actual)
95                 bdv = black.DebugVisitor()
96                 list(bdv.visit(exp_node))
97             except Exception as ve:
98                 black.err(str(ve))
99         self.assertEqual(expected, actual)
100
101     @patch("black.dump_to_file", dump_to_stderr)
102     def test_self(self) -> None:
103         source, expected = read_data("test_black")
104         actual = fs(source)
105         self.assertFormatEqual(expected, actual)
106         black.assert_equivalent(source, actual)
107         black.assert_stable(source, actual, line_length=ll)
108         self.assertFalse(ff(THIS_FILE))
109
110     @patch("black.dump_to_file", dump_to_stderr)
111     def test_black(self) -> None:
112         source, expected = read_data("../black")
113         actual = fs(source)
114         self.assertFormatEqual(expected, actual)
115         black.assert_equivalent(source, actual)
116         black.assert_stable(source, actual, line_length=ll)
117         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
118
119     def test_piping(self) -> None:
120         source, expected = read_data("../black")
121         hold_stdin, hold_stdout = sys.stdin, sys.stdout
122         try:
123             sys.stdin, sys.stdout = StringIO(source), StringIO()
124             sys.stdin.name = "<stdin>"
125             black.format_stdin_to_stdout(
126                 line_length=ll, fast=True, write_back=black.WriteBack.YES
127             )
128             sys.stdout.seek(0)
129             actual = sys.stdout.read()
130         finally:
131             sys.stdin, sys.stdout = hold_stdin, hold_stdout
132         self.assertFormatEqual(expected, actual)
133         black.assert_equivalent(source, actual)
134         black.assert_stable(source, actual, line_length=ll)
135
136     def test_piping_diff(self) -> None:
137         source, _ = read_data("expression.py")
138         expected, _ = read_data("expression.diff")
139         hold_stdin, hold_stdout = sys.stdin, sys.stdout
140         try:
141             sys.stdin, sys.stdout = StringIO(source), StringIO()
142             sys.stdin.name = "<stdin>"
143             black.format_stdin_to_stdout(
144                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
145             )
146             sys.stdout.seek(0)
147             actual = sys.stdout.read()
148         finally:
149             sys.stdin, sys.stdout = hold_stdin, hold_stdout
150         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
151         self.assertEqual(expected, actual)
152
153     @patch("black.dump_to_file", dump_to_stderr)
154     def test_setup(self) -> None:
155         source, expected = read_data("../setup")
156         actual = fs(source)
157         self.assertFormatEqual(expected, actual)
158         black.assert_equivalent(source, actual)
159         black.assert_stable(source, actual, line_length=ll)
160         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
161
162     @patch("black.dump_to_file", dump_to_stderr)
163     def test_function(self) -> None:
164         source, expected = read_data("function")
165         actual = fs(source)
166         self.assertFormatEqual(expected, actual)
167         black.assert_equivalent(source, actual)
168         black.assert_stable(source, actual, line_length=ll)
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_function2(self) -> None:
172         source, expected = read_data("function2")
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, line_length=ll)
177
178     @patch("black.dump_to_file", dump_to_stderr)
179     def test_expression(self) -> None:
180         source, expected = read_data("expression")
181         actual = fs(source)
182         self.assertFormatEqual(expected, actual)
183         black.assert_equivalent(source, actual)
184         black.assert_stable(source, actual, line_length=ll)
185
186     def test_expression_ff(self) -> None:
187         source, expected = read_data("expression")
188         tmp_file = Path(black.dump_to_file(source))
189         try:
190             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
191             with open(tmp_file, encoding="utf8") as f:
192                 actual = f.read()
193         finally:
194             os.unlink(tmp_file)
195         self.assertFormatEqual(expected, actual)
196         with patch("black.dump_to_file", dump_to_stderr):
197             black.assert_equivalent(source, actual)
198             black.assert_stable(source, actual, line_length=ll)
199
200     def test_expression_diff(self) -> None:
201         source, _ = read_data("expression.py")
202         expected, _ = read_data("expression.diff")
203         tmp_file = Path(black.dump_to_file(source))
204         hold_stdout = sys.stdout
205         try:
206             sys.stdout = StringIO()
207             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
208             sys.stdout.seek(0)
209             actual = sys.stdout.read()
210             actual = actual.replace(str(tmp_file), "<stdin>")
211         finally:
212             sys.stdout = hold_stdout
213             os.unlink(tmp_file)
214         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
215         if expected != actual:
216             dump = black.dump_to_file(actual)
217             msg = (
218                 f"Expected diff isn't equal to the actual. If you made changes "
219                 f"to expression.py and this is an anticipated difference, "
220                 f"overwrite tests/expression.diff with {dump}"
221             )
222             self.assertEqual(expected, actual, msg)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_fstring(self) -> None:
226         source, expected = read_data("fstring")
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_string_quotes(self) -> None:
234         source, expected = read_data("string_quotes")
235         actual = fs(source)
236         self.assertFormatEqual(expected, actual)
237         black.assert_equivalent(source, actual)
238         black.assert_stable(source, actual, line_length=ll)
239
240     @patch("black.dump_to_file", dump_to_stderr)
241     def test_slices(self) -> None:
242         source, expected = read_data("slices")
243         actual = fs(source)
244         self.assertFormatEqual(expected, actual)
245         black.assert_equivalent(source, actual)
246         black.assert_stable(source, actual, line_length=ll)
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_comments(self) -> None:
250         source, expected = read_data("comments")
251         actual = fs(source)
252         self.assertFormatEqual(expected, actual)
253         black.assert_equivalent(source, actual)
254         black.assert_stable(source, actual, line_length=ll)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_comments2(self) -> None:
258         source, expected = read_data("comments2")
259         actual = fs(source)
260         self.assertFormatEqual(expected, actual)
261         black.assert_equivalent(source, actual)
262         black.assert_stable(source, actual, line_length=ll)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_comments3(self) -> None:
266         source, expected = read_data("comments3")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, line_length=ll)
271
272     @patch("black.dump_to_file", dump_to_stderr)
273     def test_comments4(self) -> None:
274         source, expected = read_data("comments4")
275         actual = fs(source)
276         self.assertFormatEqual(expected, actual)
277         black.assert_equivalent(source, actual)
278         black.assert_stable(source, actual, line_length=ll)
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_comments5(self) -> None:
282         source, expected = read_data("comments5")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_equivalent(source, actual)
286         black.assert_stable(source, actual, line_length=ll)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_cantfit(self) -> None:
290         source, expected = read_data("cantfit")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_import_spacing(self) -> None:
298         source, expected = read_data("import_spacing")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_composition(self) -> None:
306         source, expected = read_data("composition")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     @patch("black.dump_to_file", dump_to_stderr)
313     def test_empty_lines(self) -> None:
314         source, expected = read_data("empty_lines")
315         actual = fs(source)
316         self.assertFormatEqual(expected, actual)
317         black.assert_equivalent(source, actual)
318         black.assert_stable(source, actual, line_length=ll)
319
320     @patch("black.dump_to_file", dump_to_stderr)
321     def test_python2(self) -> None:
322         source, expected = read_data("python2")
323         actual = fs(source)
324         self.assertFormatEqual(expected, actual)
325         # black.assert_equivalent(source, actual)
326         black.assert_stable(source, actual, line_length=ll)
327
328     @patch("black.dump_to_file", dump_to_stderr)
329     def test_fmtonoff(self) -> None:
330         source, expected = read_data("fmtonoff")
331         actual = fs(source)
332         self.assertFormatEqual(expected, actual)
333         black.assert_equivalent(source, actual)
334         black.assert_stable(source, actual, line_length=ll)
335
336     @patch("black.dump_to_file", dump_to_stderr)
337     def test_remove_empty_parentheses_after_class(self) -> None:
338         source, expected = read_data("class_blank_parentheses")
339         actual = fs(source)
340         self.assertFormatEqual(expected, actual)
341         black.assert_equivalent(source, actual)
342         black.assert_stable(source, actual, line_length=ll)
343
344     def test_report(self) -> None:
345         report = black.Report()
346         out_lines = []
347         err_lines = []
348
349         def out(msg: str, **kwargs: Any) -> None:
350             out_lines.append(msg)
351
352         def err(msg: str, **kwargs: Any) -> None:
353             err_lines.append(msg)
354
355         with patch("black.out", out), patch("black.err", err):
356             report.done(Path("f1"), black.Changed.NO)
357             self.assertEqual(len(out_lines), 1)
358             self.assertEqual(len(err_lines), 0)
359             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
360             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
361             self.assertEqual(report.return_code, 0)
362             report.done(Path("f2"), black.Changed.YES)
363             self.assertEqual(len(out_lines), 2)
364             self.assertEqual(len(err_lines), 0)
365             self.assertEqual(out_lines[-1], "reformatted f2")
366             self.assertEqual(
367                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
368             )
369             report.done(Path("f3"), black.Changed.CACHED)
370             self.assertEqual(len(out_lines), 3)
371             self.assertEqual(len(err_lines), 0)
372             self.assertEqual(
373                 out_lines[-1], "f3 wasn't modified on disk since last run."
374             )
375             self.assertEqual(
376                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
377             )
378             self.assertEqual(report.return_code, 0)
379             report.check = True
380             self.assertEqual(report.return_code, 1)
381             report.check = False
382             report.failed(Path("e1"), "boom")
383             self.assertEqual(len(out_lines), 3)
384             self.assertEqual(len(err_lines), 1)
385             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
386             self.assertEqual(
387                 unstyle(str(report)),
388                 "1 file reformatted, 2 files left unchanged, "
389                 "1 file failed to reformat.",
390             )
391             self.assertEqual(report.return_code, 123)
392             report.done(Path("f3"), black.Changed.YES)
393             self.assertEqual(len(out_lines), 4)
394             self.assertEqual(len(err_lines), 1)
395             self.assertEqual(out_lines[-1], "reformatted f3")
396             self.assertEqual(
397                 unstyle(str(report)),
398                 "2 files reformatted, 2 files left unchanged, "
399                 "1 file failed to reformat.",
400             )
401             self.assertEqual(report.return_code, 123)
402             report.failed(Path("e2"), "boom")
403             self.assertEqual(len(out_lines), 4)
404             self.assertEqual(len(err_lines), 2)
405             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
406             self.assertEqual(
407                 unstyle(str(report)),
408                 "2 files reformatted, 2 files left unchanged, "
409                 "2 files failed to reformat.",
410             )
411             self.assertEqual(report.return_code, 123)
412             report.done(Path("f4"), black.Changed.NO)
413             self.assertEqual(len(out_lines), 5)
414             self.assertEqual(len(err_lines), 2)
415             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
416             self.assertEqual(
417                 unstyle(str(report)),
418                 "2 files reformatted, 3 files left unchanged, "
419                 "2 files failed to reformat.",
420             )
421             self.assertEqual(report.return_code, 123)
422             report.check = True
423             self.assertEqual(
424                 unstyle(str(report)),
425                 "2 files would be reformatted, 3 files would be left unchanged, "
426                 "2 files would fail to reformat.",
427             )
428
429     def test_is_python36(self) -> None:
430         node = black.lib2to3_parse("def f(*, arg): ...\n")
431         self.assertFalse(black.is_python36(node))
432         node = black.lib2to3_parse("def f(*, arg,): ...\n")
433         self.assertTrue(black.is_python36(node))
434         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
435         self.assertTrue(black.is_python36(node))
436         source, expected = read_data("function")
437         node = black.lib2to3_parse(source)
438         self.assertTrue(black.is_python36(node))
439         node = black.lib2to3_parse(expected)
440         self.assertTrue(black.is_python36(node))
441         source, expected = read_data("expression")
442         node = black.lib2to3_parse(source)
443         self.assertFalse(black.is_python36(node))
444         node = black.lib2to3_parse(expected)
445         self.assertFalse(black.is_python36(node))
446
447     def test_debug_visitor(self) -> None:
448         source, _ = read_data("debug_visitor.py")
449         expected, _ = read_data("debug_visitor.out")
450         out_lines = []
451         err_lines = []
452
453         def out(msg: str, **kwargs: Any) -> None:
454             out_lines.append(msg)
455
456         def err(msg: str, **kwargs: Any) -> None:
457             err_lines.append(msg)
458
459         with patch("black.out", out), patch("black.err", err):
460             black.DebugVisitor.show(source)
461         actual = "\n".join(out_lines) + "\n"
462         log_name = ""
463         if expected != actual:
464             log_name = black.dump_to_file(*out_lines)
465         self.assertEqual(
466             expected,
467             actual,
468             f"AST print out is different. Actual version dumped to {log_name}",
469         )
470
471     def test_format_file_contents(self) -> None:
472         empty = ""
473         with self.assertRaises(black.NothingChanged):
474             black.format_file_contents(empty, line_length=ll, fast=False)
475         just_nl = "\n"
476         with self.assertRaises(black.NothingChanged):
477             black.format_file_contents(just_nl, line_length=ll, fast=False)
478         same = "l = [1, 2, 3]\n"
479         with self.assertRaises(black.NothingChanged):
480             black.format_file_contents(same, line_length=ll, fast=False)
481         different = "l = [1,2,3]"
482         expected = same
483         actual = black.format_file_contents(different, line_length=ll, fast=False)
484         self.assertEqual(expected, actual)
485         invalid = "return if you can"
486         with self.assertRaises(ValueError) as e:
487             black.format_file_contents(invalid, line_length=ll, fast=False)
488         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
489
490     def test_endmarker(self) -> None:
491         n = black.lib2to3_parse("\n")
492         self.assertEqual(n.type, black.syms.file_input)
493         self.assertEqual(len(n.children), 1)
494         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
495
496     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
497     def test_assertFormatEqual(self) -> None:
498         out_lines = []
499         err_lines = []
500
501         def out(msg: str, **kwargs: Any) -> None:
502             out_lines.append(msg)
503
504         def err(msg: str, **kwargs: Any) -> None:
505             err_lines.append(msg)
506
507         with patch("black.out", out), patch("black.err", err):
508             with self.assertRaises(AssertionError):
509                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
510
511         out_str = "".join(out_lines)
512         self.assertTrue("Expected tree:" in out_str)
513         self.assertTrue("Actual tree:" in out_str)
514         self.assertEqual("".join(err_lines), "")
515
516     def test_cache_broken_file(self) -> None:
517         with cache_dir() as workspace:
518             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
519             with cache_file.open("w") as fobj:
520                 fobj.write("this is not a pickle")
521             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
522             src = (workspace / "test.py").resolve()
523             with src.open("w") as fobj:
524                 fobj.write("print('hello')")
525             result = CliRunner().invoke(black.main, [str(src)])
526             self.assertEqual(result.exit_code, 0)
527             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
528             self.assertIn(src, cache)
529
530     def test_cache_single_file_already_cached(self) -> None:
531         with cache_dir() as workspace:
532             src = (workspace / "test.py").resolve()
533             with src.open("w") as fobj:
534                 fobj.write("print('hello')")
535             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
536             result = CliRunner().invoke(black.main, [str(src)])
537             self.assertEqual(result.exit_code, 0)
538             with src.open("r") as fobj:
539                 self.assertEqual(fobj.read(), "print('hello')")
540
541     @event_loop(close=False)
542     def test_cache_multiple_files(self) -> None:
543         with cache_dir() as workspace, patch(
544             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
545         ):
546             one = (workspace / "one.py").resolve()
547             with one.open("w") as fobj:
548                 fobj.write("print('hello')")
549             two = (workspace / "two.py").resolve()
550             with two.open("w") as fobj:
551                 fobj.write("print('hello')")
552             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH)
553             result = CliRunner().invoke(black.main, [str(workspace)])
554             self.assertEqual(result.exit_code, 0)
555             with one.open("r") as fobj:
556                 self.assertEqual(fobj.read(), "print('hello')")
557             with two.open("r") as fobj:
558                 self.assertEqual(fobj.read(), 'print("hello")\n')
559             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
560             self.assertIn(one, cache)
561             self.assertIn(two, cache)
562
563     def test_no_cache_when_writeback_diff(self) -> None:
564         with cache_dir() as workspace:
565             src = (workspace / "test.py").resolve()
566             with src.open("w") as fobj:
567                 fobj.write("print('hello')")
568             result = CliRunner().invoke(black.main, [str(src), "--diff"])
569             self.assertEqual(result.exit_code, 0)
570             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
571             self.assertFalse(cache_file.exists())
572
573     def test_no_cache_when_stdin(self) -> None:
574         with cache_dir():
575             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
576             self.assertEqual(result.exit_code, 0)
577             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
578             self.assertFalse(cache_file.exists())
579
580     def test_read_cache_no_cachefile(self) -> None:
581         with cache_dir():
582             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
583
584     def test_write_cache_read_cache(self) -> None:
585         with cache_dir() as workspace:
586             src = (workspace / "test.py").resolve()
587             src.touch()
588             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
589             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
590             self.assertIn(src, cache)
591             self.assertEqual(cache[src], black.get_cache_info(src))
592
593     def test_filter_cached(self) -> None:
594         with TemporaryDirectory() as workspace:
595             path = Path(workspace)
596             uncached = (path / "uncached").resolve()
597             cached = (path / "cached").resolve()
598             cached_but_changed = (path / "changed").resolve()
599             uncached.touch()
600             cached.touch()
601             cached_but_changed.touch()
602             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
603             todo, done = black.filter_cached(
604                 cache, [uncached, cached, cached_but_changed]
605             )
606             self.assertEqual(todo, [uncached, cached_but_changed])
607             self.assertEqual(done, [cached])
608
609     def test_write_cache_creates_directory_if_needed(self) -> None:
610         with cache_dir(exists=False) as workspace:
611             self.assertFalse(workspace.exists())
612             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
613             self.assertTrue(workspace.exists())
614
615     @event_loop(close=False)
616     def test_failed_formatting_does_not_get_cached(self) -> None:
617         with cache_dir() as workspace, patch(
618             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
619         ):
620             failing = (workspace / "failing.py").resolve()
621             with failing.open("w") as fobj:
622                 fobj.write("not actually python")
623             clean = (workspace / "clean.py").resolve()
624             with clean.open("w") as fobj:
625                 fobj.write('print("hello")\n')
626             result = CliRunner().invoke(black.main, [str(workspace)])
627             self.assertEqual(result.exit_code, 123)
628             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
629             self.assertNotIn(failing, cache)
630             self.assertIn(clean, cache)
631
632     def test_write_cache_write_fail(self) -> None:
633         with cache_dir(), patch.object(Path, "open") as mock:
634             mock.side_effect = OSError
635             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
636
637     def test_check_diff_use_together(self) -> None:
638         with cache_dir():
639             # Files which will be reformatted.
640             src1 = (THIS_DIR / "string_quotes.py").resolve()
641             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
642             self.assertEqual(result.exit_code, 1)
643
644             # Files which will not be reformatted.
645             src2 = (THIS_DIR / "composition.py").resolve()
646             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
647             self.assertEqual(result.exit_code, 0)
648
649             # Multi file command.
650             result = CliRunner().invoke(
651                 black.main, [str(src1), str(src2), "--diff", "--check"]
652             )
653             self.assertEqual(result.exit_code, 1)
654
655     def test_no_files(self) -> None:
656         with cache_dir():
657             # Without an argument, black exits with error code 0.
658             result = CliRunner().invoke(black.main, [])
659             self.assertEqual(result.exit_code, 0)
660
661     def test_read_cache_line_lengths(self) -> None:
662         with cache_dir() as workspace:
663             path = (workspace / "file.py").resolve()
664             path.touch()
665             black.write_cache({}, [path], 1)
666             one = black.read_cache(1)
667             self.assertIn(path, one)
668             two = black.read_cache(2)
669             self.assertNotIn(path, two)
670
671
672 if __name__ == "__main__":
673     unittest.main()