]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Simplify `is_trivial_*` methods
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14
15 from click import unstyle
16 from click.testing import CliRunner
17
18 import black
19
20 ll = 88
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
26
27
28 def dump_to_stderr(*output: str) -> str:
29     return "\n" + "\n".join(output) + "\n"
30
31
32 def read_data(name: str) -> Tuple[str, str]:
33     """read_data('test_name') -> 'input', 'output'"""
34     if not name.endswith((".py", ".pyi", ".out", ".diff")):
35         name += ".py"
36     _input: List[str] = []
37     _output: List[str] = []
38     with open(THIS_DIR / name, "r", encoding="utf8") as test:
39         lines = test.readlines()
40     result = _input
41     for line in lines:
42         line = line.replace(EMPTY_LINE, "")
43         if line.rstrip() == "# output":
44             result = _output
45             continue
46
47         result.append(line)
48     if _input and not _output:
49         # If there's no output marker, treat the entire file as already pre-formatted.
50         _output = _input[:]
51     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
52
53
54 @contextmanager
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56     with TemporaryDirectory() as workspace:
57         cache_dir = Path(workspace)
58         if not exists:
59             cache_dir = cache_dir / "new"
60         with patch("black.CACHE_DIR", cache_dir):
61             yield cache_dir
62
63
64 @contextmanager
65 def event_loop(close: bool) -> Iterator[None]:
66     policy = asyncio.get_event_loop_policy()
67     old_loop = policy.get_event_loop()
68     loop = policy.new_event_loop()
69     asyncio.set_event_loop(loop)
70     try:
71         yield
72
73     finally:
74         policy.set_event_loop(old_loop)
75         if close:
76             loop.close()
77
78
79 class BlackTestCase(unittest.TestCase):
80     maxDiff = None
81
82     def assertFormatEqual(self, expected: str, actual: str) -> None:
83         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84             bdv: black.DebugVisitor[Any]
85             black.out("Expected tree:", fg="green")
86             try:
87                 exp_node = black.lib2to3_parse(expected)
88                 bdv = black.DebugVisitor()
89                 list(bdv.visit(exp_node))
90             except Exception as ve:
91                 black.err(str(ve))
92             black.out("Actual tree:", fg="red")
93             try:
94                 exp_node = black.lib2to3_parse(actual)
95                 bdv = black.DebugVisitor()
96                 list(bdv.visit(exp_node))
97             except Exception as ve:
98                 black.err(str(ve))
99         self.assertEqual(expected, actual)
100
101     @patch("black.dump_to_file", dump_to_stderr)
102     def test_self(self) -> None:
103         source, expected = read_data("test_black")
104         actual = fs(source)
105         self.assertFormatEqual(expected, actual)
106         black.assert_equivalent(source, actual)
107         black.assert_stable(source, actual, line_length=ll)
108         self.assertFalse(ff(THIS_FILE))
109
110     @patch("black.dump_to_file", dump_to_stderr)
111     def test_black(self) -> None:
112         source, expected = read_data("../black")
113         actual = fs(source)
114         self.assertFormatEqual(expected, actual)
115         black.assert_equivalent(source, actual)
116         black.assert_stable(source, actual, line_length=ll)
117         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
118
119     def test_piping(self) -> None:
120         source, expected = read_data("../black")
121         hold_stdin, hold_stdout = sys.stdin, sys.stdout
122         try:
123             sys.stdin, sys.stdout = StringIO(source), StringIO()
124             sys.stdin.name = "<stdin>"
125             black.format_stdin_to_stdout(
126                 line_length=ll, fast=True, write_back=black.WriteBack.YES
127             )
128             sys.stdout.seek(0)
129             actual = sys.stdout.read()
130         finally:
131             sys.stdin, sys.stdout = hold_stdin, hold_stdout
132         self.assertFormatEqual(expected, actual)
133         black.assert_equivalent(source, actual)
134         black.assert_stable(source, actual, line_length=ll)
135
136     def test_piping_diff(self) -> None:
137         source, _ = read_data("expression.py")
138         expected, _ = read_data("expression.diff")
139         hold_stdin, hold_stdout = sys.stdin, sys.stdout
140         try:
141             sys.stdin, sys.stdout = StringIO(source), StringIO()
142             sys.stdin.name = "<stdin>"
143             black.format_stdin_to_stdout(
144                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
145             )
146             sys.stdout.seek(0)
147             actual = sys.stdout.read()
148         finally:
149             sys.stdin, sys.stdout = hold_stdin, hold_stdout
150         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
151         self.assertEqual(expected, actual)
152
153     @patch("black.dump_to_file", dump_to_stderr)
154     def test_setup(self) -> None:
155         source, expected = read_data("../setup")
156         actual = fs(source)
157         self.assertFormatEqual(expected, actual)
158         black.assert_equivalent(source, actual)
159         black.assert_stable(source, actual, line_length=ll)
160         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
161
162     @patch("black.dump_to_file", dump_to_stderr)
163     def test_function(self) -> None:
164         source, expected = read_data("function")
165         actual = fs(source)
166         self.assertFormatEqual(expected, actual)
167         black.assert_equivalent(source, actual)
168         black.assert_stable(source, actual, line_length=ll)
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_function2(self) -> None:
172         source, expected = read_data("function2")
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, line_length=ll)
177
178     @patch("black.dump_to_file", dump_to_stderr)
179     def test_expression(self) -> None:
180         source, expected = read_data("expression")
181         actual = fs(source)
182         self.assertFormatEqual(expected, actual)
183         black.assert_equivalent(source, actual)
184         black.assert_stable(source, actual, line_length=ll)
185
186     def test_expression_ff(self) -> None:
187         source, expected = read_data("expression")
188         tmp_file = Path(black.dump_to_file(source))
189         try:
190             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
191             with open(tmp_file, encoding="utf8") as f:
192                 actual = f.read()
193         finally:
194             os.unlink(tmp_file)
195         self.assertFormatEqual(expected, actual)
196         with patch("black.dump_to_file", dump_to_stderr):
197             black.assert_equivalent(source, actual)
198             black.assert_stable(source, actual, line_length=ll)
199
200     def test_expression_diff(self) -> None:
201         source, _ = read_data("expression.py")
202         expected, _ = read_data("expression.diff")
203         tmp_file = Path(black.dump_to_file(source))
204         hold_stdout = sys.stdout
205         try:
206             sys.stdout = StringIO()
207             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
208             sys.stdout.seek(0)
209             actual = sys.stdout.read()
210             actual = actual.replace(str(tmp_file), "<stdin>")
211         finally:
212             sys.stdout = hold_stdout
213             os.unlink(tmp_file)
214         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
215         if expected != actual:
216             dump = black.dump_to_file(actual)
217             msg = (
218                 f"Expected diff isn't equal to the actual. If you made changes "
219                 f"to expression.py and this is an anticipated difference, "
220                 f"overwrite tests/expression.diff with {dump}"
221             )
222             self.assertEqual(expected, actual, msg)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_fstring(self) -> None:
226         source, expected = read_data("fstring")
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_string_quotes(self) -> None:
234         source, expected = read_data("string_quotes")
235         actual = fs(source)
236         self.assertFormatEqual(expected, actual)
237         black.assert_equivalent(source, actual)
238         black.assert_stable(source, actual, line_length=ll)
239
240     @patch("black.dump_to_file", dump_to_stderr)
241     def test_slices(self) -> None:
242         source, expected = read_data("slices")
243         actual = fs(source)
244         self.assertFormatEqual(expected, actual)
245         black.assert_equivalent(source, actual)
246         black.assert_stable(source, actual, line_length=ll)
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_comments(self) -> None:
250         source, expected = read_data("comments")
251         actual = fs(source)
252         self.assertFormatEqual(expected, actual)
253         black.assert_equivalent(source, actual)
254         black.assert_stable(source, actual, line_length=ll)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_comments2(self) -> None:
258         source, expected = read_data("comments2")
259         actual = fs(source)
260         self.assertFormatEqual(expected, actual)
261         black.assert_equivalent(source, actual)
262         black.assert_stable(source, actual, line_length=ll)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_comments3(self) -> None:
266         source, expected = read_data("comments3")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, line_length=ll)
271
272     @patch("black.dump_to_file", dump_to_stderr)
273     def test_comments4(self) -> None:
274         source, expected = read_data("comments4")
275         actual = fs(source)
276         self.assertFormatEqual(expected, actual)
277         black.assert_equivalent(source, actual)
278         black.assert_stable(source, actual, line_length=ll)
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_comments5(self) -> None:
282         source, expected = read_data("comments5")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_equivalent(source, actual)
286         black.assert_stable(source, actual, line_length=ll)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_cantfit(self) -> None:
290         source, expected = read_data("cantfit")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_import_spacing(self) -> None:
298         source, expected = read_data("import_spacing")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_composition(self) -> None:
306         source, expected = read_data("composition")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     @patch("black.dump_to_file", dump_to_stderr)
313     def test_empty_lines(self) -> None:
314         source, expected = read_data("empty_lines")
315         actual = fs(source)
316         self.assertFormatEqual(expected, actual)
317         black.assert_equivalent(source, actual)
318         black.assert_stable(source, actual, line_length=ll)
319
320     @patch("black.dump_to_file", dump_to_stderr)
321     def test_string_prefixes(self) -> None:
322         source, expected = read_data("string_prefixes")
323         actual = fs(source)
324         self.assertFormatEqual(expected, actual)
325         black.assert_equivalent(source, actual)
326         black.assert_stable(source, actual, line_length=ll)
327
328     @patch("black.dump_to_file", dump_to_stderr)
329     def test_python2(self) -> None:
330         source, expected = read_data("python2")
331         actual = fs(source)
332         self.assertFormatEqual(expected, actual)
333         # black.assert_equivalent(source, actual)
334         black.assert_stable(source, actual, line_length=ll)
335
336     @patch("black.dump_to_file", dump_to_stderr)
337     def test_python2_unicode_literals(self) -> None:
338         source, expected = read_data("python2_unicode_literals")
339         actual = fs(source)
340         self.assertFormatEqual(expected, actual)
341         black.assert_stable(source, actual, line_length=ll)
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_stub(self) -> None:
345         source, expected = read_data("stub.pyi")
346         actual = fs(source, is_pyi=True)
347         self.assertFormatEqual(expected, actual)
348         black.assert_stable(source, actual, line_length=ll, is_pyi=True)
349
350     @patch("black.dump_to_file", dump_to_stderr)
351     def test_fmtonoff(self) -> None:
352         source, expected = read_data("fmtonoff")
353         actual = fs(source)
354         self.assertFormatEqual(expected, actual)
355         black.assert_equivalent(source, actual)
356         black.assert_stable(source, actual, line_length=ll)
357
358     @patch("black.dump_to_file", dump_to_stderr)
359     def test_remove_empty_parentheses_after_class(self) -> None:
360         source, expected = read_data("class_blank_parentheses")
361         actual = fs(source)
362         self.assertFormatEqual(expected, actual)
363         black.assert_equivalent(source, actual)
364         black.assert_stable(source, actual, line_length=ll)
365
366     def test_report(self) -> None:
367         report = black.Report()
368         out_lines = []
369         err_lines = []
370
371         def out(msg: str, **kwargs: Any) -> None:
372             out_lines.append(msg)
373
374         def err(msg: str, **kwargs: Any) -> None:
375             err_lines.append(msg)
376
377         with patch("black.out", out), patch("black.err", err):
378             report.done(Path("f1"), black.Changed.NO)
379             self.assertEqual(len(out_lines), 1)
380             self.assertEqual(len(err_lines), 0)
381             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
382             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
383             self.assertEqual(report.return_code, 0)
384             report.done(Path("f2"), black.Changed.YES)
385             self.assertEqual(len(out_lines), 2)
386             self.assertEqual(len(err_lines), 0)
387             self.assertEqual(out_lines[-1], "reformatted f2")
388             self.assertEqual(
389                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
390             )
391             report.done(Path("f3"), black.Changed.CACHED)
392             self.assertEqual(len(out_lines), 3)
393             self.assertEqual(len(err_lines), 0)
394             self.assertEqual(
395                 out_lines[-1], "f3 wasn't modified on disk since last run."
396             )
397             self.assertEqual(
398                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
399             )
400             self.assertEqual(report.return_code, 0)
401             report.check = True
402             self.assertEqual(report.return_code, 1)
403             report.check = False
404             report.failed(Path("e1"), "boom")
405             self.assertEqual(len(out_lines), 3)
406             self.assertEqual(len(err_lines), 1)
407             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
408             self.assertEqual(
409                 unstyle(str(report)),
410                 "1 file reformatted, 2 files left unchanged, "
411                 "1 file failed to reformat.",
412             )
413             self.assertEqual(report.return_code, 123)
414             report.done(Path("f3"), black.Changed.YES)
415             self.assertEqual(len(out_lines), 4)
416             self.assertEqual(len(err_lines), 1)
417             self.assertEqual(out_lines[-1], "reformatted f3")
418             self.assertEqual(
419                 unstyle(str(report)),
420                 "2 files reformatted, 2 files left unchanged, "
421                 "1 file failed to reformat.",
422             )
423             self.assertEqual(report.return_code, 123)
424             report.failed(Path("e2"), "boom")
425             self.assertEqual(len(out_lines), 4)
426             self.assertEqual(len(err_lines), 2)
427             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
428             self.assertEqual(
429                 unstyle(str(report)),
430                 "2 files reformatted, 2 files left unchanged, "
431                 "2 files failed to reformat.",
432             )
433             self.assertEqual(report.return_code, 123)
434             report.done(Path("f4"), black.Changed.NO)
435             self.assertEqual(len(out_lines), 5)
436             self.assertEqual(len(err_lines), 2)
437             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
438             self.assertEqual(
439                 unstyle(str(report)),
440                 "2 files reformatted, 3 files left unchanged, "
441                 "2 files failed to reformat.",
442             )
443             self.assertEqual(report.return_code, 123)
444             report.check = True
445             self.assertEqual(
446                 unstyle(str(report)),
447                 "2 files would be reformatted, 3 files would be left unchanged, "
448                 "2 files would fail to reformat.",
449             )
450
451     def test_is_python36(self) -> None:
452         node = black.lib2to3_parse("def f(*, arg): ...\n")
453         self.assertFalse(black.is_python36(node))
454         node = black.lib2to3_parse("def f(*, arg,): ...\n")
455         self.assertTrue(black.is_python36(node))
456         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
457         self.assertTrue(black.is_python36(node))
458         source, expected = read_data("function")
459         node = black.lib2to3_parse(source)
460         self.assertTrue(black.is_python36(node))
461         node = black.lib2to3_parse(expected)
462         self.assertTrue(black.is_python36(node))
463         source, expected = read_data("expression")
464         node = black.lib2to3_parse(source)
465         self.assertFalse(black.is_python36(node))
466         node = black.lib2to3_parse(expected)
467         self.assertFalse(black.is_python36(node))
468
469     def test_get_future_imports(self) -> None:
470         node = black.lib2to3_parse("\n")
471         self.assertEqual(set(), black.get_future_imports(node))
472         node = black.lib2to3_parse("from __future__ import black\n")
473         self.assertEqual({"black"}, black.get_future_imports(node))
474         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
475         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
476         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
477         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
478         node = black.lib2to3_parse(
479             "from __future__ import multiple\nfrom __future__ import imports\n"
480         )
481         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
482         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
483         self.assertEqual({"black"}, black.get_future_imports(node))
484         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
485         self.assertEqual({"black"}, black.get_future_imports(node))
486         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
487         self.assertEqual(set(), black.get_future_imports(node))
488         node = black.lib2to3_parse("from some.module import black\n")
489         self.assertEqual(set(), black.get_future_imports(node))
490
491     def test_debug_visitor(self) -> None:
492         source, _ = read_data("debug_visitor.py")
493         expected, _ = read_data("debug_visitor.out")
494         out_lines = []
495         err_lines = []
496
497         def out(msg: str, **kwargs: Any) -> None:
498             out_lines.append(msg)
499
500         def err(msg: str, **kwargs: Any) -> None:
501             err_lines.append(msg)
502
503         with patch("black.out", out), patch("black.err", err):
504             black.DebugVisitor.show(source)
505         actual = "\n".join(out_lines) + "\n"
506         log_name = ""
507         if expected != actual:
508             log_name = black.dump_to_file(*out_lines)
509         self.assertEqual(
510             expected,
511             actual,
512             f"AST print out is different. Actual version dumped to {log_name}",
513         )
514
515     def test_format_file_contents(self) -> None:
516         empty = ""
517         with self.assertRaises(black.NothingChanged):
518             black.format_file_contents(empty, line_length=ll, fast=False)
519         just_nl = "\n"
520         with self.assertRaises(black.NothingChanged):
521             black.format_file_contents(just_nl, line_length=ll, fast=False)
522         same = "l = [1, 2, 3]\n"
523         with self.assertRaises(black.NothingChanged):
524             black.format_file_contents(same, line_length=ll, fast=False)
525         different = "l = [1,2,3]"
526         expected = same
527         actual = black.format_file_contents(different, line_length=ll, fast=False)
528         self.assertEqual(expected, actual)
529         invalid = "return if you can"
530         with self.assertRaises(ValueError) as e:
531             black.format_file_contents(invalid, line_length=ll, fast=False)
532         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
533
534     def test_endmarker(self) -> None:
535         n = black.lib2to3_parse("\n")
536         self.assertEqual(n.type, black.syms.file_input)
537         self.assertEqual(len(n.children), 1)
538         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
539
540     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
541     def test_assertFormatEqual(self) -> None:
542         out_lines = []
543         err_lines = []
544
545         def out(msg: str, **kwargs: Any) -> None:
546             out_lines.append(msg)
547
548         def err(msg: str, **kwargs: Any) -> None:
549             err_lines.append(msg)
550
551         with patch("black.out", out), patch("black.err", err):
552             with self.assertRaises(AssertionError):
553                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
554
555         out_str = "".join(out_lines)
556         self.assertTrue("Expected tree:" in out_str)
557         self.assertTrue("Actual tree:" in out_str)
558         self.assertEqual("".join(err_lines), "")
559
560     def test_cache_broken_file(self) -> None:
561         with cache_dir() as workspace:
562             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
563             with cache_file.open("w") as fobj:
564                 fobj.write("this is not a pickle")
565             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
566             src = (workspace / "test.py").resolve()
567             with src.open("w") as fobj:
568                 fobj.write("print('hello')")
569             result = CliRunner().invoke(black.main, [str(src)])
570             self.assertEqual(result.exit_code, 0)
571             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
572             self.assertIn(src, cache)
573
574     def test_cache_single_file_already_cached(self) -> None:
575         with cache_dir() as workspace:
576             src = (workspace / "test.py").resolve()
577             with src.open("w") as fobj:
578                 fobj.write("print('hello')")
579             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
580             result = CliRunner().invoke(black.main, [str(src)])
581             self.assertEqual(result.exit_code, 0)
582             with src.open("r") as fobj:
583                 self.assertEqual(fobj.read(), "print('hello')")
584
585     @event_loop(close=False)
586     def test_cache_multiple_files(self) -> None:
587         with cache_dir() as workspace, patch(
588             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
589         ):
590             one = (workspace / "one.py").resolve()
591             with one.open("w") as fobj:
592                 fobj.write("print('hello')")
593             two = (workspace / "two.py").resolve()
594             with two.open("w") as fobj:
595                 fobj.write("print('hello')")
596             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH)
597             result = CliRunner().invoke(black.main, [str(workspace)])
598             self.assertEqual(result.exit_code, 0)
599             with one.open("r") as fobj:
600                 self.assertEqual(fobj.read(), "print('hello')")
601             with two.open("r") as fobj:
602                 self.assertEqual(fobj.read(), 'print("hello")\n')
603             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
604             self.assertIn(one, cache)
605             self.assertIn(two, cache)
606
607     def test_no_cache_when_writeback_diff(self) -> None:
608         with cache_dir() as workspace:
609             src = (workspace / "test.py").resolve()
610             with src.open("w") as fobj:
611                 fobj.write("print('hello')")
612             result = CliRunner().invoke(black.main, [str(src), "--diff"])
613             self.assertEqual(result.exit_code, 0)
614             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
615             self.assertFalse(cache_file.exists())
616
617     def test_no_cache_when_stdin(self) -> None:
618         with cache_dir():
619             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
620             self.assertEqual(result.exit_code, 0)
621             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
622             self.assertFalse(cache_file.exists())
623
624     def test_read_cache_no_cachefile(self) -> None:
625         with cache_dir():
626             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
627
628     def test_write_cache_read_cache(self) -> None:
629         with cache_dir() as workspace:
630             src = (workspace / "test.py").resolve()
631             src.touch()
632             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
633             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
634             self.assertIn(src, cache)
635             self.assertEqual(cache[src], black.get_cache_info(src))
636
637     def test_filter_cached(self) -> None:
638         with TemporaryDirectory() as workspace:
639             path = Path(workspace)
640             uncached = (path / "uncached").resolve()
641             cached = (path / "cached").resolve()
642             cached_but_changed = (path / "changed").resolve()
643             uncached.touch()
644             cached.touch()
645             cached_but_changed.touch()
646             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
647             todo, done = black.filter_cached(
648                 cache, [uncached, cached, cached_but_changed]
649             )
650             self.assertEqual(todo, [uncached, cached_but_changed])
651             self.assertEqual(done, [cached])
652
653     def test_write_cache_creates_directory_if_needed(self) -> None:
654         with cache_dir(exists=False) as workspace:
655             self.assertFalse(workspace.exists())
656             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
657             self.assertTrue(workspace.exists())
658
659     @event_loop(close=False)
660     def test_failed_formatting_does_not_get_cached(self) -> None:
661         with cache_dir() as workspace, patch(
662             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
663         ):
664             failing = (workspace / "failing.py").resolve()
665             with failing.open("w") as fobj:
666                 fobj.write("not actually python")
667             clean = (workspace / "clean.py").resolve()
668             with clean.open("w") as fobj:
669                 fobj.write('print("hello")\n')
670             result = CliRunner().invoke(black.main, [str(workspace)])
671             self.assertEqual(result.exit_code, 123)
672             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
673             self.assertNotIn(failing, cache)
674             self.assertIn(clean, cache)
675
676     def test_write_cache_write_fail(self) -> None:
677         with cache_dir(), patch.object(Path, "open") as mock:
678             mock.side_effect = OSError
679             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
680
681     def test_check_diff_use_together(self) -> None:
682         with cache_dir():
683             # Files which will be reformatted.
684             src1 = (THIS_DIR / "string_quotes.py").resolve()
685             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
686             self.assertEqual(result.exit_code, 1)
687
688             # Files which will not be reformatted.
689             src2 = (THIS_DIR / "composition.py").resolve()
690             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
691             self.assertEqual(result.exit_code, 0)
692
693             # Multi file command.
694             result = CliRunner().invoke(
695                 black.main, [str(src1), str(src2), "--diff", "--check"]
696             )
697             self.assertEqual(result.exit_code, 1)
698
699     def test_no_files(self) -> None:
700         with cache_dir():
701             # Without an argument, black exits with error code 0.
702             result = CliRunner().invoke(black.main, [])
703             self.assertEqual(result.exit_code, 0)
704
705     def test_broken_symlink(self) -> None:
706         with cache_dir() as workspace:
707             symlink = workspace / "broken_link.py"
708             symlink.symlink_to("nonexistent.py")
709             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
710             self.assertEqual(result.exit_code, 0)
711
712     def test_read_cache_line_lengths(self) -> None:
713         with cache_dir() as workspace:
714             path = (workspace / "file.py").resolve()
715             path.touch()
716             black.write_cache({}, [path], 1)
717             one = black.read_cache(1)
718             self.assertIn(path, one)
719             two = black.read_cache(2)
720             self.assertNotIn(path, two)
721
722
723 if __name__ == "__main__":
724     unittest.main()