]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Clean up PEP 257 support
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14
15 from click import unstyle
16 from click.testing import CliRunner
17
18 import black
19
20 ll = 88
21 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
22 fs = partial(black.format_str, line_length=ll)
23 THIS_FILE = Path(__file__)
24 THIS_DIR = THIS_FILE.parent
25 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
26
27
28 def dump_to_stderr(*output: str) -> str:
29     return "\n" + "\n".join(output) + "\n"
30
31
32 def read_data(name: str) -> Tuple[str, str]:
33     """read_data('test_name') -> 'input', 'output'"""
34     if not name.endswith((".py", ".pyi", ".out", ".diff")):
35         name += ".py"
36     _input: List[str] = []
37     _output: List[str] = []
38     with open(THIS_DIR / name, "r", encoding="utf8") as test:
39         lines = test.readlines()
40     result = _input
41     for line in lines:
42         line = line.replace(EMPTY_LINE, "")
43         if line.rstrip() == "# output":
44             result = _output
45             continue
46
47         result.append(line)
48     if _input and not _output:
49         # If there's no output marker, treat the entire file as already pre-formatted.
50         _output = _input[:]
51     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
52
53
54 @contextmanager
55 def cache_dir(exists: bool = True) -> Iterator[Path]:
56     with TemporaryDirectory() as workspace:
57         cache_dir = Path(workspace)
58         if not exists:
59             cache_dir = cache_dir / "new"
60         with patch("black.CACHE_DIR", cache_dir):
61             yield cache_dir
62
63
64 @contextmanager
65 def event_loop(close: bool) -> Iterator[None]:
66     policy = asyncio.get_event_loop_policy()
67     old_loop = policy.get_event_loop()
68     loop = policy.new_event_loop()
69     asyncio.set_event_loop(loop)
70     try:
71         yield
72
73     finally:
74         policy.set_event_loop(old_loop)
75         if close:
76             loop.close()
77
78
79 class BlackTestCase(unittest.TestCase):
80     maxDiff = None
81
82     def assertFormatEqual(self, expected: str, actual: str) -> None:
83         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84             bdv: black.DebugVisitor[Any]
85             black.out("Expected tree:", fg="green")
86             try:
87                 exp_node = black.lib2to3_parse(expected)
88                 bdv = black.DebugVisitor()
89                 list(bdv.visit(exp_node))
90             except Exception as ve:
91                 black.err(str(ve))
92             black.out("Actual tree:", fg="red")
93             try:
94                 exp_node = black.lib2to3_parse(actual)
95                 bdv = black.DebugVisitor()
96                 list(bdv.visit(exp_node))
97             except Exception as ve:
98                 black.err(str(ve))
99         self.assertEqual(expected, actual)
100
101     @patch("black.dump_to_file", dump_to_stderr)
102     def test_self(self) -> None:
103         source, expected = read_data("test_black")
104         actual = fs(source)
105         self.assertFormatEqual(expected, actual)
106         black.assert_equivalent(source, actual)
107         black.assert_stable(source, actual, line_length=ll)
108         self.assertFalse(ff(THIS_FILE))
109
110     @patch("black.dump_to_file", dump_to_stderr)
111     def test_black(self) -> None:
112         source, expected = read_data("../black")
113         actual = fs(source)
114         self.assertFormatEqual(expected, actual)
115         black.assert_equivalent(source, actual)
116         black.assert_stable(source, actual, line_length=ll)
117         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
118
119     def test_piping(self) -> None:
120         source, expected = read_data("../black")
121         hold_stdin, hold_stdout = sys.stdin, sys.stdout
122         try:
123             sys.stdin, sys.stdout = StringIO(source), StringIO()
124             sys.stdin.name = "<stdin>"
125             black.format_stdin_to_stdout(
126                 line_length=ll, fast=True, write_back=black.WriteBack.YES
127             )
128             sys.stdout.seek(0)
129             actual = sys.stdout.read()
130         finally:
131             sys.stdin, sys.stdout = hold_stdin, hold_stdout
132         self.assertFormatEqual(expected, actual)
133         black.assert_equivalent(source, actual)
134         black.assert_stable(source, actual, line_length=ll)
135
136     def test_piping_diff(self) -> None:
137         source, _ = read_data("expression.py")
138         expected, _ = read_data("expression.diff")
139         hold_stdin, hold_stdout = sys.stdin, sys.stdout
140         try:
141             sys.stdin, sys.stdout = StringIO(source), StringIO()
142             sys.stdin.name = "<stdin>"
143             black.format_stdin_to_stdout(
144                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
145             )
146             sys.stdout.seek(0)
147             actual = sys.stdout.read()
148         finally:
149             sys.stdin, sys.stdout = hold_stdin, hold_stdout
150         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
151         self.assertEqual(expected, actual)
152
153     @patch("black.dump_to_file", dump_to_stderr)
154     def test_setup(self) -> None:
155         source, expected = read_data("../setup")
156         actual = fs(source)
157         self.assertFormatEqual(expected, actual)
158         black.assert_equivalent(source, actual)
159         black.assert_stable(source, actual, line_length=ll)
160         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
161
162     @patch("black.dump_to_file", dump_to_stderr)
163     def test_function(self) -> None:
164         source, expected = read_data("function")
165         actual = fs(source)
166         self.assertFormatEqual(expected, actual)
167         black.assert_equivalent(source, actual)
168         black.assert_stable(source, actual, line_length=ll)
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_function2(self) -> None:
172         source, expected = read_data("function2")
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, line_length=ll)
177
178     @patch("black.dump_to_file", dump_to_stderr)
179     def test_expression(self) -> None:
180         source, expected = read_data("expression")
181         actual = fs(source)
182         self.assertFormatEqual(expected, actual)
183         black.assert_equivalent(source, actual)
184         black.assert_stable(source, actual, line_length=ll)
185
186     def test_expression_ff(self) -> None:
187         source, expected = read_data("expression")
188         tmp_file = Path(black.dump_to_file(source))
189         try:
190             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
191             with open(tmp_file, encoding="utf8") as f:
192                 actual = f.read()
193         finally:
194             os.unlink(tmp_file)
195         self.assertFormatEqual(expected, actual)
196         with patch("black.dump_to_file", dump_to_stderr):
197             black.assert_equivalent(source, actual)
198             black.assert_stable(source, actual, line_length=ll)
199
200     def test_expression_diff(self) -> None:
201         source, _ = read_data("expression.py")
202         expected, _ = read_data("expression.diff")
203         tmp_file = Path(black.dump_to_file(source))
204         hold_stdout = sys.stdout
205         try:
206             sys.stdout = StringIO()
207             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
208             sys.stdout.seek(0)
209             actual = sys.stdout.read()
210             actual = actual.replace(str(tmp_file), "<stdin>")
211         finally:
212             sys.stdout = hold_stdout
213             os.unlink(tmp_file)
214         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
215         if expected != actual:
216             dump = black.dump_to_file(actual)
217             msg = (
218                 f"Expected diff isn't equal to the actual. If you made changes "
219                 f"to expression.py and this is an anticipated difference, "
220                 f"overwrite tests/expression.diff with {dump}"
221             )
222             self.assertEqual(expected, actual, msg)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_fstring(self) -> None:
226         source, expected = read_data("fstring")
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_string_quotes(self) -> None:
234         source, expected = read_data("string_quotes")
235         actual = fs(source)
236         self.assertFormatEqual(expected, actual)
237         black.assert_equivalent(source, actual)
238         black.assert_stable(source, actual, line_length=ll)
239
240     @patch("black.dump_to_file", dump_to_stderr)
241     def test_slices(self) -> None:
242         source, expected = read_data("slices")
243         actual = fs(source)
244         self.assertFormatEqual(expected, actual)
245         black.assert_equivalent(source, actual)
246         black.assert_stable(source, actual, line_length=ll)
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_comments(self) -> None:
250         source, expected = read_data("comments")
251         actual = fs(source)
252         self.assertFormatEqual(expected, actual)
253         black.assert_equivalent(source, actual)
254         black.assert_stable(source, actual, line_length=ll)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_comments2(self) -> None:
258         source, expected = read_data("comments2")
259         actual = fs(source)
260         self.assertFormatEqual(expected, actual)
261         black.assert_equivalent(source, actual)
262         black.assert_stable(source, actual, line_length=ll)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_comments3(self) -> None:
266         source, expected = read_data("comments3")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, line_length=ll)
271
272     @patch("black.dump_to_file", dump_to_stderr)
273     def test_comments4(self) -> None:
274         source, expected = read_data("comments4")
275         actual = fs(source)
276         self.assertFormatEqual(expected, actual)
277         black.assert_equivalent(source, actual)
278         black.assert_stable(source, actual, line_length=ll)
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_comments5(self) -> None:
282         source, expected = read_data("comments5")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_equivalent(source, actual)
286         black.assert_stable(source, actual, line_length=ll)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_cantfit(self) -> None:
290         source, expected = read_data("cantfit")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_import_spacing(self) -> None:
298         source, expected = read_data("import_spacing")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_composition(self) -> None:
306         source, expected = read_data("composition")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     @patch("black.dump_to_file", dump_to_stderr)
313     def test_empty_lines(self) -> None:
314         source, expected = read_data("empty_lines")
315         actual = fs(source)
316         self.assertFormatEqual(expected, actual)
317         black.assert_equivalent(source, actual)
318         black.assert_stable(source, actual, line_length=ll)
319
320     @patch("black.dump_to_file", dump_to_stderr)
321     def test_string_prefixes(self) -> None:
322         source, expected = read_data("string_prefixes")
323         actual = fs(source)
324         self.assertFormatEqual(expected, actual)
325         black.assert_equivalent(source, actual)
326         black.assert_stable(source, actual, line_length=ll)
327
328     @patch("black.dump_to_file", dump_to_stderr)
329     def test_python2(self) -> None:
330         source, expected = read_data("python2")
331         actual = fs(source)
332         self.assertFormatEqual(expected, actual)
333         # black.assert_equivalent(source, actual)
334         black.assert_stable(source, actual, line_length=ll)
335
336     @patch("black.dump_to_file", dump_to_stderr)
337     def test_python2_unicode_literals(self) -> None:
338         source, expected = read_data("python2_unicode_literals")
339         actual = fs(source)
340         self.assertFormatEqual(expected, actual)
341         black.assert_stable(source, actual, line_length=ll)
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_stub(self) -> None:
345         source, expected = read_data("stub.pyi")
346         actual = fs(source, is_pyi=True)
347         self.assertFormatEqual(expected, actual)
348         black.assert_stable(source, actual, line_length=ll, is_pyi=True)
349
350     @patch("black.dump_to_file", dump_to_stderr)
351     def test_fmtonoff(self) -> None:
352         source, expected = read_data("fmtonoff")
353         actual = fs(source)
354         self.assertFormatEqual(expected, actual)
355         black.assert_equivalent(source, actual)
356         black.assert_stable(source, actual, line_length=ll)
357
358     @patch("black.dump_to_file", dump_to_stderr)
359     def test_remove_empty_parentheses_after_class(self) -> None:
360         source, expected = read_data("class_blank_parentheses")
361         actual = fs(source)
362         self.assertFormatEqual(expected, actual)
363         black.assert_equivalent(source, actual)
364         black.assert_stable(source, actual, line_length=ll)
365
366     @patch("black.dump_to_file", dump_to_stderr)
367     def test_new_line_between_class_and_code(self) -> None:
368         source, expected = read_data("class_methods_new_line")
369         actual = fs(source)
370         self.assertFormatEqual(expected, actual)
371         black.assert_equivalent(source, actual)
372         black.assert_stable(source, actual, line_length=ll)
373
374     def test_report(self) -> None:
375         report = black.Report()
376         out_lines = []
377         err_lines = []
378
379         def out(msg: str, **kwargs: Any) -> None:
380             out_lines.append(msg)
381
382         def err(msg: str, **kwargs: Any) -> None:
383             err_lines.append(msg)
384
385         with patch("black.out", out), patch("black.err", err):
386             report.done(Path("f1"), black.Changed.NO)
387             self.assertEqual(len(out_lines), 1)
388             self.assertEqual(len(err_lines), 0)
389             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
390             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
391             self.assertEqual(report.return_code, 0)
392             report.done(Path("f2"), black.Changed.YES)
393             self.assertEqual(len(out_lines), 2)
394             self.assertEqual(len(err_lines), 0)
395             self.assertEqual(out_lines[-1], "reformatted f2")
396             self.assertEqual(
397                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
398             )
399             report.done(Path("f3"), black.Changed.CACHED)
400             self.assertEqual(len(out_lines), 3)
401             self.assertEqual(len(err_lines), 0)
402             self.assertEqual(
403                 out_lines[-1], "f3 wasn't modified on disk since last run."
404             )
405             self.assertEqual(
406                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
407             )
408             self.assertEqual(report.return_code, 0)
409             report.check = True
410             self.assertEqual(report.return_code, 1)
411             report.check = False
412             report.failed(Path("e1"), "boom")
413             self.assertEqual(len(out_lines), 3)
414             self.assertEqual(len(err_lines), 1)
415             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
416             self.assertEqual(
417                 unstyle(str(report)),
418                 "1 file reformatted, 2 files left unchanged, "
419                 "1 file failed to reformat.",
420             )
421             self.assertEqual(report.return_code, 123)
422             report.done(Path("f3"), black.Changed.YES)
423             self.assertEqual(len(out_lines), 4)
424             self.assertEqual(len(err_lines), 1)
425             self.assertEqual(out_lines[-1], "reformatted f3")
426             self.assertEqual(
427                 unstyle(str(report)),
428                 "2 files reformatted, 2 files left unchanged, "
429                 "1 file failed to reformat.",
430             )
431             self.assertEqual(report.return_code, 123)
432             report.failed(Path("e2"), "boom")
433             self.assertEqual(len(out_lines), 4)
434             self.assertEqual(len(err_lines), 2)
435             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
436             self.assertEqual(
437                 unstyle(str(report)),
438                 "2 files reformatted, 2 files left unchanged, "
439                 "2 files failed to reformat.",
440             )
441             self.assertEqual(report.return_code, 123)
442             report.done(Path("f4"), black.Changed.NO)
443             self.assertEqual(len(out_lines), 5)
444             self.assertEqual(len(err_lines), 2)
445             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
446             self.assertEqual(
447                 unstyle(str(report)),
448                 "2 files reformatted, 3 files left unchanged, "
449                 "2 files failed to reformat.",
450             )
451             self.assertEqual(report.return_code, 123)
452             report.check = True
453             self.assertEqual(
454                 unstyle(str(report)),
455                 "2 files would be reformatted, 3 files would be left unchanged, "
456                 "2 files would fail to reformat.",
457             )
458
459     def test_is_python36(self) -> None:
460         node = black.lib2to3_parse("def f(*, arg): ...\n")
461         self.assertFalse(black.is_python36(node))
462         node = black.lib2to3_parse("def f(*, arg,): ...\n")
463         self.assertTrue(black.is_python36(node))
464         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
465         self.assertTrue(black.is_python36(node))
466         source, expected = read_data("function")
467         node = black.lib2to3_parse(source)
468         self.assertTrue(black.is_python36(node))
469         node = black.lib2to3_parse(expected)
470         self.assertTrue(black.is_python36(node))
471         source, expected = read_data("expression")
472         node = black.lib2to3_parse(source)
473         self.assertFalse(black.is_python36(node))
474         node = black.lib2to3_parse(expected)
475         self.assertFalse(black.is_python36(node))
476
477     def test_get_future_imports(self) -> None:
478         node = black.lib2to3_parse("\n")
479         self.assertEqual(set(), black.get_future_imports(node))
480         node = black.lib2to3_parse("from __future__ import black\n")
481         self.assertEqual({"black"}, black.get_future_imports(node))
482         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
483         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
484         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
485         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
486         node = black.lib2to3_parse(
487             "from __future__ import multiple\nfrom __future__ import imports\n"
488         )
489         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
490         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
491         self.assertEqual({"black"}, black.get_future_imports(node))
492         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
493         self.assertEqual({"black"}, black.get_future_imports(node))
494         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
495         self.assertEqual(set(), black.get_future_imports(node))
496         node = black.lib2to3_parse("from some.module import black\n")
497         self.assertEqual(set(), black.get_future_imports(node))
498
499     def test_debug_visitor(self) -> None:
500         source, _ = read_data("debug_visitor.py")
501         expected, _ = read_data("debug_visitor.out")
502         out_lines = []
503         err_lines = []
504
505         def out(msg: str, **kwargs: Any) -> None:
506             out_lines.append(msg)
507
508         def err(msg: str, **kwargs: Any) -> None:
509             err_lines.append(msg)
510
511         with patch("black.out", out), patch("black.err", err):
512             black.DebugVisitor.show(source)
513         actual = "\n".join(out_lines) + "\n"
514         log_name = ""
515         if expected != actual:
516             log_name = black.dump_to_file(*out_lines)
517         self.assertEqual(
518             expected,
519             actual,
520             f"AST print out is different. Actual version dumped to {log_name}",
521         )
522
523     def test_format_file_contents(self) -> None:
524         empty = ""
525         with self.assertRaises(black.NothingChanged):
526             black.format_file_contents(empty, line_length=ll, fast=False)
527         just_nl = "\n"
528         with self.assertRaises(black.NothingChanged):
529             black.format_file_contents(just_nl, line_length=ll, fast=False)
530         same = "l = [1, 2, 3]\n"
531         with self.assertRaises(black.NothingChanged):
532             black.format_file_contents(same, line_length=ll, fast=False)
533         different = "l = [1,2,3]"
534         expected = same
535         actual = black.format_file_contents(different, line_length=ll, fast=False)
536         self.assertEqual(expected, actual)
537         invalid = "return if you can"
538         with self.assertRaises(ValueError) as e:
539             black.format_file_contents(invalid, line_length=ll, fast=False)
540         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
541
542     def test_endmarker(self) -> None:
543         n = black.lib2to3_parse("\n")
544         self.assertEqual(n.type, black.syms.file_input)
545         self.assertEqual(len(n.children), 1)
546         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
547
548     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
549     def test_assertFormatEqual(self) -> None:
550         out_lines = []
551         err_lines = []
552
553         def out(msg: str, **kwargs: Any) -> None:
554             out_lines.append(msg)
555
556         def err(msg: str, **kwargs: Any) -> None:
557             err_lines.append(msg)
558
559         with patch("black.out", out), patch("black.err", err):
560             with self.assertRaises(AssertionError):
561                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
562
563         out_str = "".join(out_lines)
564         self.assertTrue("Expected tree:" in out_str)
565         self.assertTrue("Actual tree:" in out_str)
566         self.assertEqual("".join(err_lines), "")
567
568     def test_cache_broken_file(self) -> None:
569         with cache_dir() as workspace:
570             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
571             with cache_file.open("w") as fobj:
572                 fobj.write("this is not a pickle")
573             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
574             src = (workspace / "test.py").resolve()
575             with src.open("w") as fobj:
576                 fobj.write("print('hello')")
577             result = CliRunner().invoke(black.main, [str(src)])
578             self.assertEqual(result.exit_code, 0)
579             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
580             self.assertIn(src, cache)
581
582     def test_cache_single_file_already_cached(self) -> None:
583         with cache_dir() as workspace:
584             src = (workspace / "test.py").resolve()
585             with src.open("w") as fobj:
586                 fobj.write("print('hello')")
587             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
588             result = CliRunner().invoke(black.main, [str(src)])
589             self.assertEqual(result.exit_code, 0)
590             with src.open("r") as fobj:
591                 self.assertEqual(fobj.read(), "print('hello')")
592
593     @event_loop(close=False)
594     def test_cache_multiple_files(self) -> None:
595         with cache_dir() as workspace, patch(
596             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
597         ):
598             one = (workspace / "one.py").resolve()
599             with one.open("w") as fobj:
600                 fobj.write("print('hello')")
601             two = (workspace / "two.py").resolve()
602             with two.open("w") as fobj:
603                 fobj.write("print('hello')")
604             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH)
605             result = CliRunner().invoke(black.main, [str(workspace)])
606             self.assertEqual(result.exit_code, 0)
607             with one.open("r") as fobj:
608                 self.assertEqual(fobj.read(), "print('hello')")
609             with two.open("r") as fobj:
610                 self.assertEqual(fobj.read(), 'print("hello")\n')
611             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
612             self.assertIn(one, cache)
613             self.assertIn(two, cache)
614
615     def test_no_cache_when_writeback_diff(self) -> None:
616         with cache_dir() as workspace:
617             src = (workspace / "test.py").resolve()
618             with src.open("w") as fobj:
619                 fobj.write("print('hello')")
620             result = CliRunner().invoke(black.main, [str(src), "--diff"])
621             self.assertEqual(result.exit_code, 0)
622             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
623             self.assertFalse(cache_file.exists())
624
625     def test_no_cache_when_stdin(self) -> None:
626         with cache_dir():
627             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
628             self.assertEqual(result.exit_code, 0)
629             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
630             self.assertFalse(cache_file.exists())
631
632     def test_read_cache_no_cachefile(self) -> None:
633         with cache_dir():
634             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
635
636     def test_write_cache_read_cache(self) -> None:
637         with cache_dir() as workspace:
638             src = (workspace / "test.py").resolve()
639             src.touch()
640             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
641             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
642             self.assertIn(src, cache)
643             self.assertEqual(cache[src], black.get_cache_info(src))
644
645     def test_filter_cached(self) -> None:
646         with TemporaryDirectory() as workspace:
647             path = Path(workspace)
648             uncached = (path / "uncached").resolve()
649             cached = (path / "cached").resolve()
650             cached_but_changed = (path / "changed").resolve()
651             uncached.touch()
652             cached.touch()
653             cached_but_changed.touch()
654             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
655             todo, done = black.filter_cached(
656                 cache, [uncached, cached, cached_but_changed]
657             )
658             self.assertEqual(todo, [uncached, cached_but_changed])
659             self.assertEqual(done, [cached])
660
661     def test_write_cache_creates_directory_if_needed(self) -> None:
662         with cache_dir(exists=False) as workspace:
663             self.assertFalse(workspace.exists())
664             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
665             self.assertTrue(workspace.exists())
666
667     @event_loop(close=False)
668     def test_failed_formatting_does_not_get_cached(self) -> None:
669         with cache_dir() as workspace, patch(
670             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
671         ):
672             failing = (workspace / "failing.py").resolve()
673             with failing.open("w") as fobj:
674                 fobj.write("not actually python")
675             clean = (workspace / "clean.py").resolve()
676             with clean.open("w") as fobj:
677                 fobj.write('print("hello")\n')
678             result = CliRunner().invoke(black.main, [str(workspace)])
679             self.assertEqual(result.exit_code, 123)
680             cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
681             self.assertNotIn(failing, cache)
682             self.assertIn(clean, cache)
683
684     def test_write_cache_write_fail(self) -> None:
685         with cache_dir(), patch.object(Path, "open") as mock:
686             mock.side_effect = OSError
687             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
688
689     @event_loop(close=False)
690     def test_check_diff_use_together(self) -> None:
691         with cache_dir():
692             # Files which will be reformatted.
693             src1 = (THIS_DIR / "string_quotes.py").resolve()
694             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
695             self.assertEqual(result.exit_code, 1)
696
697             # Files which will not be reformatted.
698             src2 = (THIS_DIR / "composition.py").resolve()
699             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
700             self.assertEqual(result.exit_code, 0)
701
702             # Multi file command.
703             result = CliRunner().invoke(
704                 black.main, [str(src1), str(src2), "--diff", "--check"]
705             )
706             self.assertEqual(result.exit_code, 1, result.output)
707
708     def test_no_files(self) -> None:
709         with cache_dir():
710             # Without an argument, black exits with error code 0.
711             result = CliRunner().invoke(black.main, [])
712             self.assertEqual(result.exit_code, 0)
713
714     def test_broken_symlink(self) -> None:
715         with cache_dir() as workspace:
716             symlink = workspace / "broken_link.py"
717             symlink.symlink_to("nonexistent.py")
718             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
719             self.assertEqual(result.exit_code, 0)
720
721     def test_read_cache_line_lengths(self) -> None:
722         with cache_dir() as workspace:
723             path = (workspace / "file.py").resolve()
724             path.touch()
725             black.write_cache({}, [path], 1)
726             one = black.read_cache(1)
727             self.assertIn(path, one)
728             two = black.read_cache(2)
729             self.assertNotIn(path, two)
730
731     def test_single_file_force_pyi(self) -> None:
732         contents, expected = read_data("force_pyi")
733         with cache_dir() as workspace:
734             path = (workspace / "file.py").resolve()
735             with open(path, "w") as fh:
736                 fh.write(contents)
737             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
738             self.assertEqual(result.exit_code, 0)
739             with open(path, "r") as fh:
740                 actual = fh.read()
741             # verify cache with --pyi is separate
742             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi=True)
743             self.assertIn(path, pyi_cache)
744             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
745             self.assertNotIn(path, normal_cache)
746         self.assertEqual(actual, expected)
747
748     @event_loop(close=False)
749     def test_multi_file_force_pyi(self) -> None:
750         contents, expected = read_data("force_pyi")
751         with cache_dir() as workspace:
752             paths = [
753                 (workspace / "file1.py").resolve(),
754                 (workspace / "file2.py").resolve(),
755             ]
756             for path in paths:
757                 with open(path, "w") as fh:
758                     fh.write(contents)
759             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
760             self.assertEqual(result.exit_code, 0)
761             for path in paths:
762                 with open(path, "r") as fh:
763                     actual = fh.read()
764                 self.assertEqual(actual, expected)
765             # verify cache with --pyi is separate
766             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi=True)
767             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
768             for path in paths:
769                 self.assertIn(path, pyi_cache)
770                 self.assertNotIn(path, normal_cache)
771
772     def test_pipe_force_pyi(self) -> None:
773         source, expected = read_data("force_pyi")
774         result = CliRunner().invoke(black.main, ["-", "-q", "--pyi"], input=source)
775         self.assertEqual(result.exit_code, 0)
776         actual = result.output
777         self.assertFormatEqual(actual, expected)
778
779     def test_single_file_force_py36(self) -> None:
780         source, expected = read_data("force_py36")
781         with cache_dir() as workspace:
782             path = (workspace / "file.py").resolve()
783             with open(path, "w") as fh:
784                 fh.write(source)
785             result = CliRunner().invoke(black.main, [str(path), "--py36"])
786             self.assertEqual(result.exit_code, 0)
787             with open(path, "r") as fh:
788                 actual = fh.read()
789             # verify cache with --py36 is separate
790             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36=True)
791             self.assertIn(path, py36_cache)
792             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
793             self.assertNotIn(path, normal_cache)
794         self.assertEqual(actual, expected)
795
796     @event_loop(close=False)
797     def test_multi_file_force_py36(self) -> None:
798         source, expected = read_data("force_py36")
799         with cache_dir() as workspace:
800             paths = [
801                 (workspace / "file1.py").resolve(),
802                 (workspace / "file2.py").resolve(),
803             ]
804             for path in paths:
805                 with open(path, "w") as fh:
806                     fh.write(source)
807             result = CliRunner().invoke(
808                 black.main, [str(p) for p in paths] + ["--py36"]
809             )
810             self.assertEqual(result.exit_code, 0)
811             for path in paths:
812                 with open(path, "r") as fh:
813                     actual = fh.read()
814                 self.assertEqual(actual, expected)
815             # verify cache with --py36 is separate
816             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36=True)
817             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
818             for path in paths:
819                 self.assertIn(path, pyi_cache)
820                 self.assertNotIn(path, normal_cache)
821
822     def test_pipe_force_py36(self) -> None:
823         source, expected = read_data("force_py36")
824         result = CliRunner().invoke(black.main, ["-", "-q", "--py36"], input=source)
825         self.assertEqual(result.exit_code, 0)
826         actual = result.output
827         self.assertFormatEqual(actual, expected)
828
829
830 if __name__ == "__main__":
831     unittest.main()