]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Don't explode trailers that fit in a single line
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14
15 from click import unstyle
16 from click.testing import CliRunner
17
18 import black
19
20 ll = 88
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
26
27
28 def dump_to_stderr(*output: str) -> str:
29     return "\n" + "\n".join(output) + "\n"
30
31
32 def read_data(name: str) -> Tuple[str, str]:
33     """read_data('test_name') -> 'input', 'output'"""
34     if not name.endswith((".py", ".out", ".diff")):
35         name += ".py"
36     _input: List[str] = []
37     _output: List[str] = []
38     with open(THIS_DIR / name, "r", encoding="utf8") as test:
39         lines = test.readlines()
40     result = _input
41     for line in lines:
42         line = line.replace(EMPTY_LINE, "")
43         if line.rstrip() == "# output":
44             result = _output
45             continue
46
47         result.append(line)
48     if _input and not _output:
49         # If there's no output marker, treat the entire file as already pre-formatted.
50         _output = _input[:]
51     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
52
53
54 @contextmanager
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56     with TemporaryDirectory() as workspace:
57         cache_dir = Path(workspace)
58         if not exists:
59             cache_dir = cache_dir / "new"
60         with patch("black.CACHE_DIR", cache_dir):
61             yield cache_dir
62
63
64 @contextmanager
65 def event_loop(close: bool) -> Iterator[None]:
66     policy = asyncio.get_event_loop_policy()
67     old_loop = policy.get_event_loop()
68     loop = policy.new_event_loop()
69     asyncio.set_event_loop(loop)
70     try:
71         yield
72
73     finally:
74         policy.set_event_loop(old_loop)
75         if close:
76             loop.close()
77
78
79 class BlackTestCase(unittest.TestCase):
80     maxDiff = None
81
82     def assertFormatEqual(self, expected: str, actual: str) -> None:
83         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84             bdv: black.DebugVisitor[Any]
85             black.out("Expected tree:", fg="green")
86             try:
87                 exp_node = black.lib2to3_parse(expected)
88                 bdv = black.DebugVisitor()
89                 list(bdv.visit(exp_node))
90             except Exception as ve:
91                 black.err(str(ve))
92             black.out("Actual tree:", fg="red")
93             try:
94                 exp_node = black.lib2to3_parse(actual)
95                 bdv = black.DebugVisitor()
96                 list(bdv.visit(exp_node))
97             except Exception as ve:
98                 black.err(str(ve))
99         self.assertEqual(expected, actual)
100
101     @patch("black.dump_to_file", dump_to_stderr)
102     def test_self(self) -> None:
103         source, expected = read_data("test_black")
104         actual = fs(source)
105         self.assertFormatEqual(expected, actual)
106         black.assert_equivalent(source, actual)
107         black.assert_stable(source, actual, line_length=ll)
108         self.assertFalse(ff(THIS_FILE))
109
110     @patch("black.dump_to_file", dump_to_stderr)
111     def test_black(self) -> None:
112         source, expected = read_data("../black")
113         actual = fs(source)
114         self.assertFormatEqual(expected, actual)
115         black.assert_equivalent(source, actual)
116         black.assert_stable(source, actual, line_length=ll)
117         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
118
119     def test_piping(self) -> None:
120         source, expected = read_data("../black")
121         hold_stdin, hold_stdout = sys.stdin, sys.stdout
122         try:
123             sys.stdin, sys.stdout = StringIO(source), StringIO()
124             sys.stdin.name = "<stdin>"
125             black.format_stdin_to_stdout(
126                 line_length=ll, fast=True, write_back=black.WriteBack.YES
127             )
128             sys.stdout.seek(0)
129             actual = sys.stdout.read()
130         finally:
131             sys.stdin, sys.stdout = hold_stdin, hold_stdout
132         self.assertFormatEqual(expected, actual)
133         black.assert_equivalent(source, actual)
134         black.assert_stable(source, actual, line_length=ll)
135
136     def test_piping_diff(self) -> None:
137         source, _ = read_data("expression.py")
138         expected, _ = read_data("expression.diff")
139         hold_stdin, hold_stdout = sys.stdin, sys.stdout
140         try:
141             sys.stdin, sys.stdout = StringIO(source), StringIO()
142             sys.stdin.name = "<stdin>"
143             black.format_stdin_to_stdout(
144                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
145             )
146             sys.stdout.seek(0)
147             actual = sys.stdout.read()
148         finally:
149             sys.stdin, sys.stdout = hold_stdin, hold_stdout
150         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
151         self.assertEqual(expected, actual)
152
153     @patch("black.dump_to_file", dump_to_stderr)
154     def test_setup(self) -> None:
155         source, expected = read_data("../setup")
156         actual = fs(source)
157         self.assertFormatEqual(expected, actual)
158         black.assert_equivalent(source, actual)
159         black.assert_stable(source, actual, line_length=ll)
160         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
161
162     @patch("black.dump_to_file", dump_to_stderr)
163     def test_function(self) -> None:
164         source, expected = read_data("function")
165         actual = fs(source)
166         self.assertFormatEqual(expected, actual)
167         black.assert_equivalent(source, actual)
168         black.assert_stable(source, actual, line_length=ll)
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_function2(self) -> None:
172         source, expected = read_data("function2")
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, line_length=ll)
177
178     @patch("black.dump_to_file", dump_to_stderr)
179     def test_expression(self) -> None:
180         source, expected = read_data("expression")
181         actual = fs(source)
182         self.assertFormatEqual(expected, actual)
183         black.assert_equivalent(source, actual)
184         black.assert_stable(source, actual, line_length=ll)
185
186     def test_expression_ff(self) -> None:
187         source, expected = read_data("expression")
188         tmp_file = Path(black.dump_to_file(source))
189         try:
190             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
191             with open(tmp_file, encoding="utf8") as f:
192                 actual = f.read()
193         finally:
194             os.unlink(tmp_file)
195         self.assertFormatEqual(expected, actual)
196         with patch("black.dump_to_file", dump_to_stderr):
197             black.assert_equivalent(source, actual)
198             black.assert_stable(source, actual, line_length=ll)
199
200     def test_expression_diff(self) -> None:
201         source, _ = read_data("expression.py")
202         expected, _ = read_data("expression.diff")
203         tmp_file = Path(black.dump_to_file(source))
204         hold_stdout = sys.stdout
205         try:
206             sys.stdout = StringIO()
207             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
208             sys.stdout.seek(0)
209             actual = sys.stdout.read()
210             actual = actual.replace(str(tmp_file), "<stdin>")
211         finally:
212             sys.stdout = hold_stdout
213             os.unlink(tmp_file)
214         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
215         if expected != actual:
216             dump = black.dump_to_file(actual)
217             msg = (
218                 f"Expected diff isn't equal to the actual. If you made changes "
219                 f"to expression.py and this is an anticipated difference, "
220                 f"overwrite tests/expression.diff with {dump}"
221             )
222             self.assertEqual(expected, actual, msg)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_fstring(self) -> None:
226         source, expected = read_data("fstring")
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_string_quotes(self) -> None:
234         source, expected = read_data("string_quotes")
235         actual = fs(source)
236         self.assertFormatEqual(expected, actual)
237         black.assert_equivalent(source, actual)
238         black.assert_stable(source, actual, line_length=ll)
239
240     @patch("black.dump_to_file", dump_to_stderr)
241     def test_slices(self) -> None:
242         source, expected = read_data("slices")
243         actual = fs(source)
244         self.assertFormatEqual(expected, actual)
245         black.assert_equivalent(source, actual)
246         black.assert_stable(source, actual, line_length=ll)
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_comments(self) -> None:
250         source, expected = read_data("comments")
251         actual = fs(source)
252         self.assertFormatEqual(expected, actual)
253         black.assert_equivalent(source, actual)
254         black.assert_stable(source, actual, line_length=ll)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_comments2(self) -> None:
258         source, expected = read_data("comments2")
259         actual = fs(source)
260         self.assertFormatEqual(expected, actual)
261         black.assert_equivalent(source, actual)
262         black.assert_stable(source, actual, line_length=ll)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_comments3(self) -> None:
266         source, expected = read_data("comments3")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, line_length=ll)
271
272     @patch("black.dump_to_file", dump_to_stderr)
273     def test_comments4(self) -> None:
274         source, expected = read_data("comments4")
275         actual = fs(source)
276         self.assertFormatEqual(expected, actual)
277         black.assert_equivalent(source, actual)
278         black.assert_stable(source, actual, line_length=ll)
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_comments5(self) -> None:
282         source, expected = read_data("comments5")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_equivalent(source, actual)
286         black.assert_stable(source, actual, line_length=ll)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_cantfit(self) -> None:
290         source, expected = read_data("cantfit")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_import_spacing(self) -> None:
298         source, expected = read_data("import_spacing")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_composition(self) -> None:
306         source, expected = read_data("composition")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     @patch("black.dump_to_file", dump_to_stderr)
313     def test_empty_lines(self) -> None:
314         source, expected = read_data("empty_lines")
315         actual = fs(source)
316         self.assertFormatEqual(expected, actual)
317         black.assert_equivalent(source, actual)
318         black.assert_stable(source, actual, line_length=ll)
319
320     @patch("black.dump_to_file", dump_to_stderr)
321     def test_string_prefixes(self) -> None:
322         source, expected = read_data("string_prefixes")
323         actual = fs(source)
324         self.assertFormatEqual(expected, actual)
325         black.assert_equivalent(source, actual)
326         black.assert_stable(source, actual, line_length=ll)
327
328     @patch("black.dump_to_file", dump_to_stderr)
329     def test_python2(self) -> None:
330         source, expected = read_data("python2")
331         actual = fs(source)
332         self.assertFormatEqual(expected, actual)
333         # black.assert_equivalent(source, actual)
334         black.assert_stable(source, actual, line_length=ll)
335
336     @patch("black.dump_to_file", dump_to_stderr)
337     def test_python2_unicode_literals(self) -> None:
338         source, expected = read_data("python2_unicode_literals")
339         actual = fs(source)
340         self.assertFormatEqual(expected, actual)
341         black.assert_stable(source, actual, line_length=ll)
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_fmtonoff(self) -> None:
345         source, expected = read_data("fmtonoff")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, line_length=ll)
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_remove_empty_parentheses_after_class(self) -> None:
353         source, expected = read_data("class_blank_parentheses")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, line_length=ll)
358
359     def test_report(self) -> None:
360         report = black.Report()
361         out_lines = []
362         err_lines = []
363
364         def out(msg: str, **kwargs: Any) -> None:
365             out_lines.append(msg)
366
367         def err(msg: str, **kwargs: Any) -> None:
368             err_lines.append(msg)
369
370         with patch("black.out", out), patch("black.err", err):
371             report.done(Path("f1"), black.Changed.NO)
372             self.assertEqual(len(out_lines), 1)
373             self.assertEqual(len(err_lines), 0)
374             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
375             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
376             self.assertEqual(report.return_code, 0)
377             report.done(Path("f2"), black.Changed.YES)
378             self.assertEqual(len(out_lines), 2)
379             self.assertEqual(len(err_lines), 0)
380             self.assertEqual(out_lines[-1], "reformatted f2")
381             self.assertEqual(
382                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
383             )
384             report.done(Path("f3"), black.Changed.CACHED)
385             self.assertEqual(len(out_lines), 3)
386             self.assertEqual(len(err_lines), 0)
387             self.assertEqual(
388                 out_lines[-1], "f3 wasn't modified on disk since last run."
389             )
390             self.assertEqual(
391                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
392             )
393             self.assertEqual(report.return_code, 0)
394             report.check = True
395             self.assertEqual(report.return_code, 1)
396             report.check = False
397             report.failed(Path("e1"), "boom")
398             self.assertEqual(len(out_lines), 3)
399             self.assertEqual(len(err_lines), 1)
400             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
401             self.assertEqual(
402                 unstyle(str(report)),
403                 "1 file reformatted, 2 files left unchanged, "
404                 "1 file failed to reformat.",
405             )
406             self.assertEqual(report.return_code, 123)
407             report.done(Path("f3"), black.Changed.YES)
408             self.assertEqual(len(out_lines), 4)
409             self.assertEqual(len(err_lines), 1)
410             self.assertEqual(out_lines[-1], "reformatted f3")
411             self.assertEqual(
412                 unstyle(str(report)),
413                 "2 files reformatted, 2 files left unchanged, "
414                 "1 file failed to reformat.",
415             )
416             self.assertEqual(report.return_code, 123)
417             report.failed(Path("e2"), "boom")
418             self.assertEqual(len(out_lines), 4)
419             self.assertEqual(len(err_lines), 2)
420             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
421             self.assertEqual(
422                 unstyle(str(report)),
423                 "2 files reformatted, 2 files left unchanged, "
424                 "2 files failed to reformat.",
425             )
426             self.assertEqual(report.return_code, 123)
427             report.done(Path("f4"), black.Changed.NO)
428             self.assertEqual(len(out_lines), 5)
429             self.assertEqual(len(err_lines), 2)
430             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
431             self.assertEqual(
432                 unstyle(str(report)),
433                 "2 files reformatted, 3 files left unchanged, "
434                 "2 files failed to reformat.",
435             )
436             self.assertEqual(report.return_code, 123)
437             report.check = True
438             self.assertEqual(
439                 unstyle(str(report)),
440                 "2 files would be reformatted, 3 files would be left unchanged, "
441                 "2 files would fail to reformat.",
442             )
443
444     def test_is_python36(self) -> None:
445         node = black.lib2to3_parse("def f(*, arg): ...\n")
446         self.assertFalse(black.is_python36(node))
447         node = black.lib2to3_parse("def f(*, arg,): ...\n")
448         self.assertTrue(black.is_python36(node))
449         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
450         self.assertTrue(black.is_python36(node))
451         source, expected = read_data("function")
452         node = black.lib2to3_parse(source)
453         self.assertTrue(black.is_python36(node))
454         node = black.lib2to3_parse(expected)
455         self.assertTrue(black.is_python36(node))
456         source, expected = read_data("expression")
457         node = black.lib2to3_parse(source)
458         self.assertFalse(black.is_python36(node))
459         node = black.lib2to3_parse(expected)
460         self.assertFalse(black.is_python36(node))
461
462     def test_get_future_imports(self) -> None:
463         node = black.lib2to3_parse("\n")
464         self.assertEqual(set(), black.get_future_imports(node))
465         node = black.lib2to3_parse("from __future__ import black\n")
466         self.assertEqual({"black"}, black.get_future_imports(node))
467         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
468         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
469         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
470         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
471         node = black.lib2to3_parse(
472             "from __future__ import multiple\nfrom __future__ import imports\n"
473         )
474         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
475         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
476         self.assertEqual({"black"}, black.get_future_imports(node))
477         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
478         self.assertEqual({"black"}, black.get_future_imports(node))
479         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
480         self.assertEqual(set(), black.get_future_imports(node))
481         node = black.lib2to3_parse("from some.module import black\n")
482         self.assertEqual(set(), black.get_future_imports(node))
483
484     def test_debug_visitor(self) -> None:
485         source, _ = read_data("debug_visitor.py")
486         expected, _ = read_data("debug_visitor.out")
487         out_lines = []
488         err_lines = []
489
490         def out(msg: str, **kwargs: Any) -> None:
491             out_lines.append(msg)
492
493         def err(msg: str, **kwargs: Any) -> None:
494             err_lines.append(msg)
495
496         with patch("black.out", out), patch("black.err", err):
497             black.DebugVisitor.show(source)
498         actual = "\n".join(out_lines) + "\n"
499         log_name = ""
500         if expected != actual:
501             log_name = black.dump_to_file(*out_lines)
502         self.assertEqual(
503             expected,
504             actual,
505             f"AST print out is different. Actual version dumped to {log_name}",
506         )
507
508     def test_format_file_contents(self) -> None:
509         empty = ""
510         with self.assertRaises(black.NothingChanged):
511             black.format_file_contents(empty, line_length=ll, fast=False)
512         just_nl = "\n"
513         with self.assertRaises(black.NothingChanged):
514             black.format_file_contents(just_nl, line_length=ll, fast=False)
515         same = "l = [1, 2, 3]\n"
516         with self.assertRaises(black.NothingChanged):
517             black.format_file_contents(same, line_length=ll, fast=False)
518         different = "l = [1,2,3]"
519         expected = same
520         actual = black.format_file_contents(different, line_length=ll, fast=False)
521         self.assertEqual(expected, actual)
522         invalid = "return if you can"
523         with self.assertRaises(ValueError) as e:
524             black.format_file_contents(invalid, line_length=ll, fast=False)
525         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
526
527     def test_endmarker(self) -> None:
528         n = black.lib2to3_parse("\n")
529         self.assertEqual(n.type, black.syms.file_input)
530         self.assertEqual(len(n.children), 1)
531         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
532
533     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
534     def test_assertFormatEqual(self) -> None:
535         out_lines = []
536         err_lines = []
537
538         def out(msg: str, **kwargs: Any) -> None:
539             out_lines.append(msg)
540
541         def err(msg: str, **kwargs: Any) -> None:
542             err_lines.append(msg)
543
544         with patch("black.out", out), patch("black.err", err):
545             with self.assertRaises(AssertionError):
546                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
547
548         out_str = "".join(out_lines)
549         self.assertTrue("Expected tree:" in out_str)
550         self.assertTrue("Actual tree:" in out_str)
551         self.assertEqual("".join(err_lines), "")
552
553     def test_cache_broken_file(self) -> None:
554         with cache_dir() as workspace:
555             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
556             with cache_file.open("w") as fobj:
557                 fobj.write("this is not a pickle")
558             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
559             src = (workspace / "test.py").resolve()
560             with src.open("w") as fobj:
561                 fobj.write("print('hello')")
562             result = CliRunner().invoke(black.main, [str(src)])
563             self.assertEqual(result.exit_code, 0)
564             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
565             self.assertIn(src, cache)
566
567     def test_cache_single_file_already_cached(self) -> None:
568         with cache_dir() as workspace:
569             src = (workspace / "test.py").resolve()
570             with src.open("w") as fobj:
571                 fobj.write("print('hello')")
572             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
573             result = CliRunner().invoke(black.main, [str(src)])
574             self.assertEqual(result.exit_code, 0)
575             with src.open("r") as fobj:
576                 self.assertEqual(fobj.read(), "print('hello')")
577
578     @event_loop(close=False)
579     def test_cache_multiple_files(self) -> None:
580         with cache_dir() as workspace, patch(
581             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
582         ):
583             one = (workspace / "one.py").resolve()
584             with one.open("w") as fobj:
585                 fobj.write("print('hello')")
586             two = (workspace / "two.py").resolve()
587             with two.open("w") as fobj:
588                 fobj.write("print('hello')")
589             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH)
590             result = CliRunner().invoke(black.main, [str(workspace)])
591             self.assertEqual(result.exit_code, 0)
592             with one.open("r") as fobj:
593                 self.assertEqual(fobj.read(), "print('hello')")
594             with two.open("r") as fobj:
595                 self.assertEqual(fobj.read(), 'print("hello")\n')
596             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
597             self.assertIn(one, cache)
598             self.assertIn(two, cache)
599
600     def test_no_cache_when_writeback_diff(self) -> None:
601         with cache_dir() as workspace:
602             src = (workspace / "test.py").resolve()
603             with src.open("w") as fobj:
604                 fobj.write("print('hello')")
605             result = CliRunner().invoke(black.main, [str(src), "--diff"])
606             self.assertEqual(result.exit_code, 0)
607             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
608             self.assertFalse(cache_file.exists())
609
610     def test_no_cache_when_stdin(self) -> None:
611         with cache_dir():
612             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
613             self.assertEqual(result.exit_code, 0)
614             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
615             self.assertFalse(cache_file.exists())
616
617     def test_read_cache_no_cachefile(self) -> None:
618         with cache_dir():
619             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
620
621     def test_write_cache_read_cache(self) -> None:
622         with cache_dir() as workspace:
623             src = (workspace / "test.py").resolve()
624             src.touch()
625             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
626             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
627             self.assertIn(src, cache)
628             self.assertEqual(cache[src], black.get_cache_info(src))
629
630     def test_filter_cached(self) -> None:
631         with TemporaryDirectory() as workspace:
632             path = Path(workspace)
633             uncached = (path / "uncached").resolve()
634             cached = (path / "cached").resolve()
635             cached_but_changed = (path / "changed").resolve()
636             uncached.touch()
637             cached.touch()
638             cached_but_changed.touch()
639             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
640             todo, done = black.filter_cached(
641                 cache, [uncached, cached, cached_but_changed]
642             )
643             self.assertEqual(todo, [uncached, cached_but_changed])
644             self.assertEqual(done, [cached])
645
646     def test_write_cache_creates_directory_if_needed(self) -> None:
647         with cache_dir(exists=False) as workspace:
648             self.assertFalse(workspace.exists())
649             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
650             self.assertTrue(workspace.exists())
651
652     @event_loop(close=False)
653     def test_failed_formatting_does_not_get_cached(self) -> None:
654         with cache_dir() as workspace, patch(
655             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
656         ):
657             failing = (workspace / "failing.py").resolve()
658             with failing.open("w") as fobj:
659                 fobj.write("not actually python")
660             clean = (workspace / "clean.py").resolve()
661             with clean.open("w") as fobj:
662                 fobj.write('print("hello")\n')
663             result = CliRunner().invoke(black.main, [str(workspace)])
664             self.assertEqual(result.exit_code, 123)
665             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
666             self.assertNotIn(failing, cache)
667             self.assertIn(clean, cache)
668
669     def test_write_cache_write_fail(self) -> None:
670         with cache_dir(), patch.object(Path, "open") as mock:
671             mock.side_effect = OSError
672             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
673
674     def test_check_diff_use_together(self) -> None:
675         with cache_dir():
676             # Files which will be reformatted.
677             src1 = (THIS_DIR / "string_quotes.py").resolve()
678             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
679             self.assertEqual(result.exit_code, 1)
680
681             # Files which will not be reformatted.
682             src2 = (THIS_DIR / "composition.py").resolve()
683             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
684             self.assertEqual(result.exit_code, 0)
685
686             # Multi file command.
687             result = CliRunner().invoke(
688                 black.main, [str(src1), str(src2), "--diff", "--check"]
689             )
690             self.assertEqual(result.exit_code, 1)
691
692     def test_no_files(self) -> None:
693         with cache_dir():
694             # Without an argument, black exits with error code 0.
695             result = CliRunner().invoke(black.main, [])
696             self.assertEqual(result.exit_code, 0)
697
698     def test_broken_symlink(self) -> None:
699         with cache_dir() as workspace:
700             symlink = workspace / "broken_link.py"
701             symlink.symlink_to("nonexistent.py")
702             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
703             self.assertEqual(result.exit_code, 0)
704
705     def test_read_cache_line_lengths(self) -> None:
706         with cache_dir() as workspace:
707             path = (workspace / "file.py").resolve()
708             path.touch()
709             black.write_cache({}, [path], 1)
710             one = black.read_cache(1)
711             self.assertIn(path, one)
712             two = black.read_cache(2)
713             self.assertNotIn(path, two)
714
715
716 if __name__ == "__main__":
717     unittest.main()