]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

acks += Stavros; document fix, add to Pipfile
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14
15 from click import unstyle
16 from click.testing import CliRunner
17
18 import black
19
20 ll = 88
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
26
27
28 def dump_to_stderr(*output: str) -> str:
29     return "\n" + "\n".join(output) + "\n"
30
31
32 def read_data(name: str) -> Tuple[str, str]:
33     """read_data('test_name') -> 'input', 'output'"""
34     if not name.endswith((".py", ".pyi", ".out", ".diff")):
35         name += ".py"
36     _input: List[str] = []
37     _output: List[str] = []
38     with open(THIS_DIR / name, "r", encoding="utf8") as test:
39         lines = test.readlines()
40     result = _input
41     for line in lines:
42         line = line.replace(EMPTY_LINE, "")
43         if line.rstrip() == "# output":
44             result = _output
45             continue
46
47         result.append(line)
48     if _input and not _output:
49         # If there's no output marker, treat the entire file as already pre-formatted.
50         _output = _input[:]
51     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
52
53
54 @contextmanager
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56     with TemporaryDirectory() as workspace:
57         cache_dir = Path(workspace)
58         if not exists:
59             cache_dir = cache_dir / "new"
60         with patch("black.CACHE_DIR", cache_dir):
61             yield cache_dir
62
63
64 @contextmanager
65 def event_loop(close: bool) -> Iterator[None]:
66     policy = asyncio.get_event_loop_policy()
67     old_loop = policy.get_event_loop()
68     loop = policy.new_event_loop()
69     asyncio.set_event_loop(loop)
70     try:
71         yield
72
73     finally:
74         policy.set_event_loop(old_loop)
75         if close:
76             loop.close()
77
78
79 class BlackTestCase(unittest.TestCase):
80     maxDiff = None
81
82     def assertFormatEqual(self, expected: str, actual: str) -> None:
83         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84             bdv: black.DebugVisitor[Any]
85             black.out("Expected tree:", fg="green")
86             try:
87                 exp_node = black.lib2to3_parse(expected)
88                 bdv = black.DebugVisitor()
89                 list(bdv.visit(exp_node))
90             except Exception as ve:
91                 black.err(str(ve))
92             black.out("Actual tree:", fg="red")
93             try:
94                 exp_node = black.lib2to3_parse(actual)
95                 bdv = black.DebugVisitor()
96                 list(bdv.visit(exp_node))
97             except Exception as ve:
98                 black.err(str(ve))
99         self.assertEqual(expected, actual)
100
101     @patch("black.dump_to_file", dump_to_stderr)
102     def test_self(self) -> None:
103         source, expected = read_data("test_black")
104         actual = fs(source)
105         self.assertFormatEqual(expected, actual)
106         black.assert_equivalent(source, actual)
107         black.assert_stable(source, actual, line_length=ll)
108         self.assertFalse(ff(THIS_FILE))
109
110     @patch("black.dump_to_file", dump_to_stderr)
111     def test_black(self) -> None:
112         source, expected = read_data("../black")
113         actual = fs(source)
114         self.assertFormatEqual(expected, actual)
115         black.assert_equivalent(source, actual)
116         black.assert_stable(source, actual, line_length=ll)
117         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
118
119     def test_piping(self) -> None:
120         source, expected = read_data("../black")
121         hold_stdin, hold_stdout = sys.stdin, sys.stdout
122         try:
123             sys.stdin, sys.stdout = StringIO(source), StringIO()
124             sys.stdin.name = "<stdin>"
125             black.format_stdin_to_stdout(
126                 line_length=ll, fast=True, write_back=black.WriteBack.YES
127             )
128             sys.stdout.seek(0)
129             actual = sys.stdout.read()
130         finally:
131             sys.stdin, sys.stdout = hold_stdin, hold_stdout
132         self.assertFormatEqual(expected, actual)
133         black.assert_equivalent(source, actual)
134         black.assert_stable(source, actual, line_length=ll)
135
136     def test_piping_diff(self) -> None:
137         source, _ = read_data("expression.py")
138         expected, _ = read_data("expression.diff")
139         hold_stdin, hold_stdout = sys.stdin, sys.stdout
140         try:
141             sys.stdin, sys.stdout = StringIO(source), StringIO()
142             sys.stdin.name = "<stdin>"
143             black.format_stdin_to_stdout(
144                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
145             )
146             sys.stdout.seek(0)
147             actual = sys.stdout.read()
148         finally:
149             sys.stdin, sys.stdout = hold_stdin, hold_stdout
150         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
151         self.assertEqual(expected, actual)
152
153     @patch("black.dump_to_file", dump_to_stderr)
154     def test_setup(self) -> None:
155         source, expected = read_data("../setup")
156         actual = fs(source)
157         self.assertFormatEqual(expected, actual)
158         black.assert_equivalent(source, actual)
159         black.assert_stable(source, actual, line_length=ll)
160         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
161
162     @patch("black.dump_to_file", dump_to_stderr)
163     def test_function(self) -> None:
164         source, expected = read_data("function")
165         actual = fs(source)
166         self.assertFormatEqual(expected, actual)
167         black.assert_equivalent(source, actual)
168         black.assert_stable(source, actual, line_length=ll)
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_function2(self) -> None:
172         source, expected = read_data("function2")
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, line_length=ll)
177
178     @patch("black.dump_to_file", dump_to_stderr)
179     def test_expression(self) -> None:
180         source, expected = read_data("expression")
181         actual = fs(source)
182         self.assertFormatEqual(expected, actual)
183         black.assert_equivalent(source, actual)
184         black.assert_stable(source, actual, line_length=ll)
185
186     def test_expression_ff(self) -> None:
187         source, expected = read_data("expression")
188         tmp_file = Path(black.dump_to_file(source))
189         try:
190             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
191             with open(tmp_file, encoding="utf8") as f:
192                 actual = f.read()
193         finally:
194             os.unlink(tmp_file)
195         self.assertFormatEqual(expected, actual)
196         with patch("black.dump_to_file", dump_to_stderr):
197             black.assert_equivalent(source, actual)
198             black.assert_stable(source, actual, line_length=ll)
199
200     def test_expression_diff(self) -> None:
201         source, _ = read_data("expression.py")
202         expected, _ = read_data("expression.diff")
203         tmp_file = Path(black.dump_to_file(source))
204         hold_stdout = sys.stdout
205         try:
206             sys.stdout = StringIO()
207             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
208             sys.stdout.seek(0)
209             actual = sys.stdout.read()
210             actual = actual.replace(str(tmp_file), "<stdin>")
211         finally:
212             sys.stdout = hold_stdout
213             os.unlink(tmp_file)
214         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
215         if expected != actual:
216             dump = black.dump_to_file(actual)
217             msg = (
218                 f"Expected diff isn't equal to the actual. If you made changes "
219                 f"to expression.py and this is an anticipated difference, "
220                 f"overwrite tests/expression.diff with {dump}"
221             )
222             self.assertEqual(expected, actual, msg)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_fstring(self) -> None:
226         source, expected = read_data("fstring")
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_string_quotes(self) -> None:
234         source, expected = read_data("string_quotes")
235         actual = fs(source)
236         self.assertFormatEqual(expected, actual)
237         black.assert_equivalent(source, actual)
238         black.assert_stable(source, actual, line_length=ll)
239         mode = black.FileMode.NO_STRING_NORMALIZATION
240         not_normalized = fs(source, mode=mode)
241         self.assertFormatEqual(source, not_normalized)
242         black.assert_equivalent(source, not_normalized)
243         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
244
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_slices(self) -> None:
247         source, expected = read_data("slices")
248         actual = fs(source)
249         self.assertFormatEqual(expected, actual)
250         black.assert_equivalent(source, actual)
251         black.assert_stable(source, actual, line_length=ll)
252
253     @patch("black.dump_to_file", dump_to_stderr)
254     def test_comments(self) -> None:
255         source, expected = read_data("comments")
256         actual = fs(source)
257         self.assertFormatEqual(expected, actual)
258         black.assert_equivalent(source, actual)
259         black.assert_stable(source, actual, line_length=ll)
260
261     @patch("black.dump_to_file", dump_to_stderr)
262     def test_comments2(self) -> None:
263         source, expected = read_data("comments2")
264         actual = fs(source)
265         self.assertFormatEqual(expected, actual)
266         black.assert_equivalent(source, actual)
267         black.assert_stable(source, actual, line_length=ll)
268
269     @patch("black.dump_to_file", dump_to_stderr)
270     def test_comments3(self) -> None:
271         source, expected = read_data("comments3")
272         actual = fs(source)
273         self.assertFormatEqual(expected, actual)
274         black.assert_equivalent(source, actual)
275         black.assert_stable(source, actual, line_length=ll)
276
277     @patch("black.dump_to_file", dump_to_stderr)
278     def test_comments4(self) -> None:
279         source, expected = read_data("comments4")
280         actual = fs(source)
281         self.assertFormatEqual(expected, actual)
282         black.assert_equivalent(source, actual)
283         black.assert_stable(source, actual, line_length=ll)
284
285     @patch("black.dump_to_file", dump_to_stderr)
286     def test_comments5(self) -> None:
287         source, expected = read_data("comments5")
288         actual = fs(source)
289         self.assertFormatEqual(expected, actual)
290         black.assert_equivalent(source, actual)
291         black.assert_stable(source, actual, line_length=ll)
292
293     @patch("black.dump_to_file", dump_to_stderr)
294     def test_cantfit(self) -> None:
295         source, expected = read_data("cantfit")
296         actual = fs(source)
297         self.assertFormatEqual(expected, actual)
298         black.assert_equivalent(source, actual)
299         black.assert_stable(source, actual, line_length=ll)
300
301     @patch("black.dump_to_file", dump_to_stderr)
302     def test_import_spacing(self) -> None:
303         source, expected = read_data("import_spacing")
304         actual = fs(source)
305         self.assertFormatEqual(expected, actual)
306         black.assert_equivalent(source, actual)
307         black.assert_stable(source, actual, line_length=ll)
308
309     @patch("black.dump_to_file", dump_to_stderr)
310     def test_composition(self) -> None:
311         source, expected = read_data("composition")
312         actual = fs(source)
313         self.assertFormatEqual(expected, actual)
314         black.assert_equivalent(source, actual)
315         black.assert_stable(source, actual, line_length=ll)
316
317     @patch("black.dump_to_file", dump_to_stderr)
318     def test_empty_lines(self) -> None:
319         source, expected = read_data("empty_lines")
320         actual = fs(source)
321         self.assertFormatEqual(expected, actual)
322         black.assert_equivalent(source, actual)
323         black.assert_stable(source, actual, line_length=ll)
324
325     @patch("black.dump_to_file", dump_to_stderr)
326     def test_string_prefixes(self) -> None:
327         source, expected = read_data("string_prefixes")
328         actual = fs(source)
329         self.assertFormatEqual(expected, actual)
330         black.assert_equivalent(source, actual)
331         black.assert_stable(source, actual, line_length=ll)
332
333     @patch("black.dump_to_file", dump_to_stderr)
334     def test_python2(self) -> None:
335         source, expected = read_data("python2")
336         actual = fs(source)
337         self.assertFormatEqual(expected, actual)
338         # black.assert_equivalent(source, actual)
339         black.assert_stable(source, actual, line_length=ll)
340
341     @patch("black.dump_to_file", dump_to_stderr)
342     def test_python2_unicode_literals(self) -> None:
343         source, expected = read_data("python2_unicode_literals")
344         actual = fs(source)
345         self.assertFormatEqual(expected, actual)
346         black.assert_stable(source, actual, line_length=ll)
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def test_stub(self) -> None:
350         mode = black.FileMode.PYI
351         source, expected = read_data("stub.pyi")
352         actual = fs(source, mode=mode)
353         self.assertFormatEqual(expected, actual)
354         black.assert_stable(source, actual, line_length=ll, mode=mode)
355
356     @patch("black.dump_to_file", dump_to_stderr)
357     def test_fmtonoff(self) -> None:
358         source, expected = read_data("fmtonoff")
359         actual = fs(source)
360         self.assertFormatEqual(expected, actual)
361         black.assert_equivalent(source, actual)
362         black.assert_stable(source, actual, line_length=ll)
363
364     @patch("black.dump_to_file", dump_to_stderr)
365     def test_remove_empty_parentheses_after_class(self) -> None:
366         source, expected = read_data("class_blank_parentheses")
367         actual = fs(source)
368         self.assertFormatEqual(expected, actual)
369         black.assert_equivalent(source, actual)
370         black.assert_stable(source, actual, line_length=ll)
371
372     @patch("black.dump_to_file", dump_to_stderr)
373     def test_new_line_between_class_and_code(self) -> None:
374         source, expected = read_data("class_methods_new_line")
375         actual = fs(source)
376         self.assertFormatEqual(expected, actual)
377         black.assert_equivalent(source, actual)
378         black.assert_stable(source, actual, line_length=ll)
379
380     def test_report(self) -> None:
381         report = black.Report()
382         out_lines = []
383         err_lines = []
384
385         def out(msg: str, **kwargs: Any) -> None:
386             out_lines.append(msg)
387
388         def err(msg: str, **kwargs: Any) -> None:
389             err_lines.append(msg)
390
391         with patch("black.out", out), patch("black.err", err):
392             report.done(Path("f1"), black.Changed.NO)
393             self.assertEqual(len(out_lines), 1)
394             self.assertEqual(len(err_lines), 0)
395             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
396             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
397             self.assertEqual(report.return_code, 0)
398             report.done(Path("f2"), black.Changed.YES)
399             self.assertEqual(len(out_lines), 2)
400             self.assertEqual(len(err_lines), 0)
401             self.assertEqual(out_lines[-1], "reformatted f2")
402             self.assertEqual(
403                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
404             )
405             report.done(Path("f3"), black.Changed.CACHED)
406             self.assertEqual(len(out_lines), 3)
407             self.assertEqual(len(err_lines), 0)
408             self.assertEqual(
409                 out_lines[-1], "f3 wasn't modified on disk since last run."
410             )
411             self.assertEqual(
412                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
413             )
414             self.assertEqual(report.return_code, 0)
415             report.check = True
416             self.assertEqual(report.return_code, 1)
417             report.check = False
418             report.failed(Path("e1"), "boom")
419             self.assertEqual(len(out_lines), 3)
420             self.assertEqual(len(err_lines), 1)
421             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
422             self.assertEqual(
423                 unstyle(str(report)),
424                 "1 file reformatted, 2 files left unchanged, "
425                 "1 file failed to reformat.",
426             )
427             self.assertEqual(report.return_code, 123)
428             report.done(Path("f3"), black.Changed.YES)
429             self.assertEqual(len(out_lines), 4)
430             self.assertEqual(len(err_lines), 1)
431             self.assertEqual(out_lines[-1], "reformatted f3")
432             self.assertEqual(
433                 unstyle(str(report)),
434                 "2 files reformatted, 2 files left unchanged, "
435                 "1 file failed to reformat.",
436             )
437             self.assertEqual(report.return_code, 123)
438             report.failed(Path("e2"), "boom")
439             self.assertEqual(len(out_lines), 4)
440             self.assertEqual(len(err_lines), 2)
441             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
442             self.assertEqual(
443                 unstyle(str(report)),
444                 "2 files reformatted, 2 files left unchanged, "
445                 "2 files failed to reformat.",
446             )
447             self.assertEqual(report.return_code, 123)
448             report.done(Path("f4"), black.Changed.NO)
449             self.assertEqual(len(out_lines), 5)
450             self.assertEqual(len(err_lines), 2)
451             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
452             self.assertEqual(
453                 unstyle(str(report)),
454                 "2 files reformatted, 3 files left unchanged, "
455                 "2 files failed to reformat.",
456             )
457             self.assertEqual(report.return_code, 123)
458             report.check = True
459             self.assertEqual(
460                 unstyle(str(report)),
461                 "2 files would be reformatted, 3 files would be left unchanged, "
462                 "2 files would fail to reformat.",
463             )
464
465     def test_is_python36(self) -> None:
466         node = black.lib2to3_parse("def f(*, arg): ...\n")
467         self.assertFalse(black.is_python36(node))
468         node = black.lib2to3_parse("def f(*, arg,): ...\n")
469         self.assertTrue(black.is_python36(node))
470         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
471         self.assertTrue(black.is_python36(node))
472         source, expected = read_data("function")
473         node = black.lib2to3_parse(source)
474         self.assertTrue(black.is_python36(node))
475         node = black.lib2to3_parse(expected)
476         self.assertTrue(black.is_python36(node))
477         source, expected = read_data("expression")
478         node = black.lib2to3_parse(source)
479         self.assertFalse(black.is_python36(node))
480         node = black.lib2to3_parse(expected)
481         self.assertFalse(black.is_python36(node))
482
483     def test_get_future_imports(self) -> None:
484         node = black.lib2to3_parse("\n")
485         self.assertEqual(set(), black.get_future_imports(node))
486         node = black.lib2to3_parse("from __future__ import black\n")
487         self.assertEqual({"black"}, black.get_future_imports(node))
488         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
489         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
490         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
491         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
492         node = black.lib2to3_parse(
493             "from __future__ import multiple\nfrom __future__ import imports\n"
494         )
495         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
496         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
497         self.assertEqual({"black"}, black.get_future_imports(node))
498         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
499         self.assertEqual({"black"}, black.get_future_imports(node))
500         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
501         self.assertEqual(set(), black.get_future_imports(node))
502         node = black.lib2to3_parse("from some.module import black\n")
503         self.assertEqual(set(), black.get_future_imports(node))
504
505     def test_debug_visitor(self) -> None:
506         source, _ = read_data("debug_visitor.py")
507         expected, _ = read_data("debug_visitor.out")
508         out_lines = []
509         err_lines = []
510
511         def out(msg: str, **kwargs: Any) -> None:
512             out_lines.append(msg)
513
514         def err(msg: str, **kwargs: Any) -> None:
515             err_lines.append(msg)
516
517         with patch("black.out", out), patch("black.err", err):
518             black.DebugVisitor.show(source)
519         actual = "\n".join(out_lines) + "\n"
520         log_name = ""
521         if expected != actual:
522             log_name = black.dump_to_file(*out_lines)
523         self.assertEqual(
524             expected,
525             actual,
526             f"AST print out is different. Actual version dumped to {log_name}",
527         )
528
529     def test_format_file_contents(self) -> None:
530         empty = ""
531         with self.assertRaises(black.NothingChanged):
532             black.format_file_contents(empty, line_length=ll, fast=False)
533         just_nl = "\n"
534         with self.assertRaises(black.NothingChanged):
535             black.format_file_contents(just_nl, line_length=ll, fast=False)
536         same = "l = [1, 2, 3]\n"
537         with self.assertRaises(black.NothingChanged):
538             black.format_file_contents(same, line_length=ll, fast=False)
539         different = "l = [1,2,3]"
540         expected = same
541         actual = black.format_file_contents(different, line_length=ll, fast=False)
542         self.assertEqual(expected, actual)
543         invalid = "return if you can"
544         with self.assertRaises(ValueError) as e:
545             black.format_file_contents(invalid, line_length=ll, fast=False)
546         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
547
548     def test_endmarker(self) -> None:
549         n = black.lib2to3_parse("\n")
550         self.assertEqual(n.type, black.syms.file_input)
551         self.assertEqual(len(n.children), 1)
552         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
553
554     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
555     def test_assertFormatEqual(self) -> None:
556         out_lines = []
557         err_lines = []
558
559         def out(msg: str, **kwargs: Any) -> None:
560             out_lines.append(msg)
561
562         def err(msg: str, **kwargs: Any) -> None:
563             err_lines.append(msg)
564
565         with patch("black.out", out), patch("black.err", err):
566             with self.assertRaises(AssertionError):
567                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
568
569         out_str = "".join(out_lines)
570         self.assertTrue("Expected tree:" in out_str)
571         self.assertTrue("Actual tree:" in out_str)
572         self.assertEqual("".join(err_lines), "")
573
574     def test_cache_broken_file(self) -> None:
575         mode = black.FileMode.AUTO_DETECT
576         with cache_dir() as workspace:
577             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
578             with cache_file.open("w") as fobj:
579                 fobj.write("this is not a pickle")
580             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
581             src = (workspace / "test.py").resolve()
582             with src.open("w") as fobj:
583                 fobj.write("print('hello')")
584             result = CliRunner().invoke(black.main, [str(src)])
585             self.assertEqual(result.exit_code, 0)
586             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
587             self.assertIn(src, cache)
588
589     def test_cache_single_file_already_cached(self) -> None:
590         mode = black.FileMode.AUTO_DETECT
591         with cache_dir() as workspace:
592             src = (workspace / "test.py").resolve()
593             with src.open("w") as fobj:
594                 fobj.write("print('hello')")
595             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
596             result = CliRunner().invoke(black.main, [str(src)])
597             self.assertEqual(result.exit_code, 0)
598             with src.open("r") as fobj:
599                 self.assertEqual(fobj.read(), "print('hello')")
600
601     @event_loop(close=False)
602     def test_cache_multiple_files(self) -> None:
603         mode = black.FileMode.AUTO_DETECT
604         with cache_dir() as workspace, patch(
605             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
606         ):
607             one = (workspace / "one.py").resolve()
608             with one.open("w") as fobj:
609                 fobj.write("print('hello')")
610             two = (workspace / "two.py").resolve()
611             with two.open("w") as fobj:
612                 fobj.write("print('hello')")
613             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
614             result = CliRunner().invoke(black.main, [str(workspace)])
615             self.assertEqual(result.exit_code, 0)
616             with one.open("r") as fobj:
617                 self.assertEqual(fobj.read(), "print('hello')")
618             with two.open("r") as fobj:
619                 self.assertEqual(fobj.read(), 'print("hello")\n')
620             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
621             self.assertIn(one, cache)
622             self.assertIn(two, cache)
623
624     def test_no_cache_when_writeback_diff(self) -> None:
625         mode = black.FileMode.AUTO_DETECT
626         with cache_dir() as workspace:
627             src = (workspace / "test.py").resolve()
628             with src.open("w") as fobj:
629                 fobj.write("print('hello')")
630             result = CliRunner().invoke(black.main, [str(src), "--diff"])
631             self.assertEqual(result.exit_code, 0)
632             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
633             self.assertFalse(cache_file.exists())
634
635     def test_no_cache_when_stdin(self) -> None:
636         mode = black.FileMode.AUTO_DETECT
637         with cache_dir():
638             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
639             self.assertEqual(result.exit_code, 0)
640             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
641             self.assertFalse(cache_file.exists())
642
643     def test_read_cache_no_cachefile(self) -> None:
644         mode = black.FileMode.AUTO_DETECT
645         with cache_dir():
646             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
647
648     def test_write_cache_read_cache(self) -> None:
649         mode = black.FileMode.AUTO_DETECT
650         with cache_dir() as workspace:
651             src = (workspace / "test.py").resolve()
652             src.touch()
653             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
654             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
655             self.assertIn(src, cache)
656             self.assertEqual(cache[src], black.get_cache_info(src))
657
658     def test_filter_cached(self) -> None:
659         with TemporaryDirectory() as workspace:
660             path = Path(workspace)
661             uncached = (path / "uncached").resolve()
662             cached = (path / "cached").resolve()
663             cached_but_changed = (path / "changed").resolve()
664             uncached.touch()
665             cached.touch()
666             cached_but_changed.touch()
667             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
668             todo, done = black.filter_cached(
669                 cache, [uncached, cached, cached_but_changed]
670             )
671             self.assertEqual(todo, [uncached, cached_but_changed])
672             self.assertEqual(done, [cached])
673
674     def test_write_cache_creates_directory_if_needed(self) -> None:
675         mode = black.FileMode.AUTO_DETECT
676         with cache_dir(exists=False) as workspace:
677             self.assertFalse(workspace.exists())
678             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
679             self.assertTrue(workspace.exists())
680
681     @event_loop(close=False)
682     def test_failed_formatting_does_not_get_cached(self) -> None:
683         mode = black.FileMode.AUTO_DETECT
684         with cache_dir() as workspace, patch(
685             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
686         ):
687             failing = (workspace / "failing.py").resolve()
688             with failing.open("w") as fobj:
689                 fobj.write("not actually python")
690             clean = (workspace / "clean.py").resolve()
691             with clean.open("w") as fobj:
692                 fobj.write('print("hello")\n')
693             result = CliRunner().invoke(black.main, [str(workspace)])
694             self.assertEqual(result.exit_code, 123)
695             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
696             self.assertNotIn(failing, cache)
697             self.assertIn(clean, cache)
698
699     def test_write_cache_write_fail(self) -> None:
700         mode = black.FileMode.AUTO_DETECT
701         with cache_dir(), patch.object(Path, "open") as mock:
702             mock.side_effect = OSError
703             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
704
705     @event_loop(close=False)
706     def test_check_diff_use_together(self) -> None:
707         with cache_dir():
708             # Files which will be reformatted.
709             src1 = (THIS_DIR / "string_quotes.py").resolve()
710             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
711             self.assertEqual(result.exit_code, 1)
712
713             # Files which will not be reformatted.
714             src2 = (THIS_DIR / "composition.py").resolve()
715             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
716             self.assertEqual(result.exit_code, 0)
717
718             # Multi file command.
719             result = CliRunner().invoke(
720                 black.main, [str(src1), str(src2), "--diff", "--check"]
721             )
722             self.assertEqual(result.exit_code, 1, result.output)
723
724     def test_no_files(self) -> None:
725         with cache_dir():
726             # Without an argument, black exits with error code 0.
727             result = CliRunner().invoke(black.main, [])
728             self.assertEqual(result.exit_code, 0)
729
730     def test_broken_symlink(self) -> None:
731         with cache_dir() as workspace:
732             symlink = workspace / "broken_link.py"
733             symlink.symlink_to("nonexistent.py")
734             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
735             self.assertEqual(result.exit_code, 0)
736
737     def test_read_cache_line_lengths(self) -> None:
738         mode = black.FileMode.AUTO_DETECT
739         with cache_dir() as workspace:
740             path = (workspace / "file.py").resolve()
741             path.touch()
742             black.write_cache({}, [path], 1, mode)
743             one = black.read_cache(1, mode)
744             self.assertIn(path, one)
745             two = black.read_cache(2, mode)
746             self.assertNotIn(path, two)
747
748     def test_single_file_force_pyi(self) -> None:
749         reg_mode = black.FileMode.AUTO_DETECT
750         pyi_mode = black.FileMode.PYI
751         contents, expected = read_data("force_pyi")
752         with cache_dir() as workspace:
753             path = (workspace / "file.py").resolve()
754             with open(path, "w") as fh:
755                 fh.write(contents)
756             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
757             self.assertEqual(result.exit_code, 0)
758             with open(path, "r") as fh:
759                 actual = fh.read()
760             # verify cache with --pyi is separate
761             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
762             self.assertIn(path, pyi_cache)
763             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
764             self.assertNotIn(path, normal_cache)
765         self.assertEqual(actual, expected)
766
767     @event_loop(close=False)
768     def test_multi_file_force_pyi(self) -> None:
769         reg_mode = black.FileMode.AUTO_DETECT
770         pyi_mode = black.FileMode.PYI
771         contents, expected = read_data("force_pyi")
772         with cache_dir() as workspace:
773             paths = [
774                 (workspace / "file1.py").resolve(),
775                 (workspace / "file2.py").resolve(),
776             ]
777             for path in paths:
778                 with open(path, "w") as fh:
779                     fh.write(contents)
780             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
781             self.assertEqual(result.exit_code, 0)
782             for path in paths:
783                 with open(path, "r") as fh:
784                     actual = fh.read()
785                 self.assertEqual(actual, expected)
786             # verify cache with --pyi is separate
787             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
788             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
789             for path in paths:
790                 self.assertIn(path, pyi_cache)
791                 self.assertNotIn(path, normal_cache)
792
793     def test_pipe_force_pyi(self) -> None:
794         source, expected = read_data("force_pyi")
795         result = CliRunner().invoke(black.main, ["-", "-q", "--pyi"], input=source)
796         self.assertEqual(result.exit_code, 0)
797         actual = result.output
798         self.assertFormatEqual(actual, expected)
799
800     def test_single_file_force_py36(self) -> None:
801         reg_mode = black.FileMode.AUTO_DETECT
802         py36_mode = black.FileMode.PYTHON36
803         source, expected = read_data("force_py36")
804         with cache_dir() as workspace:
805             path = (workspace / "file.py").resolve()
806             with open(path, "w") as fh:
807                 fh.write(source)
808             result = CliRunner().invoke(black.main, [str(path), "--py36"])
809             self.assertEqual(result.exit_code, 0)
810             with open(path, "r") as fh:
811                 actual = fh.read()
812             # verify cache with --py36 is separate
813             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
814             self.assertIn(path, py36_cache)
815             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
816             self.assertNotIn(path, normal_cache)
817         self.assertEqual(actual, expected)
818
819     @event_loop(close=False)
820     def test_multi_file_force_py36(self) -> None:
821         reg_mode = black.FileMode.AUTO_DETECT
822         py36_mode = black.FileMode.PYTHON36
823         source, expected = read_data("force_py36")
824         with cache_dir() as workspace:
825             paths = [
826                 (workspace / "file1.py").resolve(),
827                 (workspace / "file2.py").resolve(),
828             ]
829             for path in paths:
830                 with open(path, "w") as fh:
831                     fh.write(source)
832             result = CliRunner().invoke(
833                 black.main, [str(p) for p in paths] + ["--py36"]
834             )
835             self.assertEqual(result.exit_code, 0)
836             for path in paths:
837                 with open(path, "r") as fh:
838                     actual = fh.read()
839                 self.assertEqual(actual, expected)
840             # verify cache with --py36 is separate
841             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
842             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
843             for path in paths:
844                 self.assertIn(path, pyi_cache)
845                 self.assertNotIn(path, normal_cache)
846
847     def test_pipe_force_py36(self) -> None:
848         source, expected = read_data("force_py36")
849         result = CliRunner().invoke(black.main, ["-", "-q", "--py36"], input=source)
850         self.assertEqual(result.exit_code, 0)
851         actual = result.output
852         self.assertFormatEqual(actual, expected)
853
854
855 if __name__ == "__main__":
856     unittest.main()