]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Improve doc regarding PyCharm keyboard shortcut (#271)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14
15 from click import unstyle
16 from click.testing import CliRunner
17
18 import black
19
20 ll = 88
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
26
27
28 def dump_to_stderr(*output: str) -> str:
29     return "\n" + "\n".join(output) + "\n"
30
31
32 def read_data(name: str) -> Tuple[str, str]:
33     """read_data('test_name') -> 'input', 'output'"""
34     if not name.endswith((".py", ".pyi", ".out", ".diff")):
35         name += ".py"
36     _input: List[str] = []
37     _output: List[str] = []
38     with open(THIS_DIR / name, "r", encoding="utf8") as test:
39         lines = test.readlines()
40     result = _input
41     for line in lines:
42         line = line.replace(EMPTY_LINE, "")
43         if line.rstrip() == "# output":
44             result = _output
45             continue
46
47         result.append(line)
48     if _input and not _output:
49         # If there's no output marker, treat the entire file as already pre-formatted.
50         _output = _input[:]
51     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
52
53
54 @contextmanager
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56     with TemporaryDirectory() as workspace:
57         cache_dir = Path(workspace)
58         if not exists:
59             cache_dir = cache_dir / "new"
60         with patch("black.CACHE_DIR", cache_dir):
61             yield cache_dir
62
63
64 @contextmanager
65 def event_loop(close: bool) -> Iterator[None]:
66     policy = asyncio.get_event_loop_policy()
67     old_loop = policy.get_event_loop()
68     loop = policy.new_event_loop()
69     asyncio.set_event_loop(loop)
70     try:
71         yield
72
73     finally:
74         policy.set_event_loop(old_loop)
75         if close:
76             loop.close()
77
78
79 class BlackTestCase(unittest.TestCase):
80     maxDiff = None
81
82     def assertFormatEqual(self, expected: str, actual: str) -> None:
83         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84             bdv: black.DebugVisitor[Any]
85             black.out("Expected tree:", fg="green")
86             try:
87                 exp_node = black.lib2to3_parse(expected)
88                 bdv = black.DebugVisitor()
89                 list(bdv.visit(exp_node))
90             except Exception as ve:
91                 black.err(str(ve))
92             black.out("Actual tree:", fg="red")
93             try:
94                 exp_node = black.lib2to3_parse(actual)
95                 bdv = black.DebugVisitor()
96                 list(bdv.visit(exp_node))
97             except Exception as ve:
98                 black.err(str(ve))
99         self.assertEqual(expected, actual)
100
101     @patch("black.dump_to_file", dump_to_stderr)
102     def test_self(self) -> None:
103         source, expected = read_data("test_black")
104         actual = fs(source)
105         self.assertFormatEqual(expected, actual)
106         black.assert_equivalent(source, actual)
107         black.assert_stable(source, actual, line_length=ll)
108         self.assertFalse(ff(THIS_FILE))
109
110     @patch("black.dump_to_file", dump_to_stderr)
111     def test_black(self) -> None:
112         source, expected = read_data("../black")
113         actual = fs(source)
114         self.assertFormatEqual(expected, actual)
115         black.assert_equivalent(source, actual)
116         black.assert_stable(source, actual, line_length=ll)
117         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
118
119     def test_piping(self) -> None:
120         source, expected = read_data("../black")
121         hold_stdin, hold_stdout = sys.stdin, sys.stdout
122         try:
123             sys.stdin, sys.stdout = StringIO(source), StringIO()
124             sys.stdin.name = "<stdin>"
125             black.format_stdin_to_stdout(
126                 line_length=ll, fast=True, write_back=black.WriteBack.YES
127             )
128             sys.stdout.seek(0)
129             actual = sys.stdout.read()
130         finally:
131             sys.stdin, sys.stdout = hold_stdin, hold_stdout
132         self.assertFormatEqual(expected, actual)
133         black.assert_equivalent(source, actual)
134         black.assert_stable(source, actual, line_length=ll)
135
136     def test_piping_diff(self) -> None:
137         source, _ = read_data("expression.py")
138         expected, _ = read_data("expression.diff")
139         hold_stdin, hold_stdout = sys.stdin, sys.stdout
140         try:
141             sys.stdin, sys.stdout = StringIO(source), StringIO()
142             sys.stdin.name = "<stdin>"
143             black.format_stdin_to_stdout(
144                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
145             )
146             sys.stdout.seek(0)
147             actual = sys.stdout.read()
148         finally:
149             sys.stdin, sys.stdout = hold_stdin, hold_stdout
150         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
151         self.assertEqual(expected, actual)
152
153     @patch("black.dump_to_file", dump_to_stderr)
154     def test_setup(self) -> None:
155         source, expected = read_data("../setup")
156         actual = fs(source)
157         self.assertFormatEqual(expected, actual)
158         black.assert_equivalent(source, actual)
159         black.assert_stable(source, actual, line_length=ll)
160         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
161
162     @patch("black.dump_to_file", dump_to_stderr)
163     def test_function(self) -> None:
164         source, expected = read_data("function")
165         actual = fs(source)
166         self.assertFormatEqual(expected, actual)
167         black.assert_equivalent(source, actual)
168         black.assert_stable(source, actual, line_length=ll)
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_function2(self) -> None:
172         source, expected = read_data("function2")
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, line_length=ll)
177
178     @patch("black.dump_to_file", dump_to_stderr)
179     def test_expression(self) -> None:
180         source, expected = read_data("expression")
181         actual = fs(source)
182         self.assertFormatEqual(expected, actual)
183         black.assert_equivalent(source, actual)
184         black.assert_stable(source, actual, line_length=ll)
185
186     def test_expression_ff(self) -> None:
187         source, expected = read_data("expression")
188         tmp_file = Path(black.dump_to_file(source))
189         try:
190             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
191             with open(tmp_file, encoding="utf8") as f:
192                 actual = f.read()
193         finally:
194             os.unlink(tmp_file)
195         self.assertFormatEqual(expected, actual)
196         with patch("black.dump_to_file", dump_to_stderr):
197             black.assert_equivalent(source, actual)
198             black.assert_stable(source, actual, line_length=ll)
199
200     def test_expression_diff(self) -> None:
201         source, _ = read_data("expression.py")
202         expected, _ = read_data("expression.diff")
203         tmp_file = Path(black.dump_to_file(source))
204         hold_stdout = sys.stdout
205         try:
206             sys.stdout = StringIO()
207             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
208             sys.stdout.seek(0)
209             actual = sys.stdout.read()
210             actual = actual.replace(str(tmp_file), "<stdin>")
211         finally:
212             sys.stdout = hold_stdout
213             os.unlink(tmp_file)
214         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
215         if expected != actual:
216             dump = black.dump_to_file(actual)
217             msg = (
218                 f"Expected diff isn't equal to the actual. If you made changes "
219                 f"to expression.py and this is an anticipated difference, "
220                 f"overwrite tests/expression.diff with {dump}"
221             )
222             self.assertEqual(expected, actual, msg)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_fstring(self) -> None:
226         source, expected = read_data("fstring")
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_string_quotes(self) -> None:
234         source, expected = read_data("string_quotes")
235         actual = fs(source)
236         self.assertFormatEqual(expected, actual)
237         black.assert_equivalent(source, actual)
238         black.assert_stable(source, actual, line_length=ll)
239
240     @patch("black.dump_to_file", dump_to_stderr)
241     def test_slices(self) -> None:
242         source, expected = read_data("slices")
243         actual = fs(source)
244         self.assertFormatEqual(expected, actual)
245         black.assert_equivalent(source, actual)
246         black.assert_stable(source, actual, line_length=ll)
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_comments(self) -> None:
250         source, expected = read_data("comments")
251         actual = fs(source)
252         self.assertFormatEqual(expected, actual)
253         black.assert_equivalent(source, actual)
254         black.assert_stable(source, actual, line_length=ll)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_comments2(self) -> None:
258         source, expected = read_data("comments2")
259         actual = fs(source)
260         self.assertFormatEqual(expected, actual)
261         black.assert_equivalent(source, actual)
262         black.assert_stable(source, actual, line_length=ll)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_comments3(self) -> None:
266         source, expected = read_data("comments3")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, line_length=ll)
271
272     @patch("black.dump_to_file", dump_to_stderr)
273     def test_comments4(self) -> None:
274         source, expected = read_data("comments4")
275         actual = fs(source)
276         self.assertFormatEqual(expected, actual)
277         black.assert_equivalent(source, actual)
278         black.assert_stable(source, actual, line_length=ll)
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_comments5(self) -> None:
282         source, expected = read_data("comments5")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_equivalent(source, actual)
286         black.assert_stable(source, actual, line_length=ll)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_cantfit(self) -> None:
290         source, expected = read_data("cantfit")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_import_spacing(self) -> None:
298         source, expected = read_data("import_spacing")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_composition(self) -> None:
306         source, expected = read_data("composition")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     @patch("black.dump_to_file", dump_to_stderr)
313     def test_empty_lines(self) -> None:
314         source, expected = read_data("empty_lines")
315         actual = fs(source)
316         self.assertFormatEqual(expected, actual)
317         black.assert_equivalent(source, actual)
318         black.assert_stable(source, actual, line_length=ll)
319
320     @patch("black.dump_to_file", dump_to_stderr)
321     def test_string_prefixes(self) -> None:
322         source, expected = read_data("string_prefixes")
323         actual = fs(source)
324         self.assertFormatEqual(expected, actual)
325         black.assert_equivalent(source, actual)
326         black.assert_stable(source, actual, line_length=ll)
327
328     @patch("black.dump_to_file", dump_to_stderr)
329     def test_python2(self) -> None:
330         source, expected = read_data("python2")
331         actual = fs(source)
332         self.assertFormatEqual(expected, actual)
333         # black.assert_equivalent(source, actual)
334         black.assert_stable(source, actual, line_length=ll)
335
336     @patch("black.dump_to_file", dump_to_stderr)
337     def test_python2_unicode_literals(self) -> None:
338         source, expected = read_data("python2_unicode_literals")
339         actual = fs(source)
340         self.assertFormatEqual(expected, actual)
341         black.assert_stable(source, actual, line_length=ll)
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_stub(self) -> None:
345         mode = black.FileMode.PYI
346         source, expected = read_data("stub.pyi")
347         actual = fs(source, mode=mode)
348         self.assertFormatEqual(expected, actual)
349         black.assert_stable(source, actual, line_length=ll, mode=mode)
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_fmtonoff(self) -> None:
353         source, expected = read_data("fmtonoff")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, line_length=ll)
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_remove_empty_parentheses_after_class(self) -> None:
361         source, expected = read_data("class_blank_parentheses")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, line_length=ll)
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_new_line_between_class_and_code(self) -> None:
369         source, expected = read_data("class_methods_new_line")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, line_length=ll)
374
375     def test_report(self) -> None:
376         report = black.Report()
377         out_lines = []
378         err_lines = []
379
380         def out(msg: str, **kwargs: Any) -> None:
381             out_lines.append(msg)
382
383         def err(msg: str, **kwargs: Any) -> None:
384             err_lines.append(msg)
385
386         with patch("black.out", out), patch("black.err", err):
387             report.done(Path("f1"), black.Changed.NO)
388             self.assertEqual(len(out_lines), 1)
389             self.assertEqual(len(err_lines), 0)
390             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
391             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
392             self.assertEqual(report.return_code, 0)
393             report.done(Path("f2"), black.Changed.YES)
394             self.assertEqual(len(out_lines), 2)
395             self.assertEqual(len(err_lines), 0)
396             self.assertEqual(out_lines[-1], "reformatted f2")
397             self.assertEqual(
398                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
399             )
400             report.done(Path("f3"), black.Changed.CACHED)
401             self.assertEqual(len(out_lines), 3)
402             self.assertEqual(len(err_lines), 0)
403             self.assertEqual(
404                 out_lines[-1], "f3 wasn't modified on disk since last run."
405             )
406             self.assertEqual(
407                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
408             )
409             self.assertEqual(report.return_code, 0)
410             report.check = True
411             self.assertEqual(report.return_code, 1)
412             report.check = False
413             report.failed(Path("e1"), "boom")
414             self.assertEqual(len(out_lines), 3)
415             self.assertEqual(len(err_lines), 1)
416             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
417             self.assertEqual(
418                 unstyle(str(report)),
419                 "1 file reformatted, 2 files left unchanged, "
420                 "1 file failed to reformat.",
421             )
422             self.assertEqual(report.return_code, 123)
423             report.done(Path("f3"), black.Changed.YES)
424             self.assertEqual(len(out_lines), 4)
425             self.assertEqual(len(err_lines), 1)
426             self.assertEqual(out_lines[-1], "reformatted f3")
427             self.assertEqual(
428                 unstyle(str(report)),
429                 "2 files reformatted, 2 files left unchanged, "
430                 "1 file failed to reformat.",
431             )
432             self.assertEqual(report.return_code, 123)
433             report.failed(Path("e2"), "boom")
434             self.assertEqual(len(out_lines), 4)
435             self.assertEqual(len(err_lines), 2)
436             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
437             self.assertEqual(
438                 unstyle(str(report)),
439                 "2 files reformatted, 2 files left unchanged, "
440                 "2 files failed to reformat.",
441             )
442             self.assertEqual(report.return_code, 123)
443             report.done(Path("f4"), black.Changed.NO)
444             self.assertEqual(len(out_lines), 5)
445             self.assertEqual(len(err_lines), 2)
446             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
447             self.assertEqual(
448                 unstyle(str(report)),
449                 "2 files reformatted, 3 files left unchanged, "
450                 "2 files failed to reformat.",
451             )
452             self.assertEqual(report.return_code, 123)
453             report.check = True
454             self.assertEqual(
455                 unstyle(str(report)),
456                 "2 files would be reformatted, 3 files would be left unchanged, "
457                 "2 files would fail to reformat.",
458             )
459
460     def test_is_python36(self) -> None:
461         node = black.lib2to3_parse("def f(*, arg): ...\n")
462         self.assertFalse(black.is_python36(node))
463         node = black.lib2to3_parse("def f(*, arg,): ...\n")
464         self.assertTrue(black.is_python36(node))
465         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
466         self.assertTrue(black.is_python36(node))
467         source, expected = read_data("function")
468         node = black.lib2to3_parse(source)
469         self.assertTrue(black.is_python36(node))
470         node = black.lib2to3_parse(expected)
471         self.assertTrue(black.is_python36(node))
472         source, expected = read_data("expression")
473         node = black.lib2to3_parse(source)
474         self.assertFalse(black.is_python36(node))
475         node = black.lib2to3_parse(expected)
476         self.assertFalse(black.is_python36(node))
477
478     def test_get_future_imports(self) -> None:
479         node = black.lib2to3_parse("\n")
480         self.assertEqual(set(), black.get_future_imports(node))
481         node = black.lib2to3_parse("from __future__ import black\n")
482         self.assertEqual({"black"}, black.get_future_imports(node))
483         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
484         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
485         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
486         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
487         node = black.lib2to3_parse(
488             "from __future__ import multiple\nfrom __future__ import imports\n"
489         )
490         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
491         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
492         self.assertEqual({"black"}, black.get_future_imports(node))
493         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
494         self.assertEqual({"black"}, black.get_future_imports(node))
495         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
496         self.assertEqual(set(), black.get_future_imports(node))
497         node = black.lib2to3_parse("from some.module import black\n")
498         self.assertEqual(set(), black.get_future_imports(node))
499
500     def test_debug_visitor(self) -> None:
501         source, _ = read_data("debug_visitor.py")
502         expected, _ = read_data("debug_visitor.out")
503         out_lines = []
504         err_lines = []
505
506         def out(msg: str, **kwargs: Any) -> None:
507             out_lines.append(msg)
508
509         def err(msg: str, **kwargs: Any) -> None:
510             err_lines.append(msg)
511
512         with patch("black.out", out), patch("black.err", err):
513             black.DebugVisitor.show(source)
514         actual = "\n".join(out_lines) + "\n"
515         log_name = ""
516         if expected != actual:
517             log_name = black.dump_to_file(*out_lines)
518         self.assertEqual(
519             expected,
520             actual,
521             f"AST print out is different. Actual version dumped to {log_name}",
522         )
523
524     def test_format_file_contents(self) -> None:
525         empty = ""
526         with self.assertRaises(black.NothingChanged):
527             black.format_file_contents(empty, line_length=ll, fast=False)
528         just_nl = "\n"
529         with self.assertRaises(black.NothingChanged):
530             black.format_file_contents(just_nl, line_length=ll, fast=False)
531         same = "l = [1, 2, 3]\n"
532         with self.assertRaises(black.NothingChanged):
533             black.format_file_contents(same, line_length=ll, fast=False)
534         different = "l = [1,2,3]"
535         expected = same
536         actual = black.format_file_contents(different, line_length=ll, fast=False)
537         self.assertEqual(expected, actual)
538         invalid = "return if you can"
539         with self.assertRaises(ValueError) as e:
540             black.format_file_contents(invalid, line_length=ll, fast=False)
541         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
542
543     def test_endmarker(self) -> None:
544         n = black.lib2to3_parse("\n")
545         self.assertEqual(n.type, black.syms.file_input)
546         self.assertEqual(len(n.children), 1)
547         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
548
549     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
550     def test_assertFormatEqual(self) -> None:
551         out_lines = []
552         err_lines = []
553
554         def out(msg: str, **kwargs: Any) -> None:
555             out_lines.append(msg)
556
557         def err(msg: str, **kwargs: Any) -> None:
558             err_lines.append(msg)
559
560         with patch("black.out", out), patch("black.err", err):
561             with self.assertRaises(AssertionError):
562                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
563
564         out_str = "".join(out_lines)
565         self.assertTrue("Expected tree:" in out_str)
566         self.assertTrue("Actual tree:" in out_str)
567         self.assertEqual("".join(err_lines), "")
568
569     def test_cache_broken_file(self) -> None:
570         mode = black.FileMode.AUTO_DETECT
571         with cache_dir() as workspace:
572             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
573             with cache_file.open("w") as fobj:
574                 fobj.write("this is not a pickle")
575             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
576             src = (workspace / "test.py").resolve()
577             with src.open("w") as fobj:
578                 fobj.write("print('hello')")
579             result = CliRunner().invoke(black.main, [str(src)])
580             self.assertEqual(result.exit_code, 0)
581             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
582             self.assertIn(src, cache)
583
584     def test_cache_single_file_already_cached(self) -> None:
585         mode = black.FileMode.AUTO_DETECT
586         with cache_dir() as workspace:
587             src = (workspace / "test.py").resolve()
588             with src.open("w") as fobj:
589                 fobj.write("print('hello')")
590             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
591             result = CliRunner().invoke(black.main, [str(src)])
592             self.assertEqual(result.exit_code, 0)
593             with src.open("r") as fobj:
594                 self.assertEqual(fobj.read(), "print('hello')")
595
596     @event_loop(close=False)
597     def test_cache_multiple_files(self) -> None:
598         mode = black.FileMode.AUTO_DETECT
599         with cache_dir() as workspace, patch(
600             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
601         ):
602             one = (workspace / "one.py").resolve()
603             with one.open("w") as fobj:
604                 fobj.write("print('hello')")
605             two = (workspace / "two.py").resolve()
606             with two.open("w") as fobj:
607                 fobj.write("print('hello')")
608             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
609             result = CliRunner().invoke(black.main, [str(workspace)])
610             self.assertEqual(result.exit_code, 0)
611             with one.open("r") as fobj:
612                 self.assertEqual(fobj.read(), "print('hello')")
613             with two.open("r") as fobj:
614                 self.assertEqual(fobj.read(), 'print("hello")\n')
615             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
616             self.assertIn(one, cache)
617             self.assertIn(two, cache)
618
619     def test_no_cache_when_writeback_diff(self) -> None:
620         mode = black.FileMode.AUTO_DETECT
621         with cache_dir() as workspace:
622             src = (workspace / "test.py").resolve()
623             with src.open("w") as fobj:
624                 fobj.write("print('hello')")
625             result = CliRunner().invoke(black.main, [str(src), "--diff"])
626             self.assertEqual(result.exit_code, 0)
627             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
628             self.assertFalse(cache_file.exists())
629
630     def test_no_cache_when_stdin(self) -> None:
631         mode = black.FileMode.AUTO_DETECT
632         with cache_dir():
633             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
634             self.assertEqual(result.exit_code, 0)
635             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
636             self.assertFalse(cache_file.exists())
637
638     def test_read_cache_no_cachefile(self) -> None:
639         mode = black.FileMode.AUTO_DETECT
640         with cache_dir():
641             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
642
643     def test_write_cache_read_cache(self) -> None:
644         mode = black.FileMode.AUTO_DETECT
645         with cache_dir() as workspace:
646             src = (workspace / "test.py").resolve()
647             src.touch()
648             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
649             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
650             self.assertIn(src, cache)
651             self.assertEqual(cache[src], black.get_cache_info(src))
652
653     def test_filter_cached(self) -> None:
654         with TemporaryDirectory() as workspace:
655             path = Path(workspace)
656             uncached = (path / "uncached").resolve()
657             cached = (path / "cached").resolve()
658             cached_but_changed = (path / "changed").resolve()
659             uncached.touch()
660             cached.touch()
661             cached_but_changed.touch()
662             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
663             todo, done = black.filter_cached(
664                 cache, [uncached, cached, cached_but_changed]
665             )
666             self.assertEqual(todo, [uncached, cached_but_changed])
667             self.assertEqual(done, [cached])
668
669     def test_write_cache_creates_directory_if_needed(self) -> None:
670         mode = black.FileMode.AUTO_DETECT
671         with cache_dir(exists=False) as workspace:
672             self.assertFalse(workspace.exists())
673             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
674             self.assertTrue(workspace.exists())
675
676     @event_loop(close=False)
677     def test_failed_formatting_does_not_get_cached(self) -> None:
678         mode = black.FileMode.AUTO_DETECT
679         with cache_dir() as workspace, patch(
680             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
681         ):
682             failing = (workspace / "failing.py").resolve()
683             with failing.open("w") as fobj:
684                 fobj.write("not actually python")
685             clean = (workspace / "clean.py").resolve()
686             with clean.open("w") as fobj:
687                 fobj.write('print("hello")\n')
688             result = CliRunner().invoke(black.main, [str(workspace)])
689             self.assertEqual(result.exit_code, 123)
690             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
691             self.assertNotIn(failing, cache)
692             self.assertIn(clean, cache)
693
694     def test_write_cache_write_fail(self) -> None:
695         mode = black.FileMode.AUTO_DETECT
696         with cache_dir(), patch.object(Path, "open") as mock:
697             mock.side_effect = OSError
698             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
699
700     @event_loop(close=False)
701     def test_check_diff_use_together(self) -> None:
702         with cache_dir():
703             # Files which will be reformatted.
704             src1 = (THIS_DIR / "string_quotes.py").resolve()
705             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
706             self.assertEqual(result.exit_code, 1)
707
708             # Files which will not be reformatted.
709             src2 = (THIS_DIR / "composition.py").resolve()
710             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
711             self.assertEqual(result.exit_code, 0)
712
713             # Multi file command.
714             result = CliRunner().invoke(
715                 black.main, [str(src1), str(src2), "--diff", "--check"]
716             )
717             self.assertEqual(result.exit_code, 1, result.output)
718
719     def test_no_files(self) -> None:
720         with cache_dir():
721             # Without an argument, black exits with error code 0.
722             result = CliRunner().invoke(black.main, [])
723             self.assertEqual(result.exit_code, 0)
724
725     def test_broken_symlink(self) -> None:
726         with cache_dir() as workspace:
727             symlink = workspace / "broken_link.py"
728             symlink.symlink_to("nonexistent.py")
729             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
730             self.assertEqual(result.exit_code, 0)
731
732     def test_read_cache_line_lengths(self) -> None:
733         mode = black.FileMode.AUTO_DETECT
734         with cache_dir() as workspace:
735             path = (workspace / "file.py").resolve()
736             path.touch()
737             black.write_cache({}, [path], 1, mode)
738             one = black.read_cache(1, mode)
739             self.assertIn(path, one)
740             two = black.read_cache(2, mode)
741             self.assertNotIn(path, two)
742
743     def test_single_file_force_pyi(self) -> None:
744         reg_mode = black.FileMode.AUTO_DETECT
745         pyi_mode = black.FileMode.PYI
746         contents, expected = read_data("force_pyi")
747         with cache_dir() as workspace:
748             path = (workspace / "file.py").resolve()
749             with open(path, "w") as fh:
750                 fh.write(contents)
751             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
752             self.assertEqual(result.exit_code, 0)
753             with open(path, "r") as fh:
754                 actual = fh.read()
755             # verify cache with --pyi is separate
756             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
757             self.assertIn(path, pyi_cache)
758             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
759             self.assertNotIn(path, normal_cache)
760         self.assertEqual(actual, expected)
761
762     @event_loop(close=False)
763     def test_multi_file_force_pyi(self) -> None:
764         reg_mode = black.FileMode.AUTO_DETECT
765         pyi_mode = black.FileMode.PYI
766         contents, expected = read_data("force_pyi")
767         with cache_dir() as workspace:
768             paths = [
769                 (workspace / "file1.py").resolve(),
770                 (workspace / "file2.py").resolve(),
771             ]
772             for path in paths:
773                 with open(path, "w") as fh:
774                     fh.write(contents)
775             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
776             self.assertEqual(result.exit_code, 0)
777             for path in paths:
778                 with open(path, "r") as fh:
779                     actual = fh.read()
780                 self.assertEqual(actual, expected)
781             # verify cache with --pyi is separate
782             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
783             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
784             for path in paths:
785                 self.assertIn(path, pyi_cache)
786                 self.assertNotIn(path, normal_cache)
787
788     def test_pipe_force_pyi(self) -> None:
789         source, expected = read_data("force_pyi")
790         result = CliRunner().invoke(black.main, ["-", "-q", "--pyi"], input=source)
791         self.assertEqual(result.exit_code, 0)
792         actual = result.output
793         self.assertFormatEqual(actual, expected)
794
795     def test_single_file_force_py36(self) -> None:
796         reg_mode = black.FileMode.AUTO_DETECT
797         py36_mode = black.FileMode.PYTHON36
798         source, expected = read_data("force_py36")
799         with cache_dir() as workspace:
800             path = (workspace / "file.py").resolve()
801             with open(path, "w") as fh:
802                 fh.write(source)
803             result = CliRunner().invoke(black.main, [str(path), "--py36"])
804             self.assertEqual(result.exit_code, 0)
805             with open(path, "r") as fh:
806                 actual = fh.read()
807             # verify cache with --py36 is separate
808             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
809             self.assertIn(path, py36_cache)
810             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
811             self.assertNotIn(path, normal_cache)
812         self.assertEqual(actual, expected)
813
814     @event_loop(close=False)
815     def test_multi_file_force_py36(self) -> None:
816         reg_mode = black.FileMode.AUTO_DETECT
817         py36_mode = black.FileMode.PYTHON36
818         source, expected = read_data("force_py36")
819         with cache_dir() as workspace:
820             paths = [
821                 (workspace / "file1.py").resolve(),
822                 (workspace / "file2.py").resolve(),
823             ]
824             for path in paths:
825                 with open(path, "w") as fh:
826                     fh.write(source)
827             result = CliRunner().invoke(
828                 black.main, [str(p) for p in paths] + ["--py36"]
829             )
830             self.assertEqual(result.exit_code, 0)
831             for path in paths:
832                 with open(path, "r") as fh:
833                     actual = fh.read()
834                 self.assertEqual(actual, expected)
835             # verify cache with --py36 is separate
836             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
837             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
838             for path in paths:
839                 self.assertIn(path, pyi_cache)
840                 self.assertNotIn(path, normal_cache)
841
842     def test_pipe_force_py36(self) -> None:
843         source, expected = read_data("force_py36")
844         result = CliRunner().invoke(black.main, ["-", "-q", "--py36"], input=source)
845         self.assertEqual(result.exit_code, 0)
846         actual = result.output
847         self.assertFormatEqual(actual, expected)
848
849
850 if __name__ == "__main__":
851     unittest.main()