]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

[trivial] Simplify stdin handling
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14 import re
15
16 from click import unstyle
17 from click.testing import CliRunner
18
19 import black
20
21 ll = 88
22 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
23 fs = partial(black.format_str, line_length=ll)
24 THIS_FILE = Path(__file__)
25 THIS_DIR = THIS_FILE.parent
26 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
27
28
29 def dump_to_stderr(*output: str) -> str:
30     return "\n" + "\n".join(output) + "\n"
31
32
33 def read_data(name: str) -> Tuple[str, str]:
34     """read_data('test_name') -> 'input', 'output'"""
35     if not name.endswith((".py", ".pyi", ".out", ".diff")):
36         name += ".py"
37     _input: List[str] = []
38     _output: List[str] = []
39     with open(THIS_DIR / name, "r", encoding="utf8") as test:
40         lines = test.readlines()
41     result = _input
42     for line in lines:
43         line = line.replace(EMPTY_LINE, "")
44         if line.rstrip() == "# output":
45             result = _output
46             continue
47
48         result.append(line)
49     if _input and not _output:
50         # If there's no output marker, treat the entire file as already pre-formatted.
51         _output = _input[:]
52     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
53
54
55 @contextmanager
56 def cache_dir(exists: bool = True) -> Iterator[Path]:
57     with TemporaryDirectory() as workspace:
58         cache_dir = Path(workspace)
59         if not exists:
60             cache_dir = cache_dir / "new"
61         with patch("black.CACHE_DIR", cache_dir):
62             yield cache_dir
63
64
65 @contextmanager
66 def event_loop(close: bool) -> Iterator[None]:
67     policy = asyncio.get_event_loop_policy()
68     old_loop = policy.get_event_loop()
69     loop = policy.new_event_loop()
70     asyncio.set_event_loop(loop)
71     try:
72         yield
73
74     finally:
75         policy.set_event_loop(old_loop)
76         if close:
77             loop.close()
78
79
80 class BlackTestCase(unittest.TestCase):
81     maxDiff = None
82
83     def assertFormatEqual(self, expected: str, actual: str) -> None:
84         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
85             bdv: black.DebugVisitor[Any]
86             black.out("Expected tree:", fg="green")
87             try:
88                 exp_node = black.lib2to3_parse(expected)
89                 bdv = black.DebugVisitor()
90                 list(bdv.visit(exp_node))
91             except Exception as ve:
92                 black.err(str(ve))
93             black.out("Actual tree:", fg="red")
94             try:
95                 exp_node = black.lib2to3_parse(actual)
96                 bdv = black.DebugVisitor()
97                 list(bdv.visit(exp_node))
98             except Exception as ve:
99                 black.err(str(ve))
100         self.assertEqual(expected, actual)
101
102     @patch("black.dump_to_file", dump_to_stderr)
103     def test_self(self) -> None:
104         source, expected = read_data("test_black")
105         actual = fs(source)
106         self.assertFormatEqual(expected, actual)
107         black.assert_equivalent(source, actual)
108         black.assert_stable(source, actual, line_length=ll)
109         self.assertFalse(ff(THIS_FILE))
110
111     @patch("black.dump_to_file", dump_to_stderr)
112     def test_black(self) -> None:
113         source, expected = read_data("../black")
114         actual = fs(source)
115         self.assertFormatEqual(expected, actual)
116         black.assert_equivalent(source, actual)
117         black.assert_stable(source, actual, line_length=ll)
118         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
119
120     def test_piping(self) -> None:
121         source, expected = read_data("../black")
122         hold_stdin, hold_stdout = sys.stdin, sys.stdout
123         try:
124             sys.stdin, sys.stdout = StringIO(source), StringIO()
125             sys.stdin.name = "<stdin>"
126             black.format_stdin_to_stdout(
127                 line_length=ll, fast=True, write_back=black.WriteBack.YES
128             )
129             sys.stdout.seek(0)
130             actual = sys.stdout.read()
131         finally:
132             sys.stdin, sys.stdout = hold_stdin, hold_stdout
133         self.assertFormatEqual(expected, actual)
134         black.assert_equivalent(source, actual)
135         black.assert_stable(source, actual, line_length=ll)
136
137     def test_piping_diff(self) -> None:
138         source, _ = read_data("expression.py")
139         expected, _ = read_data("expression.diff")
140         hold_stdin, hold_stdout = sys.stdin, sys.stdout
141         try:
142             sys.stdin, sys.stdout = StringIO(source), StringIO()
143             sys.stdin.name = "<stdin>"
144             black.format_stdin_to_stdout(
145                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
146             )
147             sys.stdout.seek(0)
148             actual = sys.stdout.read()
149         finally:
150             sys.stdin, sys.stdout = hold_stdin, hold_stdout
151         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
152         self.assertEqual(expected, actual)
153
154     @patch("black.dump_to_file", dump_to_stderr)
155     def test_setup(self) -> None:
156         source, expected = read_data("../setup")
157         actual = fs(source)
158         self.assertFormatEqual(expected, actual)
159         black.assert_equivalent(source, actual)
160         black.assert_stable(source, actual, line_length=ll)
161         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
162
163     @patch("black.dump_to_file", dump_to_stderr)
164     def test_function(self) -> None:
165         source, expected = read_data("function")
166         actual = fs(source)
167         self.assertFormatEqual(expected, actual)
168         black.assert_equivalent(source, actual)
169         black.assert_stable(source, actual, line_length=ll)
170
171     @patch("black.dump_to_file", dump_to_stderr)
172     def test_function2(self) -> None:
173         source, expected = read_data("function2")
174         actual = fs(source)
175         self.assertFormatEqual(expected, actual)
176         black.assert_equivalent(source, actual)
177         black.assert_stable(source, actual, line_length=ll)
178
179     @patch("black.dump_to_file", dump_to_stderr)
180     def test_expression(self) -> None:
181         source, expected = read_data("expression")
182         actual = fs(source)
183         self.assertFormatEqual(expected, actual)
184         black.assert_equivalent(source, actual)
185         black.assert_stable(source, actual, line_length=ll)
186
187     def test_expression_ff(self) -> None:
188         source, expected = read_data("expression")
189         tmp_file = Path(black.dump_to_file(source))
190         try:
191             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
192             with open(tmp_file, encoding="utf8") as f:
193                 actual = f.read()
194         finally:
195             os.unlink(tmp_file)
196         self.assertFormatEqual(expected, actual)
197         with patch("black.dump_to_file", dump_to_stderr):
198             black.assert_equivalent(source, actual)
199             black.assert_stable(source, actual, line_length=ll)
200
201     def test_expression_diff(self) -> None:
202         source, _ = read_data("expression.py")
203         expected, _ = read_data("expression.diff")
204         tmp_file = Path(black.dump_to_file(source))
205         hold_stdout = sys.stdout
206         try:
207             sys.stdout = StringIO()
208             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
209             sys.stdout.seek(0)
210             actual = sys.stdout.read()
211             actual = actual.replace(str(tmp_file), "<stdin>")
212         finally:
213             sys.stdout = hold_stdout
214             os.unlink(tmp_file)
215         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
216         if expected != actual:
217             dump = black.dump_to_file(actual)
218             msg = (
219                 f"Expected diff isn't equal to the actual. If you made changes "
220                 f"to expression.py and this is an anticipated difference, "
221                 f"overwrite tests/expression.diff with {dump}"
222             )
223             self.assertEqual(expected, actual, msg)
224
225     @patch("black.dump_to_file", dump_to_stderr)
226     def test_fstring(self) -> None:
227         source, expected = read_data("fstring")
228         actual = fs(source)
229         self.assertFormatEqual(expected, actual)
230         black.assert_equivalent(source, actual)
231         black.assert_stable(source, actual, line_length=ll)
232
233     @patch("black.dump_to_file", dump_to_stderr)
234     def test_string_quotes(self) -> None:
235         source, expected = read_data("string_quotes")
236         actual = fs(source)
237         self.assertFormatEqual(expected, actual)
238         black.assert_equivalent(source, actual)
239         black.assert_stable(source, actual, line_length=ll)
240         mode = black.FileMode.NO_STRING_NORMALIZATION
241         not_normalized = fs(source, mode=mode)
242         self.assertFormatEqual(source, not_normalized)
243         black.assert_equivalent(source, not_normalized)
244         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
245
246     @patch("black.dump_to_file", dump_to_stderr)
247     def test_slices(self) -> None:
248         source, expected = read_data("slices")
249         actual = fs(source)
250         self.assertFormatEqual(expected, actual)
251         black.assert_equivalent(source, actual)
252         black.assert_stable(source, actual, line_length=ll)
253
254     @patch("black.dump_to_file", dump_to_stderr)
255     def test_comments(self) -> None:
256         source, expected = read_data("comments")
257         actual = fs(source)
258         self.assertFormatEqual(expected, actual)
259         black.assert_equivalent(source, actual)
260         black.assert_stable(source, actual, line_length=ll)
261
262     @patch("black.dump_to_file", dump_to_stderr)
263     def test_comments2(self) -> None:
264         source, expected = read_data("comments2")
265         actual = fs(source)
266         self.assertFormatEqual(expected, actual)
267         black.assert_equivalent(source, actual)
268         black.assert_stable(source, actual, line_length=ll)
269
270     @patch("black.dump_to_file", dump_to_stderr)
271     def test_comments3(self) -> None:
272         source, expected = read_data("comments3")
273         actual = fs(source)
274         self.assertFormatEqual(expected, actual)
275         black.assert_equivalent(source, actual)
276         black.assert_stable(source, actual, line_length=ll)
277
278     @patch("black.dump_to_file", dump_to_stderr)
279     def test_comments4(self) -> None:
280         source, expected = read_data("comments4")
281         actual = fs(source)
282         self.assertFormatEqual(expected, actual)
283         black.assert_equivalent(source, actual)
284         black.assert_stable(source, actual, line_length=ll)
285
286     @patch("black.dump_to_file", dump_to_stderr)
287     def test_comments5(self) -> None:
288         source, expected = read_data("comments5")
289         actual = fs(source)
290         self.assertFormatEqual(expected, actual)
291         black.assert_equivalent(source, actual)
292         black.assert_stable(source, actual, line_length=ll)
293
294     @patch("black.dump_to_file", dump_to_stderr)
295     def test_cantfit(self) -> None:
296         source, expected = read_data("cantfit")
297         actual = fs(source)
298         self.assertFormatEqual(expected, actual)
299         black.assert_equivalent(source, actual)
300         black.assert_stable(source, actual, line_length=ll)
301
302     @patch("black.dump_to_file", dump_to_stderr)
303     def test_import_spacing(self) -> None:
304         source, expected = read_data("import_spacing")
305         actual = fs(source)
306         self.assertFormatEqual(expected, actual)
307         black.assert_equivalent(source, actual)
308         black.assert_stable(source, actual, line_length=ll)
309
310     @patch("black.dump_to_file", dump_to_stderr)
311     def test_composition(self) -> None:
312         source, expected = read_data("composition")
313         actual = fs(source)
314         self.assertFormatEqual(expected, actual)
315         black.assert_equivalent(source, actual)
316         black.assert_stable(source, actual, line_length=ll)
317
318     @patch("black.dump_to_file", dump_to_stderr)
319     def test_empty_lines(self) -> None:
320         source, expected = read_data("empty_lines")
321         actual = fs(source)
322         self.assertFormatEqual(expected, actual)
323         black.assert_equivalent(source, actual)
324         black.assert_stable(source, actual, line_length=ll)
325
326     @patch("black.dump_to_file", dump_to_stderr)
327     def test_string_prefixes(self) -> None:
328         source, expected = read_data("string_prefixes")
329         actual = fs(source)
330         self.assertFormatEqual(expected, actual)
331         black.assert_equivalent(source, actual)
332         black.assert_stable(source, actual, line_length=ll)
333
334     @patch("black.dump_to_file", dump_to_stderr)
335     def test_python2(self) -> None:
336         source, expected = read_data("python2")
337         actual = fs(source)
338         self.assertFormatEqual(expected, actual)
339         # black.assert_equivalent(source, actual)
340         black.assert_stable(source, actual, line_length=ll)
341
342     @patch("black.dump_to_file", dump_to_stderr)
343     def test_python2_unicode_literals(self) -> None:
344         source, expected = read_data("python2_unicode_literals")
345         actual = fs(source)
346         self.assertFormatEqual(expected, actual)
347         black.assert_stable(source, actual, line_length=ll)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_stub(self) -> None:
351         mode = black.FileMode.PYI
352         source, expected = read_data("stub.pyi")
353         actual = fs(source, mode=mode)
354         self.assertFormatEqual(expected, actual)
355         black.assert_stable(source, actual, line_length=ll, mode=mode)
356
357     @patch("black.dump_to_file", dump_to_stderr)
358     def test_fmtonoff(self) -> None:
359         source, expected = read_data("fmtonoff")
360         actual = fs(source)
361         self.assertFormatEqual(expected, actual)
362         black.assert_equivalent(source, actual)
363         black.assert_stable(source, actual, line_length=ll)
364
365     @patch("black.dump_to_file", dump_to_stderr)
366     def test_remove_empty_parentheses_after_class(self) -> None:
367         source, expected = read_data("class_blank_parentheses")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         black.assert_equivalent(source, actual)
371         black.assert_stable(source, actual, line_length=ll)
372
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_new_line_between_class_and_code(self) -> None:
375         source, expected = read_data("class_methods_new_line")
376         actual = fs(source)
377         self.assertFormatEqual(expected, actual)
378         black.assert_equivalent(source, actual)
379         black.assert_stable(source, actual, line_length=ll)
380
381     def test_report(self) -> None:
382         report = black.Report()
383         out_lines = []
384         err_lines = []
385
386         def out(msg: str, **kwargs: Any) -> None:
387             out_lines.append(msg)
388
389         def err(msg: str, **kwargs: Any) -> None:
390             err_lines.append(msg)
391
392         with patch("black.out", out), patch("black.err", err):
393             report.done(Path("f1"), black.Changed.NO)
394             self.assertEqual(len(out_lines), 1)
395             self.assertEqual(len(err_lines), 0)
396             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
397             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
398             self.assertEqual(report.return_code, 0)
399             report.done(Path("f2"), black.Changed.YES)
400             self.assertEqual(len(out_lines), 2)
401             self.assertEqual(len(err_lines), 0)
402             self.assertEqual(out_lines[-1], "reformatted f2")
403             self.assertEqual(
404                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
405             )
406             report.done(Path("f3"), black.Changed.CACHED)
407             self.assertEqual(len(out_lines), 3)
408             self.assertEqual(len(err_lines), 0)
409             self.assertEqual(
410                 out_lines[-1], "f3 wasn't modified on disk since last run."
411             )
412             self.assertEqual(
413                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
414             )
415             self.assertEqual(report.return_code, 0)
416             report.check = True
417             self.assertEqual(report.return_code, 1)
418             report.check = False
419             report.failed(Path("e1"), "boom")
420             self.assertEqual(len(out_lines), 3)
421             self.assertEqual(len(err_lines), 1)
422             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
423             self.assertEqual(
424                 unstyle(str(report)),
425                 "1 file reformatted, 2 files left unchanged, "
426                 "1 file failed to reformat.",
427             )
428             self.assertEqual(report.return_code, 123)
429             report.done(Path("f3"), black.Changed.YES)
430             self.assertEqual(len(out_lines), 4)
431             self.assertEqual(len(err_lines), 1)
432             self.assertEqual(out_lines[-1], "reformatted f3")
433             self.assertEqual(
434                 unstyle(str(report)),
435                 "2 files reformatted, 2 files left unchanged, "
436                 "1 file failed to reformat.",
437             )
438             self.assertEqual(report.return_code, 123)
439             report.failed(Path("e2"), "boom")
440             self.assertEqual(len(out_lines), 4)
441             self.assertEqual(len(err_lines), 2)
442             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
443             self.assertEqual(
444                 unstyle(str(report)),
445                 "2 files reformatted, 2 files left unchanged, "
446                 "2 files failed to reformat.",
447             )
448             self.assertEqual(report.return_code, 123)
449             report.done(Path("f4"), black.Changed.NO)
450             self.assertEqual(len(out_lines), 5)
451             self.assertEqual(len(err_lines), 2)
452             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
453             self.assertEqual(
454                 unstyle(str(report)),
455                 "2 files reformatted, 3 files left unchanged, "
456                 "2 files failed to reformat.",
457             )
458             self.assertEqual(report.return_code, 123)
459             report.check = True
460             self.assertEqual(
461                 unstyle(str(report)),
462                 "2 files would be reformatted, 3 files would be left unchanged, "
463                 "2 files would fail to reformat.",
464             )
465
466     def test_is_python36(self) -> None:
467         node = black.lib2to3_parse("def f(*, arg): ...\n")
468         self.assertFalse(black.is_python36(node))
469         node = black.lib2to3_parse("def f(*, arg,): ...\n")
470         self.assertTrue(black.is_python36(node))
471         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
472         self.assertTrue(black.is_python36(node))
473         source, expected = read_data("function")
474         node = black.lib2to3_parse(source)
475         self.assertTrue(black.is_python36(node))
476         node = black.lib2to3_parse(expected)
477         self.assertTrue(black.is_python36(node))
478         source, expected = read_data("expression")
479         node = black.lib2to3_parse(source)
480         self.assertFalse(black.is_python36(node))
481         node = black.lib2to3_parse(expected)
482         self.assertFalse(black.is_python36(node))
483
484     def test_get_future_imports(self) -> None:
485         node = black.lib2to3_parse("\n")
486         self.assertEqual(set(), black.get_future_imports(node))
487         node = black.lib2to3_parse("from __future__ import black\n")
488         self.assertEqual({"black"}, black.get_future_imports(node))
489         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
490         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
491         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
492         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
493         node = black.lib2to3_parse(
494             "from __future__ import multiple\nfrom __future__ import imports\n"
495         )
496         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
497         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
498         self.assertEqual({"black"}, black.get_future_imports(node))
499         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
500         self.assertEqual({"black"}, black.get_future_imports(node))
501         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
502         self.assertEqual(set(), black.get_future_imports(node))
503         node = black.lib2to3_parse("from some.module import black\n")
504         self.assertEqual(set(), black.get_future_imports(node))
505
506     def test_debug_visitor(self) -> None:
507         source, _ = read_data("debug_visitor.py")
508         expected, _ = read_data("debug_visitor.out")
509         out_lines = []
510         err_lines = []
511
512         def out(msg: str, **kwargs: Any) -> None:
513             out_lines.append(msg)
514
515         def err(msg: str, **kwargs: Any) -> None:
516             err_lines.append(msg)
517
518         with patch("black.out", out), patch("black.err", err):
519             black.DebugVisitor.show(source)
520         actual = "\n".join(out_lines) + "\n"
521         log_name = ""
522         if expected != actual:
523             log_name = black.dump_to_file(*out_lines)
524         self.assertEqual(
525             expected,
526             actual,
527             f"AST print out is different. Actual version dumped to {log_name}",
528         )
529
530     def test_format_file_contents(self) -> None:
531         empty = ""
532         with self.assertRaises(black.NothingChanged):
533             black.format_file_contents(empty, line_length=ll, fast=False)
534         just_nl = "\n"
535         with self.assertRaises(black.NothingChanged):
536             black.format_file_contents(just_nl, line_length=ll, fast=False)
537         same = "l = [1, 2, 3]\n"
538         with self.assertRaises(black.NothingChanged):
539             black.format_file_contents(same, line_length=ll, fast=False)
540         different = "l = [1,2,3]"
541         expected = same
542         actual = black.format_file_contents(different, line_length=ll, fast=False)
543         self.assertEqual(expected, actual)
544         invalid = "return if you can"
545         with self.assertRaises(ValueError) as e:
546             black.format_file_contents(invalid, line_length=ll, fast=False)
547         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
548
549     def test_endmarker(self) -> None:
550         n = black.lib2to3_parse("\n")
551         self.assertEqual(n.type, black.syms.file_input)
552         self.assertEqual(len(n.children), 1)
553         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
554
555     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
556     def test_assertFormatEqual(self) -> None:
557         out_lines = []
558         err_lines = []
559
560         def out(msg: str, **kwargs: Any) -> None:
561             out_lines.append(msg)
562
563         def err(msg: str, **kwargs: Any) -> None:
564             err_lines.append(msg)
565
566         with patch("black.out", out), patch("black.err", err):
567             with self.assertRaises(AssertionError):
568                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
569
570         out_str = "".join(out_lines)
571         self.assertTrue("Expected tree:" in out_str)
572         self.assertTrue("Actual tree:" in out_str)
573         self.assertEqual("".join(err_lines), "")
574
575     def test_cache_broken_file(self) -> None:
576         mode = black.FileMode.AUTO_DETECT
577         with cache_dir() as workspace:
578             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
579             with cache_file.open("w") as fobj:
580                 fobj.write("this is not a pickle")
581             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
582             src = (workspace / "test.py").resolve()
583             with src.open("w") as fobj:
584                 fobj.write("print('hello')")
585             result = CliRunner().invoke(black.main, [str(src)])
586             self.assertEqual(result.exit_code, 0)
587             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
588             self.assertIn(src, cache)
589
590     def test_cache_single_file_already_cached(self) -> None:
591         mode = black.FileMode.AUTO_DETECT
592         with cache_dir() as workspace:
593             src = (workspace / "test.py").resolve()
594             with src.open("w") as fobj:
595                 fobj.write("print('hello')")
596             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
597             result = CliRunner().invoke(black.main, [str(src)])
598             self.assertEqual(result.exit_code, 0)
599             with src.open("r") as fobj:
600                 self.assertEqual(fobj.read(), "print('hello')")
601
602     @event_loop(close=False)
603     def test_cache_multiple_files(self) -> None:
604         mode = black.FileMode.AUTO_DETECT
605         with cache_dir() as workspace, patch(
606             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
607         ):
608             one = (workspace / "one.py").resolve()
609             with one.open("w") as fobj:
610                 fobj.write("print('hello')")
611             two = (workspace / "two.py").resolve()
612             with two.open("w") as fobj:
613                 fobj.write("print('hello')")
614             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
615             result = CliRunner().invoke(black.main, [str(workspace)])
616             self.assertEqual(result.exit_code, 0)
617             with one.open("r") as fobj:
618                 self.assertEqual(fobj.read(), "print('hello')")
619             with two.open("r") as fobj:
620                 self.assertEqual(fobj.read(), 'print("hello")\n')
621             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
622             self.assertIn(one, cache)
623             self.assertIn(two, cache)
624
625     def test_no_cache_when_writeback_diff(self) -> None:
626         mode = black.FileMode.AUTO_DETECT
627         with cache_dir() as workspace:
628             src = (workspace / "test.py").resolve()
629             with src.open("w") as fobj:
630                 fobj.write("print('hello')")
631             result = CliRunner().invoke(black.main, [str(src), "--diff"])
632             self.assertEqual(result.exit_code, 0)
633             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
634             self.assertFalse(cache_file.exists())
635
636     def test_no_cache_when_stdin(self) -> None:
637         mode = black.FileMode.AUTO_DETECT
638         with cache_dir():
639             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
640             self.assertEqual(result.exit_code, 0)
641             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
642             self.assertFalse(cache_file.exists())
643
644     def test_read_cache_no_cachefile(self) -> None:
645         mode = black.FileMode.AUTO_DETECT
646         with cache_dir():
647             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
648
649     def test_write_cache_read_cache(self) -> None:
650         mode = black.FileMode.AUTO_DETECT
651         with cache_dir() as workspace:
652             src = (workspace / "test.py").resolve()
653             src.touch()
654             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
655             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
656             self.assertIn(src, cache)
657             self.assertEqual(cache[src], black.get_cache_info(src))
658
659     def test_filter_cached(self) -> None:
660         with TemporaryDirectory() as workspace:
661             path = Path(workspace)
662             uncached = (path / "uncached").resolve()
663             cached = (path / "cached").resolve()
664             cached_but_changed = (path / "changed").resolve()
665             uncached.touch()
666             cached.touch()
667             cached_but_changed.touch()
668             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
669             todo, done = black.filter_cached(
670                 cache, [uncached, cached, cached_but_changed]
671             )
672             self.assertEqual(todo, [uncached, cached_but_changed])
673             self.assertEqual(done, [cached])
674
675     def test_write_cache_creates_directory_if_needed(self) -> None:
676         mode = black.FileMode.AUTO_DETECT
677         with cache_dir(exists=False) as workspace:
678             self.assertFalse(workspace.exists())
679             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
680             self.assertTrue(workspace.exists())
681
682     @event_loop(close=False)
683     def test_failed_formatting_does_not_get_cached(self) -> None:
684         mode = black.FileMode.AUTO_DETECT
685         with cache_dir() as workspace, patch(
686             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
687         ):
688             failing = (workspace / "failing.py").resolve()
689             with failing.open("w") as fobj:
690                 fobj.write("not actually python")
691             clean = (workspace / "clean.py").resolve()
692             with clean.open("w") as fobj:
693                 fobj.write('print("hello")\n')
694             result = CliRunner().invoke(black.main, [str(workspace)])
695             self.assertEqual(result.exit_code, 123)
696             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
697             self.assertNotIn(failing, cache)
698             self.assertIn(clean, cache)
699
700     def test_write_cache_write_fail(self) -> None:
701         mode = black.FileMode.AUTO_DETECT
702         with cache_dir(), patch.object(Path, "open") as mock:
703             mock.side_effect = OSError
704             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
705
706     @event_loop(close=False)
707     def test_check_diff_use_together(self) -> None:
708         with cache_dir():
709             # Files which will be reformatted.
710             src1 = (THIS_DIR / "string_quotes.py").resolve()
711             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
712             self.assertEqual(result.exit_code, 1)
713
714             # Files which will not be reformatted.
715             src2 = (THIS_DIR / "composition.py").resolve()
716             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
717             self.assertEqual(result.exit_code, 0)
718
719             # Multi file command.
720             result = CliRunner().invoke(
721                 black.main, [str(src1), str(src2), "--diff", "--check"]
722             )
723             self.assertEqual(result.exit_code, 1, result.output)
724
725     def test_no_files(self) -> None:
726         with cache_dir():
727             # Without an argument, black exits with error code 0.
728             result = CliRunner().invoke(black.main, [])
729             self.assertEqual(result.exit_code, 0)
730
731     def test_broken_symlink(self) -> None:
732         with cache_dir() as workspace:
733             symlink = workspace / "broken_link.py"
734             try:
735                 symlink.symlink_to("nonexistent.py")
736             except OSError as e:
737                 self.skipTest(f"Can't create symlinks: {e}")
738             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
739             self.assertEqual(result.exit_code, 0)
740
741     def test_read_cache_line_lengths(self) -> None:
742         mode = black.FileMode.AUTO_DETECT
743         with cache_dir() as workspace:
744             path = (workspace / "file.py").resolve()
745             path.touch()
746             black.write_cache({}, [path], 1, mode)
747             one = black.read_cache(1, mode)
748             self.assertIn(path, one)
749             two = black.read_cache(2, mode)
750             self.assertNotIn(path, two)
751
752     def test_single_file_force_pyi(self) -> None:
753         reg_mode = black.FileMode.AUTO_DETECT
754         pyi_mode = black.FileMode.PYI
755         contents, expected = read_data("force_pyi")
756         with cache_dir() as workspace:
757             path = (workspace / "file.py").resolve()
758             with open(path, "w") as fh:
759                 fh.write(contents)
760             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
761             self.assertEqual(result.exit_code, 0)
762             with open(path, "r") as fh:
763                 actual = fh.read()
764             # verify cache with --pyi is separate
765             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
766             self.assertIn(path, pyi_cache)
767             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
768             self.assertNotIn(path, normal_cache)
769         self.assertEqual(actual, expected)
770
771     @event_loop(close=False)
772     def test_multi_file_force_pyi(self) -> None:
773         reg_mode = black.FileMode.AUTO_DETECT
774         pyi_mode = black.FileMode.PYI
775         contents, expected = read_data("force_pyi")
776         with cache_dir() as workspace:
777             paths = [
778                 (workspace / "file1.py").resolve(),
779                 (workspace / "file2.py").resolve(),
780             ]
781             for path in paths:
782                 with open(path, "w") as fh:
783                     fh.write(contents)
784             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
785             self.assertEqual(result.exit_code, 0)
786             for path in paths:
787                 with open(path, "r") as fh:
788                     actual = fh.read()
789                 self.assertEqual(actual, expected)
790             # verify cache with --pyi is separate
791             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
792             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
793             for path in paths:
794                 self.assertIn(path, pyi_cache)
795                 self.assertNotIn(path, normal_cache)
796
797     def test_pipe_force_pyi(self) -> None:
798         source, expected = read_data("force_pyi")
799         result = CliRunner().invoke(black.main, ["-", "-q", "--pyi"], input=source)
800         self.assertEqual(result.exit_code, 0)
801         actual = result.output
802         self.assertFormatEqual(actual, expected)
803
804     def test_single_file_force_py36(self) -> None:
805         reg_mode = black.FileMode.AUTO_DETECT
806         py36_mode = black.FileMode.PYTHON36
807         source, expected = read_data("force_py36")
808         with cache_dir() as workspace:
809             path = (workspace / "file.py").resolve()
810             with open(path, "w") as fh:
811                 fh.write(source)
812             result = CliRunner().invoke(black.main, [str(path), "--py36"])
813             self.assertEqual(result.exit_code, 0)
814             with open(path, "r") as fh:
815                 actual = fh.read()
816             # verify cache with --py36 is separate
817             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
818             self.assertIn(path, py36_cache)
819             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
820             self.assertNotIn(path, normal_cache)
821         self.assertEqual(actual, expected)
822
823     @event_loop(close=False)
824     def test_multi_file_force_py36(self) -> None:
825         reg_mode = black.FileMode.AUTO_DETECT
826         py36_mode = black.FileMode.PYTHON36
827         source, expected = read_data("force_py36")
828         with cache_dir() as workspace:
829             paths = [
830                 (workspace / "file1.py").resolve(),
831                 (workspace / "file2.py").resolve(),
832             ]
833             for path in paths:
834                 with open(path, "w") as fh:
835                     fh.write(source)
836             result = CliRunner().invoke(
837                 black.main, [str(p) for p in paths] + ["--py36"]
838             )
839             self.assertEqual(result.exit_code, 0)
840             for path in paths:
841                 with open(path, "r") as fh:
842                     actual = fh.read()
843                 self.assertEqual(actual, expected)
844             # verify cache with --py36 is separate
845             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
846             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
847             for path in paths:
848                 self.assertIn(path, pyi_cache)
849                 self.assertNotIn(path, normal_cache)
850
851     def test_pipe_force_py36(self) -> None:
852         source, expected = read_data("force_py36")
853         result = CliRunner().invoke(black.main, ["-", "-q", "--py36"], input=source)
854         self.assertEqual(result.exit_code, 0)
855         actual = result.output
856         self.assertFormatEqual(actual, expected)
857
858     def test_include_exclude(self) -> None:
859         path = THIS_DIR / "include_exclude_tests"
860         include = re.compile(r"\.pyi?$")
861         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
862         sources: List[Path] = []
863         expected = [
864             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.py"),
865             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.pyi"),
866         ]
867         this_abs = THIS_DIR.resolve()
868         sources.extend(black.gen_python_files_in_dir(path, this_abs, include, exclude))
869         self.assertEqual(sorted(expected), sorted(sources))
870
871     def test_empty_include(self) -> None:
872         path = THIS_DIR / "include_exclude_tests"
873         empty = re.compile(r"")
874         sources: List[Path] = []
875         expected = [
876             Path(path / "b/exclude/a.pie"),
877             Path(path / "b/exclude/a.py"),
878             Path(path / "b/exclude/a.pyi"),
879             Path(path / "b/dont_exclude/a.pie"),
880             Path(path / "b/dont_exclude/a.py"),
881             Path(path / "b/dont_exclude/a.pyi"),
882             Path(path / "b/.definitely_exclude/a.pie"),
883             Path(path / "b/.definitely_exclude/a.py"),
884             Path(path / "b/.definitely_exclude/a.pyi"),
885         ]
886         this_abs = THIS_DIR.resolve()
887         sources.extend(
888             black.gen_python_files_in_dir(
889                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES)
890             )
891         )
892         self.assertEqual(sorted(expected), sorted(sources))
893
894     def test_empty_exclude(self) -> None:
895         path = THIS_DIR / "include_exclude_tests"
896         empty = re.compile(r"")
897         sources: List[Path] = []
898         expected = [
899             Path(path / "b/dont_exclude/a.py"),
900             Path(path / "b/dont_exclude/a.pyi"),
901             Path(path / "b/exclude/a.py"),
902             Path(path / "b/exclude/a.pyi"),
903             Path(path / "b/.definitely_exclude/a.py"),
904             Path(path / "b/.definitely_exclude/a.pyi"),
905         ]
906         this_abs = THIS_DIR.resolve()
907         sources.extend(
908             black.gen_python_files_in_dir(
909                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty
910             )
911         )
912         self.assertEqual(sorted(expected), sorted(sources))
913
914     def test_invalid_include_exclude(self) -> None:
915         for option in ["--include", "--exclude"]:
916             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
917             self.assertEqual(result.exit_code, 2)
918
919
920 if __name__ == "__main__":
921     unittest.main()