]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Sort default excludes, include the leading slash
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14 import re
15
16 from click import unstyle
17 from click.testing import CliRunner
18
19 import black
20
21 ll = 88
22 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
23 fs = partial(black.format_str, line_length=ll)
24 THIS_FILE = Path(__file__)
25 THIS_DIR = THIS_FILE.parent
26 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
27
28
29 def dump_to_stderr(*output: str) -> str:
30     return "\n" + "\n".join(output) + "\n"
31
32
33 def read_data(name: str) -> Tuple[str, str]:
34     """read_data('test_name') -> 'input', 'output'"""
35     if not name.endswith((".py", ".pyi", ".out", ".diff")):
36         name += ".py"
37     _input: List[str] = []
38     _output: List[str] = []
39     with open(THIS_DIR / name, "r", encoding="utf8") as test:
40         lines = test.readlines()
41     result = _input
42     for line in lines:
43         line = line.replace(EMPTY_LINE, "")
44         if line.rstrip() == "# output":
45             result = _output
46             continue
47
48         result.append(line)
49     if _input and not _output:
50         # If there's no output marker, treat the entire file as already pre-formatted.
51         _output = _input[:]
52     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
53
54
55 @contextmanager
56 def cache_dir(exists: bool = True) -> Iterator[Path]:
57     with TemporaryDirectory() as workspace:
58         cache_dir = Path(workspace)
59         if not exists:
60             cache_dir = cache_dir / "new"
61         with patch("black.CACHE_DIR", cache_dir):
62             yield cache_dir
63
64
65 @contextmanager
66 def event_loop(close: bool) -> Iterator[None]:
67     policy = asyncio.get_event_loop_policy()
68     old_loop = policy.get_event_loop()
69     loop = policy.new_event_loop()
70     asyncio.set_event_loop(loop)
71     try:
72         yield
73
74     finally:
75         policy.set_event_loop(old_loop)
76         if close:
77             loop.close()
78
79
80 class BlackTestCase(unittest.TestCase):
81     maxDiff = None
82
83     def assertFormatEqual(self, expected: str, actual: str) -> None:
84         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
85             bdv: black.DebugVisitor[Any]
86             black.out("Expected tree:", fg="green")
87             try:
88                 exp_node = black.lib2to3_parse(expected)
89                 bdv = black.DebugVisitor()
90                 list(bdv.visit(exp_node))
91             except Exception as ve:
92                 black.err(str(ve))
93             black.out("Actual tree:", fg="red")
94             try:
95                 exp_node = black.lib2to3_parse(actual)
96                 bdv = black.DebugVisitor()
97                 list(bdv.visit(exp_node))
98             except Exception as ve:
99                 black.err(str(ve))
100         self.assertEqual(expected, actual)
101
102     @patch("black.dump_to_file", dump_to_stderr)
103     def test_self(self) -> None:
104         source, expected = read_data("test_black")
105         actual = fs(source)
106         self.assertFormatEqual(expected, actual)
107         black.assert_equivalent(source, actual)
108         black.assert_stable(source, actual, line_length=ll)
109         self.assertFalse(ff(THIS_FILE))
110
111     @patch("black.dump_to_file", dump_to_stderr)
112     def test_black(self) -> None:
113         source, expected = read_data("../black")
114         actual = fs(source)
115         self.assertFormatEqual(expected, actual)
116         black.assert_equivalent(source, actual)
117         black.assert_stable(source, actual, line_length=ll)
118         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
119
120     def test_piping(self) -> None:
121         source, expected = read_data("../black")
122         hold_stdin, hold_stdout = sys.stdin, sys.stdout
123         try:
124             sys.stdin, sys.stdout = StringIO(source), StringIO()
125             sys.stdin.name = "<stdin>"
126             black.format_stdin_to_stdout(
127                 line_length=ll, fast=True, write_back=black.WriteBack.YES
128             )
129             sys.stdout.seek(0)
130             actual = sys.stdout.read()
131         finally:
132             sys.stdin, sys.stdout = hold_stdin, hold_stdout
133         self.assertFormatEqual(expected, actual)
134         black.assert_equivalent(source, actual)
135         black.assert_stable(source, actual, line_length=ll)
136
137     def test_piping_diff(self) -> None:
138         source, _ = read_data("expression.py")
139         expected, _ = read_data("expression.diff")
140         hold_stdin, hold_stdout = sys.stdin, sys.stdout
141         try:
142             sys.stdin, sys.stdout = StringIO(source), StringIO()
143             sys.stdin.name = "<stdin>"
144             black.format_stdin_to_stdout(
145                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
146             )
147             sys.stdout.seek(0)
148             actual = sys.stdout.read()
149         finally:
150             sys.stdin, sys.stdout = hold_stdin, hold_stdout
151         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
152         self.assertEqual(expected, actual)
153
154     @patch("black.dump_to_file", dump_to_stderr)
155     def test_setup(self) -> None:
156         source, expected = read_data("../setup")
157         actual = fs(source)
158         self.assertFormatEqual(expected, actual)
159         black.assert_equivalent(source, actual)
160         black.assert_stable(source, actual, line_length=ll)
161         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
162
163     @patch("black.dump_to_file", dump_to_stderr)
164     def test_function(self) -> None:
165         source, expected = read_data("function")
166         actual = fs(source)
167         self.assertFormatEqual(expected, actual)
168         black.assert_equivalent(source, actual)
169         black.assert_stable(source, actual, line_length=ll)
170
171     @patch("black.dump_to_file", dump_to_stderr)
172     def test_function2(self) -> None:
173         source, expected = read_data("function2")
174         actual = fs(source)
175         self.assertFormatEqual(expected, actual)
176         black.assert_equivalent(source, actual)
177         black.assert_stable(source, actual, line_length=ll)
178
179     @patch("black.dump_to_file", dump_to_stderr)
180     def test_expression(self) -> None:
181         source, expected = read_data("expression")
182         actual = fs(source)
183         self.assertFormatEqual(expected, actual)
184         black.assert_equivalent(source, actual)
185         black.assert_stable(source, actual, line_length=ll)
186
187     def test_expression_ff(self) -> None:
188         source, expected = read_data("expression")
189         tmp_file = Path(black.dump_to_file(source))
190         try:
191             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
192             with open(tmp_file, encoding="utf8") as f:
193                 actual = f.read()
194         finally:
195             os.unlink(tmp_file)
196         self.assertFormatEqual(expected, actual)
197         with patch("black.dump_to_file", dump_to_stderr):
198             black.assert_equivalent(source, actual)
199             black.assert_stable(source, actual, line_length=ll)
200
201     def test_expression_diff(self) -> None:
202         source, _ = read_data("expression.py")
203         expected, _ = read_data("expression.diff")
204         tmp_file = Path(black.dump_to_file(source))
205         hold_stdout = sys.stdout
206         try:
207             sys.stdout = StringIO()
208             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
209             sys.stdout.seek(0)
210             actual = sys.stdout.read()
211             actual = actual.replace(str(tmp_file), "<stdin>")
212         finally:
213             sys.stdout = hold_stdout
214             os.unlink(tmp_file)
215         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
216         if expected != actual:
217             dump = black.dump_to_file(actual)
218             msg = (
219                 f"Expected diff isn't equal to the actual. If you made changes "
220                 f"to expression.py and this is an anticipated difference, "
221                 f"overwrite tests/expression.diff with {dump}"
222             )
223             self.assertEqual(expected, actual, msg)
224
225     @patch("black.dump_to_file", dump_to_stderr)
226     def test_fstring(self) -> None:
227         source, expected = read_data("fstring")
228         actual = fs(source)
229         self.assertFormatEqual(expected, actual)
230         black.assert_equivalent(source, actual)
231         black.assert_stable(source, actual, line_length=ll)
232
233     @patch("black.dump_to_file", dump_to_stderr)
234     def test_string_quotes(self) -> None:
235         source, expected = read_data("string_quotes")
236         actual = fs(source)
237         self.assertFormatEqual(expected, actual)
238         black.assert_equivalent(source, actual)
239         black.assert_stable(source, actual, line_length=ll)
240         mode = black.FileMode.NO_STRING_NORMALIZATION
241         not_normalized = fs(source, mode=mode)
242         self.assertFormatEqual(source, not_normalized)
243         black.assert_equivalent(source, not_normalized)
244         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
245
246     @patch("black.dump_to_file", dump_to_stderr)
247     def test_slices(self) -> None:
248         source, expected = read_data("slices")
249         actual = fs(source)
250         self.assertFormatEqual(expected, actual)
251         black.assert_equivalent(source, actual)
252         black.assert_stable(source, actual, line_length=ll)
253
254     @patch("black.dump_to_file", dump_to_stderr)
255     def test_comments(self) -> None:
256         source, expected = read_data("comments")
257         actual = fs(source)
258         self.assertFormatEqual(expected, actual)
259         black.assert_equivalent(source, actual)
260         black.assert_stable(source, actual, line_length=ll)
261
262     @patch("black.dump_to_file", dump_to_stderr)
263     def test_comments2(self) -> None:
264         source, expected = read_data("comments2")
265         actual = fs(source)
266         self.assertFormatEqual(expected, actual)
267         black.assert_equivalent(source, actual)
268         black.assert_stable(source, actual, line_length=ll)
269
270     @patch("black.dump_to_file", dump_to_stderr)
271     def test_comments3(self) -> None:
272         source, expected = read_data("comments3")
273         actual = fs(source)
274         self.assertFormatEqual(expected, actual)
275         black.assert_equivalent(source, actual)
276         black.assert_stable(source, actual, line_length=ll)
277
278     @patch("black.dump_to_file", dump_to_stderr)
279     def test_comments4(self) -> None:
280         source, expected = read_data("comments4")
281         actual = fs(source)
282         self.assertFormatEqual(expected, actual)
283         black.assert_equivalent(source, actual)
284         black.assert_stable(source, actual, line_length=ll)
285
286     @patch("black.dump_to_file", dump_to_stderr)
287     def test_comments5(self) -> None:
288         source, expected = read_data("comments5")
289         actual = fs(source)
290         self.assertFormatEqual(expected, actual)
291         black.assert_equivalent(source, actual)
292         black.assert_stable(source, actual, line_length=ll)
293
294     @patch("black.dump_to_file", dump_to_stderr)
295     def test_cantfit(self) -> None:
296         source, expected = read_data("cantfit")
297         actual = fs(source)
298         self.assertFormatEqual(expected, actual)
299         black.assert_equivalent(source, actual)
300         black.assert_stable(source, actual, line_length=ll)
301
302     @patch("black.dump_to_file", dump_to_stderr)
303     def test_import_spacing(self) -> None:
304         source, expected = read_data("import_spacing")
305         actual = fs(source)
306         self.assertFormatEqual(expected, actual)
307         black.assert_equivalent(source, actual)
308         black.assert_stable(source, actual, line_length=ll)
309
310     @patch("black.dump_to_file", dump_to_stderr)
311     def test_composition(self) -> None:
312         source, expected = read_data("composition")
313         actual = fs(source)
314         self.assertFormatEqual(expected, actual)
315         black.assert_equivalent(source, actual)
316         black.assert_stable(source, actual, line_length=ll)
317
318     @patch("black.dump_to_file", dump_to_stderr)
319     def test_empty_lines(self) -> None:
320         source, expected = read_data("empty_lines")
321         actual = fs(source)
322         self.assertFormatEqual(expected, actual)
323         black.assert_equivalent(source, actual)
324         black.assert_stable(source, actual, line_length=ll)
325
326     @patch("black.dump_to_file", dump_to_stderr)
327     def test_string_prefixes(self) -> None:
328         source, expected = read_data("string_prefixes")
329         actual = fs(source)
330         self.assertFormatEqual(expected, actual)
331         black.assert_equivalent(source, actual)
332         black.assert_stable(source, actual, line_length=ll)
333
334     @patch("black.dump_to_file", dump_to_stderr)
335     def test_python2(self) -> None:
336         source, expected = read_data("python2")
337         actual = fs(source)
338         self.assertFormatEqual(expected, actual)
339         # black.assert_equivalent(source, actual)
340         black.assert_stable(source, actual, line_length=ll)
341
342     @patch("black.dump_to_file", dump_to_stderr)
343     def test_python2_unicode_literals(self) -> None:
344         source, expected = read_data("python2_unicode_literals")
345         actual = fs(source)
346         self.assertFormatEqual(expected, actual)
347         black.assert_stable(source, actual, line_length=ll)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_stub(self) -> None:
351         mode = black.FileMode.PYI
352         source, expected = read_data("stub.pyi")
353         actual = fs(source, mode=mode)
354         self.assertFormatEqual(expected, actual)
355         black.assert_stable(source, actual, line_length=ll, mode=mode)
356
357     @patch("black.dump_to_file", dump_to_stderr)
358     def test_fmtonoff(self) -> None:
359         source, expected = read_data("fmtonoff")
360         actual = fs(source)
361         self.assertFormatEqual(expected, actual)
362         black.assert_equivalent(source, actual)
363         black.assert_stable(source, actual, line_length=ll)
364
365     @patch("black.dump_to_file", dump_to_stderr)
366     def test_remove_empty_parentheses_after_class(self) -> None:
367         source, expected = read_data("class_blank_parentheses")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         black.assert_equivalent(source, actual)
371         black.assert_stable(source, actual, line_length=ll)
372
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_new_line_between_class_and_code(self) -> None:
375         source, expected = read_data("class_methods_new_line")
376         actual = fs(source)
377         self.assertFormatEqual(expected, actual)
378         black.assert_equivalent(source, actual)
379         black.assert_stable(source, actual, line_length=ll)
380
381     def test_report(self) -> None:
382         report = black.Report()
383         out_lines = []
384         err_lines = []
385
386         def out(msg: str, **kwargs: Any) -> None:
387             out_lines.append(msg)
388
389         def err(msg: str, **kwargs: Any) -> None:
390             err_lines.append(msg)
391
392         with patch("black.out", out), patch("black.err", err):
393             report.done(Path("f1"), black.Changed.NO)
394             self.assertEqual(len(out_lines), 1)
395             self.assertEqual(len(err_lines), 0)
396             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
397             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
398             self.assertEqual(report.return_code, 0)
399             report.done(Path("f2"), black.Changed.YES)
400             self.assertEqual(len(out_lines), 2)
401             self.assertEqual(len(err_lines), 0)
402             self.assertEqual(out_lines[-1], "reformatted f2")
403             self.assertEqual(
404                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
405             )
406             report.done(Path("f3"), black.Changed.CACHED)
407             self.assertEqual(len(out_lines), 3)
408             self.assertEqual(len(err_lines), 0)
409             self.assertEqual(
410                 out_lines[-1], "f3 wasn't modified on disk since last run."
411             )
412             self.assertEqual(
413                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
414             )
415             self.assertEqual(report.return_code, 0)
416             report.check = True
417             self.assertEqual(report.return_code, 1)
418             report.check = False
419             report.failed(Path("e1"), "boom")
420             self.assertEqual(len(out_lines), 3)
421             self.assertEqual(len(err_lines), 1)
422             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
423             self.assertEqual(
424                 unstyle(str(report)),
425                 "1 file reformatted, 2 files left unchanged, "
426                 "1 file failed to reformat.",
427             )
428             self.assertEqual(report.return_code, 123)
429             report.done(Path("f3"), black.Changed.YES)
430             self.assertEqual(len(out_lines), 4)
431             self.assertEqual(len(err_lines), 1)
432             self.assertEqual(out_lines[-1], "reformatted f3")
433             self.assertEqual(
434                 unstyle(str(report)),
435                 "2 files reformatted, 2 files left unchanged, "
436                 "1 file failed to reformat.",
437             )
438             self.assertEqual(report.return_code, 123)
439             report.failed(Path("e2"), "boom")
440             self.assertEqual(len(out_lines), 4)
441             self.assertEqual(len(err_lines), 2)
442             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
443             self.assertEqual(
444                 unstyle(str(report)),
445                 "2 files reformatted, 2 files left unchanged, "
446                 "2 files failed to reformat.",
447             )
448             self.assertEqual(report.return_code, 123)
449             report.done(Path("f4"), black.Changed.NO)
450             self.assertEqual(len(out_lines), 5)
451             self.assertEqual(len(err_lines), 2)
452             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
453             self.assertEqual(
454                 unstyle(str(report)),
455                 "2 files reformatted, 3 files left unchanged, "
456                 "2 files failed to reformat.",
457             )
458             self.assertEqual(report.return_code, 123)
459             report.check = True
460             self.assertEqual(
461                 unstyle(str(report)),
462                 "2 files would be reformatted, 3 files would be left unchanged, "
463                 "2 files would fail to reformat.",
464             )
465
466     def test_is_python36(self) -> None:
467         node = black.lib2to3_parse("def f(*, arg): ...\n")
468         self.assertFalse(black.is_python36(node))
469         node = black.lib2to3_parse("def f(*, arg,): ...\n")
470         self.assertTrue(black.is_python36(node))
471         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
472         self.assertTrue(black.is_python36(node))
473         source, expected = read_data("function")
474         node = black.lib2to3_parse(source)
475         self.assertTrue(black.is_python36(node))
476         node = black.lib2to3_parse(expected)
477         self.assertTrue(black.is_python36(node))
478         source, expected = read_data("expression")
479         node = black.lib2to3_parse(source)
480         self.assertFalse(black.is_python36(node))
481         node = black.lib2to3_parse(expected)
482         self.assertFalse(black.is_python36(node))
483
484     def test_get_future_imports(self) -> None:
485         node = black.lib2to3_parse("\n")
486         self.assertEqual(set(), black.get_future_imports(node))
487         node = black.lib2to3_parse("from __future__ import black\n")
488         self.assertEqual({"black"}, black.get_future_imports(node))
489         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
490         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
491         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
492         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
493         node = black.lib2to3_parse(
494             "from __future__ import multiple\nfrom __future__ import imports\n"
495         )
496         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
497         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
498         self.assertEqual({"black"}, black.get_future_imports(node))
499         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
500         self.assertEqual({"black"}, black.get_future_imports(node))
501         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
502         self.assertEqual(set(), black.get_future_imports(node))
503         node = black.lib2to3_parse("from some.module import black\n")
504         self.assertEqual(set(), black.get_future_imports(node))
505
506     def test_debug_visitor(self) -> None:
507         source, _ = read_data("debug_visitor.py")
508         expected, _ = read_data("debug_visitor.out")
509         out_lines = []
510         err_lines = []
511
512         def out(msg: str, **kwargs: Any) -> None:
513             out_lines.append(msg)
514
515         def err(msg: str, **kwargs: Any) -> None:
516             err_lines.append(msg)
517
518         with patch("black.out", out), patch("black.err", err):
519             black.DebugVisitor.show(source)
520         actual = "\n".join(out_lines) + "\n"
521         log_name = ""
522         if expected != actual:
523             log_name = black.dump_to_file(*out_lines)
524         self.assertEqual(
525             expected,
526             actual,
527             f"AST print out is different. Actual version dumped to {log_name}",
528         )
529
530     def test_format_file_contents(self) -> None:
531         empty = ""
532         with self.assertRaises(black.NothingChanged):
533             black.format_file_contents(empty, line_length=ll, fast=False)
534         just_nl = "\n"
535         with self.assertRaises(black.NothingChanged):
536             black.format_file_contents(just_nl, line_length=ll, fast=False)
537         same = "l = [1, 2, 3]\n"
538         with self.assertRaises(black.NothingChanged):
539             black.format_file_contents(same, line_length=ll, fast=False)
540         different = "l = [1,2,3]"
541         expected = same
542         actual = black.format_file_contents(different, line_length=ll, fast=False)
543         self.assertEqual(expected, actual)
544         invalid = "return if you can"
545         with self.assertRaises(ValueError) as e:
546             black.format_file_contents(invalid, line_length=ll, fast=False)
547         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
548
549     def test_endmarker(self) -> None:
550         n = black.lib2to3_parse("\n")
551         self.assertEqual(n.type, black.syms.file_input)
552         self.assertEqual(len(n.children), 1)
553         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
554
555     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
556     def test_assertFormatEqual(self) -> None:
557         out_lines = []
558         err_lines = []
559
560         def out(msg: str, **kwargs: Any) -> None:
561             out_lines.append(msg)
562
563         def err(msg: str, **kwargs: Any) -> None:
564             err_lines.append(msg)
565
566         with patch("black.out", out), patch("black.err", err):
567             with self.assertRaises(AssertionError):
568                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
569
570         out_str = "".join(out_lines)
571         self.assertTrue("Expected tree:" in out_str)
572         self.assertTrue("Actual tree:" in out_str)
573         self.assertEqual("".join(err_lines), "")
574
575     def test_cache_broken_file(self) -> None:
576         mode = black.FileMode.AUTO_DETECT
577         with cache_dir() as workspace:
578             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
579             with cache_file.open("w") as fobj:
580                 fobj.write("this is not a pickle")
581             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
582             src = (workspace / "test.py").resolve()
583             with src.open("w") as fobj:
584                 fobj.write("print('hello')")
585             result = CliRunner().invoke(black.main, [str(src)])
586             self.assertEqual(result.exit_code, 0)
587             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
588             self.assertIn(src, cache)
589
590     def test_cache_single_file_already_cached(self) -> None:
591         mode = black.FileMode.AUTO_DETECT
592         with cache_dir() as workspace:
593             src = (workspace / "test.py").resolve()
594             with src.open("w") as fobj:
595                 fobj.write("print('hello')")
596             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
597             result = CliRunner().invoke(black.main, [str(src)])
598             self.assertEqual(result.exit_code, 0)
599             with src.open("r") as fobj:
600                 self.assertEqual(fobj.read(), "print('hello')")
601
602     @event_loop(close=False)
603     def test_cache_multiple_files(self) -> None:
604         mode = black.FileMode.AUTO_DETECT
605         with cache_dir() as workspace, patch(
606             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
607         ):
608             one = (workspace / "one.py").resolve()
609             with one.open("w") as fobj:
610                 fobj.write("print('hello')")
611             two = (workspace / "two.py").resolve()
612             with two.open("w") as fobj:
613                 fobj.write("print('hello')")
614             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
615             result = CliRunner().invoke(black.main, [str(workspace)])
616             self.assertEqual(result.exit_code, 0)
617             with one.open("r") as fobj:
618                 self.assertEqual(fobj.read(), "print('hello')")
619             with two.open("r") as fobj:
620                 self.assertEqual(fobj.read(), 'print("hello")\n')
621             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
622             self.assertIn(one, cache)
623             self.assertIn(two, cache)
624
625     def test_no_cache_when_writeback_diff(self) -> None:
626         mode = black.FileMode.AUTO_DETECT
627         with cache_dir() as workspace:
628             src = (workspace / "test.py").resolve()
629             with src.open("w") as fobj:
630                 fobj.write("print('hello')")
631             result = CliRunner().invoke(black.main, [str(src), "--diff"])
632             self.assertEqual(result.exit_code, 0)
633             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
634             self.assertFalse(cache_file.exists())
635
636     def test_no_cache_when_stdin(self) -> None:
637         mode = black.FileMode.AUTO_DETECT
638         with cache_dir():
639             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
640             self.assertEqual(result.exit_code, 0)
641             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
642             self.assertFalse(cache_file.exists())
643
644     def test_read_cache_no_cachefile(self) -> None:
645         mode = black.FileMode.AUTO_DETECT
646         with cache_dir():
647             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
648
649     def test_write_cache_read_cache(self) -> None:
650         mode = black.FileMode.AUTO_DETECT
651         with cache_dir() as workspace:
652             src = (workspace / "test.py").resolve()
653             src.touch()
654             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
655             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
656             self.assertIn(src, cache)
657             self.assertEqual(cache[src], black.get_cache_info(src))
658
659     def test_filter_cached(self) -> None:
660         with TemporaryDirectory() as workspace:
661             path = Path(workspace)
662             uncached = (path / "uncached").resolve()
663             cached = (path / "cached").resolve()
664             cached_but_changed = (path / "changed").resolve()
665             uncached.touch()
666             cached.touch()
667             cached_but_changed.touch()
668             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
669             todo, done = black.filter_cached(
670                 cache, [uncached, cached, cached_but_changed]
671             )
672             self.assertEqual(todo, [uncached, cached_but_changed])
673             self.assertEqual(done, [cached])
674
675     def test_write_cache_creates_directory_if_needed(self) -> None:
676         mode = black.FileMode.AUTO_DETECT
677         with cache_dir(exists=False) as workspace:
678             self.assertFalse(workspace.exists())
679             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
680             self.assertTrue(workspace.exists())
681
682     @event_loop(close=False)
683     def test_failed_formatting_does_not_get_cached(self) -> None:
684         mode = black.FileMode.AUTO_DETECT
685         with cache_dir() as workspace, patch(
686             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
687         ):
688             failing = (workspace / "failing.py").resolve()
689             with failing.open("w") as fobj:
690                 fobj.write("not actually python")
691             clean = (workspace / "clean.py").resolve()
692             with clean.open("w") as fobj:
693                 fobj.write('print("hello")\n')
694             result = CliRunner().invoke(black.main, [str(workspace)])
695             self.assertEqual(result.exit_code, 123)
696             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
697             self.assertNotIn(failing, cache)
698             self.assertIn(clean, cache)
699
700     def test_write_cache_write_fail(self) -> None:
701         mode = black.FileMode.AUTO_DETECT
702         with cache_dir(), patch.object(Path, "open") as mock:
703             mock.side_effect = OSError
704             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
705
706     @event_loop(close=False)
707     def test_check_diff_use_together(self) -> None:
708         with cache_dir():
709             # Files which will be reformatted.
710             src1 = (THIS_DIR / "string_quotes.py").resolve()
711             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
712             self.assertEqual(result.exit_code, 1)
713
714             # Files which will not be reformatted.
715             src2 = (THIS_DIR / "composition.py").resolve()
716             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
717             self.assertEqual(result.exit_code, 0)
718
719             # Multi file command.
720             result = CliRunner().invoke(
721                 black.main, [str(src1), str(src2), "--diff", "--check"]
722             )
723             self.assertEqual(result.exit_code, 1, result.output)
724
725     def test_no_files(self) -> None:
726         with cache_dir():
727             # Without an argument, black exits with error code 0.
728             result = CliRunner().invoke(black.main, [])
729             self.assertEqual(result.exit_code, 0)
730
731     def test_broken_symlink(self) -> None:
732         with cache_dir() as workspace:
733             symlink = workspace / "broken_link.py"
734             symlink.symlink_to("nonexistent.py")
735             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
736             self.assertEqual(result.exit_code, 0)
737
738     def test_read_cache_line_lengths(self) -> None:
739         mode = black.FileMode.AUTO_DETECT
740         with cache_dir() as workspace:
741             path = (workspace / "file.py").resolve()
742             path.touch()
743             black.write_cache({}, [path], 1, mode)
744             one = black.read_cache(1, mode)
745             self.assertIn(path, one)
746             two = black.read_cache(2, mode)
747             self.assertNotIn(path, two)
748
749     def test_single_file_force_pyi(self) -> None:
750         reg_mode = black.FileMode.AUTO_DETECT
751         pyi_mode = black.FileMode.PYI
752         contents, expected = read_data("force_pyi")
753         with cache_dir() as workspace:
754             path = (workspace / "file.py").resolve()
755             with open(path, "w") as fh:
756                 fh.write(contents)
757             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
758             self.assertEqual(result.exit_code, 0)
759             with open(path, "r") as fh:
760                 actual = fh.read()
761             # verify cache with --pyi is separate
762             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
763             self.assertIn(path, pyi_cache)
764             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
765             self.assertNotIn(path, normal_cache)
766         self.assertEqual(actual, expected)
767
768     @event_loop(close=False)
769     def test_multi_file_force_pyi(self) -> None:
770         reg_mode = black.FileMode.AUTO_DETECT
771         pyi_mode = black.FileMode.PYI
772         contents, expected = read_data("force_pyi")
773         with cache_dir() as workspace:
774             paths = [
775                 (workspace / "file1.py").resolve(),
776                 (workspace / "file2.py").resolve(),
777             ]
778             for path in paths:
779                 with open(path, "w") as fh:
780                     fh.write(contents)
781             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
782             self.assertEqual(result.exit_code, 0)
783             for path in paths:
784                 with open(path, "r") as fh:
785                     actual = fh.read()
786                 self.assertEqual(actual, expected)
787             # verify cache with --pyi is separate
788             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
789             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
790             for path in paths:
791                 self.assertIn(path, pyi_cache)
792                 self.assertNotIn(path, normal_cache)
793
794     def test_pipe_force_pyi(self) -> None:
795         source, expected = read_data("force_pyi")
796         result = CliRunner().invoke(black.main, ["-", "-q", "--pyi"], input=source)
797         self.assertEqual(result.exit_code, 0)
798         actual = result.output
799         self.assertFormatEqual(actual, expected)
800
801     def test_single_file_force_py36(self) -> None:
802         reg_mode = black.FileMode.AUTO_DETECT
803         py36_mode = black.FileMode.PYTHON36
804         source, expected = read_data("force_py36")
805         with cache_dir() as workspace:
806             path = (workspace / "file.py").resolve()
807             with open(path, "w") as fh:
808                 fh.write(source)
809             result = CliRunner().invoke(black.main, [str(path), "--py36"])
810             self.assertEqual(result.exit_code, 0)
811             with open(path, "r") as fh:
812                 actual = fh.read()
813             # verify cache with --py36 is separate
814             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
815             self.assertIn(path, py36_cache)
816             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
817             self.assertNotIn(path, normal_cache)
818         self.assertEqual(actual, expected)
819
820     @event_loop(close=False)
821     def test_multi_file_force_py36(self) -> None:
822         reg_mode = black.FileMode.AUTO_DETECT
823         py36_mode = black.FileMode.PYTHON36
824         source, expected = read_data("force_py36")
825         with cache_dir() as workspace:
826             paths = [
827                 (workspace / "file1.py").resolve(),
828                 (workspace / "file2.py").resolve(),
829             ]
830             for path in paths:
831                 with open(path, "w") as fh:
832                     fh.write(source)
833             result = CliRunner().invoke(
834                 black.main, [str(p) for p in paths] + ["--py36"]
835             )
836             self.assertEqual(result.exit_code, 0)
837             for path in paths:
838                 with open(path, "r") as fh:
839                     actual = fh.read()
840                 self.assertEqual(actual, expected)
841             # verify cache with --py36 is separate
842             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
843             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
844             for path in paths:
845                 self.assertIn(path, pyi_cache)
846                 self.assertNotIn(path, normal_cache)
847
848     def test_pipe_force_py36(self) -> None:
849         source, expected = read_data("force_py36")
850         result = CliRunner().invoke(black.main, ["-", "-q", "--py36"], input=source)
851         self.assertEqual(result.exit_code, 0)
852         actual = result.output
853         self.assertFormatEqual(actual, expected)
854
855     def test_include_exclude(self) -> None:
856         path = THIS_DIR / "include_exclude_tests"
857         include = re.compile(r"\.pyi?$")
858         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
859         sources: List[Path] = []
860         expected = [
861             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.py"),
862             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.pyi"),
863         ]
864         sources.extend(black.gen_python_files_in_dir(path, include, exclude))
865         self.assertEqual(sorted(expected), sorted(sources))
866
867     def test_empty_include(self) -> None:
868         path = THIS_DIR / "include_exclude_tests"
869         empty = re.compile(r"")
870         sources: List[Path] = []
871         sources.extend(
872             black.gen_python_files_in_dir(
873                 path, empty, re.compile(black.DEFAULT_EXCLUDES)
874             )
875         )
876         self.assertEqual([], (sources))
877
878     def test_empty_exclude(self) -> None:
879         path = THIS_DIR / "include_exclude_tests"
880         empty = re.compile(r"")
881         sources: List[Path] = []
882         expected = [
883             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.py"),
884             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.pyi"),
885             Path(THIS_DIR / "include_exclude_tests/b/exclude/a.py"),
886             Path(THIS_DIR / "include_exclude_tests/b/exclude/a.pyi"),
887             Path(THIS_DIR / "include_exclude_tests/b/.definitely_exclude/a.py"),
888             Path(THIS_DIR / "include_exclude_tests/b/.definitely_exclude/a.pyi"),
889         ]
890         sources.extend(
891             black.gen_python_files_in_dir(
892                 path, re.compile(black.DEFAULT_INCLUDES), empty
893             )
894         )
895         self.assertEqual(sorted(expected), sorted(sources))
896
897     def test_invalid_include_exclude(self) -> None:
898         for option in ["--include", "--exclude"]:
899             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
900             self.assertEqual(result.exit_code, 2)
901
902
903 if __name__ == "__main__":
904     unittest.main()