]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

53d175066a95e4e1e56d4f7ee2460fada0e9309b
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager, redirect_stderr
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_expression(self) -> None:
269         source, expected = read_data("expression")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     def test_expression_ff(self) -> None:
276         source, expected = read_data("expression")
277         tmp_file = Path(black.dump_to_file(source))
278         try:
279             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
280             with open(tmp_file, encoding="utf8") as f:
281                 actual = f.read()
282         finally:
283             os.unlink(tmp_file)
284         self.assertFormatEqual(expected, actual)
285         with patch("black.dump_to_file", dump_to_stderr):
286             black.assert_equivalent(source, actual)
287             black.assert_stable(source, actual, black.FileMode())
288
289     def test_expression_diff(self) -> None:
290         source, _ = read_data("expression.py")
291         expected, _ = read_data("expression.diff")
292         tmp_file = Path(black.dump_to_file(source))
293         diff_header = re.compile(
294             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
295             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
296         )
297         try:
298             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
299             self.assertEqual(result.exit_code, 0)
300         finally:
301             os.unlink(tmp_file)
302         actual = result.output
303         actual = diff_header.sub("[Deterministic header]", actual)
304         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
305         if expected != actual:
306             dump = black.dump_to_file(actual)
307             msg = (
308                 f"Expected diff isn't equal to the actual. If you made changes "
309                 f"to expression.py and this is an anticipated difference, "
310                 f"overwrite tests/data/expression.diff with {dump}"
311             )
312             self.assertEqual(expected, actual, msg)
313
314     @patch("black.dump_to_file", dump_to_stderr)
315     def test_fstring(self) -> None:
316         source, expected = read_data("fstring")
317         actual = fs(source)
318         self.assertFormatEqual(expected, actual)
319         black.assert_equivalent(source, actual)
320         black.assert_stable(source, actual, black.FileMode())
321
322     @patch("black.dump_to_file", dump_to_stderr)
323     def test_string_quotes(self) -> None:
324         source, expected = read_data("string_quotes")
325         actual = fs(source)
326         self.assertFormatEqual(expected, actual)
327         black.assert_equivalent(source, actual)
328         black.assert_stable(source, actual, black.FileMode())
329         mode = black.FileMode(string_normalization=False)
330         not_normalized = fs(source, mode=mode)
331         self.assertFormatEqual(source, not_normalized)
332         black.assert_equivalent(source, not_normalized)
333         black.assert_stable(source, not_normalized, mode=mode)
334
335     @patch("black.dump_to_file", dump_to_stderr)
336     def test_slices(self) -> None:
337         source, expected = read_data("slices")
338         actual = fs(source)
339         self.assertFormatEqual(expected, actual)
340         black.assert_equivalent(source, actual)
341         black.assert_stable(source, actual, black.FileMode())
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_comments(self) -> None:
345         source, expected = read_data("comments")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, black.FileMode())
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_comments2(self) -> None:
353         source, expected = read_data("comments2")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, black.FileMode())
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_comments3(self) -> None:
361         source, expected = read_data("comments3")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, black.FileMode())
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_comments4(self) -> None:
369         source, expected = read_data("comments4")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, black.FileMode())
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_comments5(self) -> None:
377         source, expected = read_data("comments5")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, black.FileMode())
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_comments6(self) -> None:
385         source, expected = read_data("comments6")
386         actual = fs(source)
387         self.assertFormatEqual(expected, actual)
388         black.assert_equivalent(source, actual)
389         black.assert_stable(source, actual, black.FileMode())
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_comments7(self) -> None:
393         source, expected = read_data("comments7")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, black.FileMode())
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_comment_after_escaped_newline(self) -> None:
401         source, expected = read_data("comment_after_escaped_newline")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, black.FileMode())
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_cantfit(self) -> None:
409         source, expected = read_data("cantfit")
410         actual = fs(source)
411         self.assertFormatEqual(expected, actual)
412         black.assert_equivalent(source, actual)
413         black.assert_stable(source, actual, black.FileMode())
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_import_spacing(self) -> None:
417         source, expected = read_data("import_spacing")
418         actual = fs(source)
419         self.assertFormatEqual(expected, actual)
420         black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, black.FileMode())
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_composition(self) -> None:
425         source, expected = read_data("composition")
426         actual = fs(source)
427         self.assertFormatEqual(expected, actual)
428         black.assert_equivalent(source, actual)
429         black.assert_stable(source, actual, black.FileMode())
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_empty_lines(self) -> None:
433         source, expected = read_data("empty_lines")
434         actual = fs(source)
435         self.assertFormatEqual(expected, actual)
436         black.assert_equivalent(source, actual)
437         black.assert_stable(source, actual, black.FileMode())
438
439     @patch("black.dump_to_file", dump_to_stderr)
440     def test_string_prefixes(self) -> None:
441         source, expected = read_data("string_prefixes")
442         actual = fs(source)
443         self.assertFormatEqual(expected, actual)
444         black.assert_equivalent(source, actual)
445         black.assert_stable(source, actual, black.FileMode())
446
447     @patch("black.dump_to_file", dump_to_stderr)
448     def test_numeric_literals(self) -> None:
449         source, expected = read_data("numeric_literals")
450         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
451         actual = fs(source, mode=mode)
452         self.assertFormatEqual(expected, actual)
453         black.assert_equivalent(source, actual)
454         black.assert_stable(source, actual, mode)
455
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_numeric_literals_ignoring_underscores(self) -> None:
458         source, expected = read_data("numeric_literals_skip_underscores")
459         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
460         actual = fs(source, mode=mode)
461         self.assertFormatEqual(expected, actual)
462         black.assert_equivalent(source, actual)
463         black.assert_stable(source, actual, mode)
464
465     @patch("black.dump_to_file", dump_to_stderr)
466     def test_numeric_literals_py2(self) -> None:
467         source, expected = read_data("numeric_literals_py2")
468         actual = fs(source)
469         self.assertFormatEqual(expected, actual)
470         black.assert_stable(source, actual, black.FileMode())
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_python2(self) -> None:
474         source, expected = read_data("python2")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         # black.assert_equivalent(source, actual)
478         black.assert_stable(source, actual, black.FileMode())
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_python2_print_function(self) -> None:
482         source, expected = read_data("python2_print_function")
483         mode = black.FileMode(target_versions={TargetVersion.PY27})
484         actual = fs(source, mode=mode)
485         self.assertFormatEqual(expected, actual)
486         black.assert_stable(source, actual, mode)
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_python2_unicode_literals(self) -> None:
490         source, expected = read_data("python2_unicode_literals")
491         actual = fs(source)
492         self.assertFormatEqual(expected, actual)
493         black.assert_stable(source, actual, black.FileMode())
494
495     @patch("black.dump_to_file", dump_to_stderr)
496     def test_stub(self) -> None:
497         mode = black.FileMode(is_pyi=True)
498         source, expected = read_data("stub.pyi")
499         actual = fs(source, mode=mode)
500         self.assertFormatEqual(expected, actual)
501         black.assert_stable(source, actual, mode)
502
503     @patch("black.dump_to_file", dump_to_stderr)
504     def test_python37(self) -> None:
505         source, expected = read_data("python37")
506         actual = fs(source)
507         self.assertFormatEqual(expected, actual)
508         major, minor = sys.version_info[:2]
509         if major > 3 or (major == 3 and minor >= 7):
510             black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, black.FileMode())
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_fmtonoff(self) -> None:
515         source, expected = read_data("fmtonoff")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         black.assert_equivalent(source, actual)
519         black.assert_stable(source, actual, black.FileMode())
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_fmtonoff2(self) -> None:
523         source, expected = read_data("fmtonoff2")
524         actual = fs(source)
525         self.assertFormatEqual(expected, actual)
526         black.assert_equivalent(source, actual)
527         black.assert_stable(source, actual, black.FileMode())
528
529     @patch("black.dump_to_file", dump_to_stderr)
530     def test_remove_empty_parentheses_after_class(self) -> None:
531         source, expected = read_data("class_blank_parentheses")
532         actual = fs(source)
533         self.assertFormatEqual(expected, actual)
534         black.assert_equivalent(source, actual)
535         black.assert_stable(source, actual, black.FileMode())
536
537     @patch("black.dump_to_file", dump_to_stderr)
538     def test_new_line_between_class_and_code(self) -> None:
539         source, expected = read_data("class_methods_new_line")
540         actual = fs(source)
541         self.assertFormatEqual(expected, actual)
542         black.assert_equivalent(source, actual)
543         black.assert_stable(source, actual, black.FileMode())
544
545     @patch("black.dump_to_file", dump_to_stderr)
546     def test_bracket_match(self) -> None:
547         source, expected = read_data("bracketmatch")
548         actual = fs(source)
549         self.assertFormatEqual(expected, actual)
550         black.assert_equivalent(source, actual)
551         black.assert_stable(source, actual, black.FileMode())
552
553     @patch("black.dump_to_file", dump_to_stderr)
554     def test_tuple_assign(self) -> None:
555         source, expected = read_data("tupleassign")
556         actual = fs(source)
557         self.assertFormatEqual(expected, actual)
558         black.assert_equivalent(source, actual)
559         black.assert_stable(source, actual, black.FileMode())
560
561     def test_tab_comment_indentation(self) -> None:
562         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
563         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
564         self.assertFormatEqual(contents_spc, fs(contents_spc))
565         self.assertFormatEqual(contents_spc, fs(contents_tab))
566
567         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
568         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
569         self.assertFormatEqual(contents_spc, fs(contents_spc))
570         self.assertFormatEqual(contents_spc, fs(contents_tab))
571
572         # mixed tabs and spaces (valid Python 2 code)
573         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
574         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
575         self.assertFormatEqual(contents_spc, fs(contents_spc))
576         self.assertFormatEqual(contents_spc, fs(contents_tab))
577
578         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
579         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
580         self.assertFormatEqual(contents_spc, fs(contents_spc))
581         self.assertFormatEqual(contents_spc, fs(contents_tab))
582
583     def test_report_verbose(self) -> None:
584         report = black.Report(verbose=True)
585         out_lines = []
586         err_lines = []
587
588         def out(msg: str, **kwargs: Any) -> None:
589             out_lines.append(msg)
590
591         def err(msg: str, **kwargs: Any) -> None:
592             err_lines.append(msg)
593
594         with patch("black.out", out), patch("black.err", err):
595             report.done(Path("f1"), black.Changed.NO)
596             self.assertEqual(len(out_lines), 1)
597             self.assertEqual(len(err_lines), 0)
598             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
599             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
600             self.assertEqual(report.return_code, 0)
601             report.done(Path("f2"), black.Changed.YES)
602             self.assertEqual(len(out_lines), 2)
603             self.assertEqual(len(err_lines), 0)
604             self.assertEqual(out_lines[-1], "reformatted f2")
605             self.assertEqual(
606                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
607             )
608             report.done(Path("f3"), black.Changed.CACHED)
609             self.assertEqual(len(out_lines), 3)
610             self.assertEqual(len(err_lines), 0)
611             self.assertEqual(
612                 out_lines[-1], "f3 wasn't modified on disk since last run."
613             )
614             self.assertEqual(
615                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
616             )
617             self.assertEqual(report.return_code, 0)
618             report.check = True
619             self.assertEqual(report.return_code, 1)
620             report.check = False
621             report.failed(Path("e1"), "boom")
622             self.assertEqual(len(out_lines), 3)
623             self.assertEqual(len(err_lines), 1)
624             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
625             self.assertEqual(
626                 unstyle(str(report)),
627                 "1 file reformatted, 2 files left unchanged, "
628                 "1 file failed to reformat.",
629             )
630             self.assertEqual(report.return_code, 123)
631             report.done(Path("f3"), black.Changed.YES)
632             self.assertEqual(len(out_lines), 4)
633             self.assertEqual(len(err_lines), 1)
634             self.assertEqual(out_lines[-1], "reformatted f3")
635             self.assertEqual(
636                 unstyle(str(report)),
637                 "2 files reformatted, 2 files left unchanged, "
638                 "1 file failed to reformat.",
639             )
640             self.assertEqual(report.return_code, 123)
641             report.failed(Path("e2"), "boom")
642             self.assertEqual(len(out_lines), 4)
643             self.assertEqual(len(err_lines), 2)
644             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
645             self.assertEqual(
646                 unstyle(str(report)),
647                 "2 files reformatted, 2 files left unchanged, "
648                 "2 files failed to reformat.",
649             )
650             self.assertEqual(report.return_code, 123)
651             report.path_ignored(Path("wat"), "no match")
652             self.assertEqual(len(out_lines), 5)
653             self.assertEqual(len(err_lines), 2)
654             self.assertEqual(out_lines[-1], "wat ignored: no match")
655             self.assertEqual(
656                 unstyle(str(report)),
657                 "2 files reformatted, 2 files left unchanged, "
658                 "2 files failed to reformat.",
659             )
660             self.assertEqual(report.return_code, 123)
661             report.done(Path("f4"), black.Changed.NO)
662             self.assertEqual(len(out_lines), 6)
663             self.assertEqual(len(err_lines), 2)
664             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
665             self.assertEqual(
666                 unstyle(str(report)),
667                 "2 files reformatted, 3 files left unchanged, "
668                 "2 files failed to reformat.",
669             )
670             self.assertEqual(report.return_code, 123)
671             report.check = True
672             self.assertEqual(
673                 unstyle(str(report)),
674                 "2 files would be reformatted, 3 files would be left unchanged, "
675                 "2 files would fail to reformat.",
676             )
677
678     def test_report_quiet(self) -> None:
679         report = black.Report(quiet=True)
680         out_lines = []
681         err_lines = []
682
683         def out(msg: str, **kwargs: Any) -> None:
684             out_lines.append(msg)
685
686         def err(msg: str, **kwargs: Any) -> None:
687             err_lines.append(msg)
688
689         with patch("black.out", out), patch("black.err", err):
690             report.done(Path("f1"), black.Changed.NO)
691             self.assertEqual(len(out_lines), 0)
692             self.assertEqual(len(err_lines), 0)
693             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
694             self.assertEqual(report.return_code, 0)
695             report.done(Path("f2"), black.Changed.YES)
696             self.assertEqual(len(out_lines), 0)
697             self.assertEqual(len(err_lines), 0)
698             self.assertEqual(
699                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
700             )
701             report.done(Path("f3"), black.Changed.CACHED)
702             self.assertEqual(len(out_lines), 0)
703             self.assertEqual(len(err_lines), 0)
704             self.assertEqual(
705                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
706             )
707             self.assertEqual(report.return_code, 0)
708             report.check = True
709             self.assertEqual(report.return_code, 1)
710             report.check = False
711             report.failed(Path("e1"), "boom")
712             self.assertEqual(len(out_lines), 0)
713             self.assertEqual(len(err_lines), 1)
714             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "1 file reformatted, 2 files left unchanged, "
718                 "1 file failed to reformat.",
719             )
720             self.assertEqual(report.return_code, 123)
721             report.done(Path("f3"), black.Changed.YES)
722             self.assertEqual(len(out_lines), 0)
723             self.assertEqual(len(err_lines), 1)
724             self.assertEqual(
725                 unstyle(str(report)),
726                 "2 files reformatted, 2 files left unchanged, "
727                 "1 file failed to reformat.",
728             )
729             self.assertEqual(report.return_code, 123)
730             report.failed(Path("e2"), "boom")
731             self.assertEqual(len(out_lines), 0)
732             self.assertEqual(len(err_lines), 2)
733             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
734             self.assertEqual(
735                 unstyle(str(report)),
736                 "2 files reformatted, 2 files left unchanged, "
737                 "2 files failed to reformat.",
738             )
739             self.assertEqual(report.return_code, 123)
740             report.path_ignored(Path("wat"), "no match")
741             self.assertEqual(len(out_lines), 0)
742             self.assertEqual(len(err_lines), 2)
743             self.assertEqual(
744                 unstyle(str(report)),
745                 "2 files reformatted, 2 files left unchanged, "
746                 "2 files failed to reformat.",
747             )
748             self.assertEqual(report.return_code, 123)
749             report.done(Path("f4"), black.Changed.NO)
750             self.assertEqual(len(out_lines), 0)
751             self.assertEqual(len(err_lines), 2)
752             self.assertEqual(
753                 unstyle(str(report)),
754                 "2 files reformatted, 3 files left unchanged, "
755                 "2 files failed to reformat.",
756             )
757             self.assertEqual(report.return_code, 123)
758             report.check = True
759             self.assertEqual(
760                 unstyle(str(report)),
761                 "2 files would be reformatted, 3 files would be left unchanged, "
762                 "2 files would fail to reformat.",
763             )
764
765     def test_report_normal(self) -> None:
766         report = black.Report()
767         out_lines = []
768         err_lines = []
769
770         def out(msg: str, **kwargs: Any) -> None:
771             out_lines.append(msg)
772
773         def err(msg: str, **kwargs: Any) -> None:
774             err_lines.append(msg)
775
776         with patch("black.out", out), patch("black.err", err):
777             report.done(Path("f1"), black.Changed.NO)
778             self.assertEqual(len(out_lines), 0)
779             self.assertEqual(len(err_lines), 0)
780             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
781             self.assertEqual(report.return_code, 0)
782             report.done(Path("f2"), black.Changed.YES)
783             self.assertEqual(len(out_lines), 1)
784             self.assertEqual(len(err_lines), 0)
785             self.assertEqual(out_lines[-1], "reformatted f2")
786             self.assertEqual(
787                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
788             )
789             report.done(Path("f3"), black.Changed.CACHED)
790             self.assertEqual(len(out_lines), 1)
791             self.assertEqual(len(err_lines), 0)
792             self.assertEqual(out_lines[-1], "reformatted f2")
793             self.assertEqual(
794                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
795             )
796             self.assertEqual(report.return_code, 0)
797             report.check = True
798             self.assertEqual(report.return_code, 1)
799             report.check = False
800             report.failed(Path("e1"), "boom")
801             self.assertEqual(len(out_lines), 1)
802             self.assertEqual(len(err_lines), 1)
803             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
804             self.assertEqual(
805                 unstyle(str(report)),
806                 "1 file reformatted, 2 files left unchanged, "
807                 "1 file failed to reformat.",
808             )
809             self.assertEqual(report.return_code, 123)
810             report.done(Path("f3"), black.Changed.YES)
811             self.assertEqual(len(out_lines), 2)
812             self.assertEqual(len(err_lines), 1)
813             self.assertEqual(out_lines[-1], "reformatted f3")
814             self.assertEqual(
815                 unstyle(str(report)),
816                 "2 files reformatted, 2 files left unchanged, "
817                 "1 file failed to reformat.",
818             )
819             self.assertEqual(report.return_code, 123)
820             report.failed(Path("e2"), "boom")
821             self.assertEqual(len(out_lines), 2)
822             self.assertEqual(len(err_lines), 2)
823             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
824             self.assertEqual(
825                 unstyle(str(report)),
826                 "2 files reformatted, 2 files left unchanged, "
827                 "2 files failed to reformat.",
828             )
829             self.assertEqual(report.return_code, 123)
830             report.path_ignored(Path("wat"), "no match")
831             self.assertEqual(len(out_lines), 2)
832             self.assertEqual(len(err_lines), 2)
833             self.assertEqual(
834                 unstyle(str(report)),
835                 "2 files reformatted, 2 files left unchanged, "
836                 "2 files failed to reformat.",
837             )
838             self.assertEqual(report.return_code, 123)
839             report.done(Path("f4"), black.Changed.NO)
840             self.assertEqual(len(out_lines), 2)
841             self.assertEqual(len(err_lines), 2)
842             self.assertEqual(
843                 unstyle(str(report)),
844                 "2 files reformatted, 3 files left unchanged, "
845                 "2 files failed to reformat.",
846             )
847             self.assertEqual(report.return_code, 123)
848             report.check = True
849             self.assertEqual(
850                 unstyle(str(report)),
851                 "2 files would be reformatted, 3 files would be left unchanged, "
852                 "2 files would fail to reformat.",
853             )
854
855     def test_lib2to3_parse(self) -> None:
856         with self.assertRaises(black.InvalidInput):
857             black.lib2to3_parse("invalid syntax")
858
859         straddling = "x + y"
860         black.lib2to3_parse(straddling)
861         black.lib2to3_parse(straddling, {TargetVersion.PY27})
862         black.lib2to3_parse(straddling, {TargetVersion.PY36})
863         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
864
865         py2_only = "print x"
866         black.lib2to3_parse(py2_only)
867         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
868         with self.assertRaises(black.InvalidInput):
869             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
870         with self.assertRaises(black.InvalidInput):
871             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
872
873         py3_only = "exec(x, end=y)"
874         black.lib2to3_parse(py3_only)
875         with self.assertRaises(black.InvalidInput):
876             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
877         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
878         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
879
880     def test_get_features_used(self) -> None:
881         node = black.lib2to3_parse("def f(*, arg): ...\n")
882         self.assertEqual(black.get_features_used(node), set())
883         node = black.lib2to3_parse("def f(*, arg,): ...\n")
884         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
885         node = black.lib2to3_parse("f(*arg,)\n")
886         self.assertEqual(
887             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
888         )
889         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
890         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
891         node = black.lib2to3_parse("123_456\n")
892         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
893         node = black.lib2to3_parse("123456\n")
894         self.assertEqual(black.get_features_used(node), set())
895         source, expected = read_data("function")
896         node = black.lib2to3_parse(source)
897         expected_features = {
898             Feature.TRAILING_COMMA_IN_CALL,
899             Feature.TRAILING_COMMA_IN_DEF,
900             Feature.F_STRINGS,
901         }
902         self.assertEqual(black.get_features_used(node), expected_features)
903         node = black.lib2to3_parse(expected)
904         self.assertEqual(black.get_features_used(node), expected_features)
905         source, expected = read_data("expression")
906         node = black.lib2to3_parse(source)
907         self.assertEqual(black.get_features_used(node), set())
908         node = black.lib2to3_parse(expected)
909         self.assertEqual(black.get_features_used(node), set())
910
911     def test_get_future_imports(self) -> None:
912         node = black.lib2to3_parse("\n")
913         self.assertEqual(set(), black.get_future_imports(node))
914         node = black.lib2to3_parse("from __future__ import black\n")
915         self.assertEqual({"black"}, black.get_future_imports(node))
916         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
917         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
918         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
919         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
920         node = black.lib2to3_parse(
921             "from __future__ import multiple\nfrom __future__ import imports\n"
922         )
923         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
924         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
925         self.assertEqual({"black"}, black.get_future_imports(node))
926         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
927         self.assertEqual({"black"}, black.get_future_imports(node))
928         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
929         self.assertEqual(set(), black.get_future_imports(node))
930         node = black.lib2to3_parse("from some.module import black\n")
931         self.assertEqual(set(), black.get_future_imports(node))
932         node = black.lib2to3_parse(
933             "from __future__ import unicode_literals as _unicode_literals"
934         )
935         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
936         node = black.lib2to3_parse(
937             "from __future__ import unicode_literals as _lol, print"
938         )
939         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
940
941     def test_debug_visitor(self) -> None:
942         source, _ = read_data("debug_visitor.py")
943         expected, _ = read_data("debug_visitor.out")
944         out_lines = []
945         err_lines = []
946
947         def out(msg: str, **kwargs: Any) -> None:
948             out_lines.append(msg)
949
950         def err(msg: str, **kwargs: Any) -> None:
951             err_lines.append(msg)
952
953         with patch("black.out", out), patch("black.err", err):
954             black.DebugVisitor.show(source)
955         actual = "\n".join(out_lines) + "\n"
956         log_name = ""
957         if expected != actual:
958             log_name = black.dump_to_file(*out_lines)
959         self.assertEqual(
960             expected,
961             actual,
962             f"AST print out is different. Actual version dumped to {log_name}",
963         )
964
965     def test_format_file_contents(self) -> None:
966         empty = ""
967         mode = black.FileMode()
968         with self.assertRaises(black.NothingChanged):
969             black.format_file_contents(empty, mode=mode, fast=False)
970         just_nl = "\n"
971         with self.assertRaises(black.NothingChanged):
972             black.format_file_contents(just_nl, mode=mode, fast=False)
973         same = "l = [1, 2, 3]\n"
974         with self.assertRaises(black.NothingChanged):
975             black.format_file_contents(same, mode=mode, fast=False)
976         different = "l = [1,2,3]"
977         expected = same
978         actual = black.format_file_contents(different, mode=mode, fast=False)
979         self.assertEqual(expected, actual)
980         invalid = "return if you can"
981         with self.assertRaises(black.InvalidInput) as e:
982             black.format_file_contents(invalid, mode=mode, fast=False)
983         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
984
985     def test_endmarker(self) -> None:
986         n = black.lib2to3_parse("\n")
987         self.assertEqual(n.type, black.syms.file_input)
988         self.assertEqual(len(n.children), 1)
989         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
990
991     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
992     def test_assertFormatEqual(self) -> None:
993         out_lines = []
994         err_lines = []
995
996         def out(msg: str, **kwargs: Any) -> None:
997             out_lines.append(msg)
998
999         def err(msg: str, **kwargs: Any) -> None:
1000             err_lines.append(msg)
1001
1002         with patch("black.out", out), patch("black.err", err):
1003             with self.assertRaises(AssertionError):
1004                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
1005
1006         out_str = "".join(out_lines)
1007         self.assertTrue("Expected tree:" in out_str)
1008         self.assertTrue("Actual tree:" in out_str)
1009         self.assertEqual("".join(err_lines), "")
1010
1011     def test_cache_broken_file(self) -> None:
1012         mode = black.FileMode()
1013         with cache_dir() as workspace:
1014             cache_file = black.get_cache_file(mode)
1015             with cache_file.open("w") as fobj:
1016                 fobj.write("this is not a pickle")
1017             self.assertEqual(black.read_cache(mode), {})
1018             src = (workspace / "test.py").resolve()
1019             with src.open("w") as fobj:
1020                 fobj.write("print('hello')")
1021             self.invokeBlack([str(src)])
1022             cache = black.read_cache(mode)
1023             self.assertIn(src, cache)
1024
1025     def test_cache_single_file_already_cached(self) -> None:
1026         mode = black.FileMode()
1027         with cache_dir() as workspace:
1028             src = (workspace / "test.py").resolve()
1029             with src.open("w") as fobj:
1030                 fobj.write("print('hello')")
1031             black.write_cache({}, [src], mode)
1032             self.invokeBlack([str(src)])
1033             with src.open("r") as fobj:
1034                 self.assertEqual(fobj.read(), "print('hello')")
1035
1036     @event_loop(close=False)
1037     def test_cache_multiple_files(self) -> None:
1038         mode = black.FileMode()
1039         with cache_dir() as workspace, patch(
1040             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1041         ):
1042             one = (workspace / "one.py").resolve()
1043             with one.open("w") as fobj:
1044                 fobj.write("print('hello')")
1045             two = (workspace / "two.py").resolve()
1046             with two.open("w") as fobj:
1047                 fobj.write("print('hello')")
1048             black.write_cache({}, [one], mode)
1049             self.invokeBlack([str(workspace)])
1050             with one.open("r") as fobj:
1051                 self.assertEqual(fobj.read(), "print('hello')")
1052             with two.open("r") as fobj:
1053                 self.assertEqual(fobj.read(), 'print("hello")\n')
1054             cache = black.read_cache(mode)
1055             self.assertIn(one, cache)
1056             self.assertIn(two, cache)
1057
1058     def test_no_cache_when_writeback_diff(self) -> None:
1059         mode = black.FileMode()
1060         with cache_dir() as workspace:
1061             src = (workspace / "test.py").resolve()
1062             with src.open("w") as fobj:
1063                 fobj.write("print('hello')")
1064             self.invokeBlack([str(src), "--diff"])
1065             cache_file = black.get_cache_file(mode)
1066             self.assertFalse(cache_file.exists())
1067
1068     def test_no_cache_when_stdin(self) -> None:
1069         mode = black.FileMode()
1070         with cache_dir():
1071             result = CliRunner().invoke(
1072                 black.main, ["-"], input=BytesIO(b"print('hello')")
1073             )
1074             self.assertEqual(result.exit_code, 0)
1075             cache_file = black.get_cache_file(mode)
1076             self.assertFalse(cache_file.exists())
1077
1078     def test_read_cache_no_cachefile(self) -> None:
1079         mode = black.FileMode()
1080         with cache_dir():
1081             self.assertEqual(black.read_cache(mode), {})
1082
1083     def test_write_cache_read_cache(self) -> None:
1084         mode = black.FileMode()
1085         with cache_dir() as workspace:
1086             src = (workspace / "test.py").resolve()
1087             src.touch()
1088             black.write_cache({}, [src], mode)
1089             cache = black.read_cache(mode)
1090             self.assertIn(src, cache)
1091             self.assertEqual(cache[src], black.get_cache_info(src))
1092
1093     def test_filter_cached(self) -> None:
1094         with TemporaryDirectory() as workspace:
1095             path = Path(workspace)
1096             uncached = (path / "uncached").resolve()
1097             cached = (path / "cached").resolve()
1098             cached_but_changed = (path / "changed").resolve()
1099             uncached.touch()
1100             cached.touch()
1101             cached_but_changed.touch()
1102             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1103             todo, done = black.filter_cached(
1104                 cache, {uncached, cached, cached_but_changed}
1105             )
1106             self.assertEqual(todo, {uncached, cached_but_changed})
1107             self.assertEqual(done, {cached})
1108
1109     def test_write_cache_creates_directory_if_needed(self) -> None:
1110         mode = black.FileMode()
1111         with cache_dir(exists=False) as workspace:
1112             self.assertFalse(workspace.exists())
1113             black.write_cache({}, [], mode)
1114             self.assertTrue(workspace.exists())
1115
1116     @event_loop(close=False)
1117     def test_failed_formatting_does_not_get_cached(self) -> None:
1118         mode = black.FileMode()
1119         with cache_dir() as workspace, patch(
1120             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1121         ):
1122             failing = (workspace / "failing.py").resolve()
1123             with failing.open("w") as fobj:
1124                 fobj.write("not actually python")
1125             clean = (workspace / "clean.py").resolve()
1126             with clean.open("w") as fobj:
1127                 fobj.write('print("hello")\n')
1128             self.invokeBlack([str(workspace)], exit_code=123)
1129             cache = black.read_cache(mode)
1130             self.assertNotIn(failing, cache)
1131             self.assertIn(clean, cache)
1132
1133     def test_write_cache_write_fail(self) -> None:
1134         mode = black.FileMode()
1135         with cache_dir(), patch.object(Path, "open") as mock:
1136             mock.side_effect = OSError
1137             black.write_cache({}, [], mode)
1138
1139     @event_loop(close=False)
1140     def test_check_diff_use_together(self) -> None:
1141         with cache_dir():
1142             # Files which will be reformatted.
1143             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1144             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1145             # Files which will not be reformatted.
1146             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1147             self.invokeBlack([str(src2), "--diff", "--check"])
1148             # Multi file command.
1149             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1150
1151     def test_no_files(self) -> None:
1152         with cache_dir():
1153             # Without an argument, black exits with error code 0.
1154             self.invokeBlack([])
1155
1156     def test_broken_symlink(self) -> None:
1157         with cache_dir() as workspace:
1158             symlink = workspace / "broken_link.py"
1159             try:
1160                 symlink.symlink_to("nonexistent.py")
1161             except OSError as e:
1162                 self.skipTest(f"Can't create symlinks: {e}")
1163             self.invokeBlack([str(workspace.resolve())])
1164
1165     def test_read_cache_line_lengths(self) -> None:
1166         mode = black.FileMode()
1167         short_mode = black.FileMode(line_length=1)
1168         with cache_dir() as workspace:
1169             path = (workspace / "file.py").resolve()
1170             path.touch()
1171             black.write_cache({}, [path], mode)
1172             one = black.read_cache(mode)
1173             self.assertIn(path, one)
1174             two = black.read_cache(short_mode)
1175             self.assertNotIn(path, two)
1176
1177     def test_single_file_force_pyi(self) -> None:
1178         reg_mode = black.FileMode()
1179         pyi_mode = black.FileMode(is_pyi=True)
1180         contents, expected = read_data("force_pyi")
1181         with cache_dir() as workspace:
1182             path = (workspace / "file.py").resolve()
1183             with open(path, "w") as fh:
1184                 fh.write(contents)
1185             self.invokeBlack([str(path), "--pyi"])
1186             with open(path, "r") as fh:
1187                 actual = fh.read()
1188             # verify cache with --pyi is separate
1189             pyi_cache = black.read_cache(pyi_mode)
1190             self.assertIn(path, pyi_cache)
1191             normal_cache = black.read_cache(reg_mode)
1192             self.assertNotIn(path, normal_cache)
1193         self.assertEqual(actual, expected)
1194
1195     @event_loop(close=False)
1196     def test_multi_file_force_pyi(self) -> None:
1197         reg_mode = black.FileMode()
1198         pyi_mode = black.FileMode(is_pyi=True)
1199         contents, expected = read_data("force_pyi")
1200         with cache_dir() as workspace:
1201             paths = [
1202                 (workspace / "file1.py").resolve(),
1203                 (workspace / "file2.py").resolve(),
1204             ]
1205             for path in paths:
1206                 with open(path, "w") as fh:
1207                     fh.write(contents)
1208             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1209             for path in paths:
1210                 with open(path, "r") as fh:
1211                     actual = fh.read()
1212                 self.assertEqual(actual, expected)
1213             # verify cache with --pyi is separate
1214             pyi_cache = black.read_cache(pyi_mode)
1215             normal_cache = black.read_cache(reg_mode)
1216             for path in paths:
1217                 self.assertIn(path, pyi_cache)
1218                 self.assertNotIn(path, normal_cache)
1219
1220     def test_pipe_force_pyi(self) -> None:
1221         source, expected = read_data("force_pyi")
1222         result = CliRunner().invoke(
1223             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1224         )
1225         self.assertEqual(result.exit_code, 0)
1226         actual = result.output
1227         self.assertFormatEqual(actual, expected)
1228
1229     def test_single_file_force_py36(self) -> None:
1230         reg_mode = black.FileMode()
1231         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1232         source, expected = read_data("force_py36")
1233         with cache_dir() as workspace:
1234             path = (workspace / "file.py").resolve()
1235             with open(path, "w") as fh:
1236                 fh.write(source)
1237             self.invokeBlack([str(path), *PY36_ARGS])
1238             with open(path, "r") as fh:
1239                 actual = fh.read()
1240             # verify cache with --target-version is separate
1241             py36_cache = black.read_cache(py36_mode)
1242             self.assertIn(path, py36_cache)
1243             normal_cache = black.read_cache(reg_mode)
1244             self.assertNotIn(path, normal_cache)
1245         self.assertEqual(actual, expected)
1246
1247     @event_loop(close=False)
1248     def test_multi_file_force_py36(self) -> None:
1249         reg_mode = black.FileMode()
1250         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1251         source, expected = read_data("force_py36")
1252         with cache_dir() as workspace:
1253             paths = [
1254                 (workspace / "file1.py").resolve(),
1255                 (workspace / "file2.py").resolve(),
1256             ]
1257             for path in paths:
1258                 with open(path, "w") as fh:
1259                     fh.write(source)
1260             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1261             for path in paths:
1262                 with open(path, "r") as fh:
1263                     actual = fh.read()
1264                 self.assertEqual(actual, expected)
1265             # verify cache with --target-version is separate
1266             pyi_cache = black.read_cache(py36_mode)
1267             normal_cache = black.read_cache(reg_mode)
1268             for path in paths:
1269                 self.assertIn(path, pyi_cache)
1270                 self.assertNotIn(path, normal_cache)
1271
1272     def test_pipe_force_py36(self) -> None:
1273         source, expected = read_data("force_py36")
1274         result = CliRunner().invoke(
1275             black.main,
1276             ["-", "-q", "--target-version=py36"],
1277             input=BytesIO(source.encode("utf8")),
1278         )
1279         self.assertEqual(result.exit_code, 0)
1280         actual = result.output
1281         self.assertFormatEqual(actual, expected)
1282
1283     def test_include_exclude(self) -> None:
1284         path = THIS_DIR / "data" / "include_exclude_tests"
1285         include = re.compile(r"\.pyi?$")
1286         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1287         report = black.Report()
1288         sources: List[Path] = []
1289         expected = [
1290             Path(path / "b/dont_exclude/a.py"),
1291             Path(path / "b/dont_exclude/a.pyi"),
1292         ]
1293         this_abs = THIS_DIR.resolve()
1294         sources.extend(
1295             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1296         )
1297         self.assertEqual(sorted(expected), sorted(sources))
1298
1299     def test_empty_include(self) -> None:
1300         path = THIS_DIR / "data" / "include_exclude_tests"
1301         report = black.Report()
1302         empty = re.compile(r"")
1303         sources: List[Path] = []
1304         expected = [
1305             Path(path / "b/exclude/a.pie"),
1306             Path(path / "b/exclude/a.py"),
1307             Path(path / "b/exclude/a.pyi"),
1308             Path(path / "b/dont_exclude/a.pie"),
1309             Path(path / "b/dont_exclude/a.py"),
1310             Path(path / "b/dont_exclude/a.pyi"),
1311             Path(path / "b/.definitely_exclude/a.pie"),
1312             Path(path / "b/.definitely_exclude/a.py"),
1313             Path(path / "b/.definitely_exclude/a.pyi"),
1314         ]
1315         this_abs = THIS_DIR.resolve()
1316         sources.extend(
1317             black.gen_python_files_in_dir(
1318                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1319             )
1320         )
1321         self.assertEqual(sorted(expected), sorted(sources))
1322
1323     def test_empty_exclude(self) -> None:
1324         path = THIS_DIR / "data" / "include_exclude_tests"
1325         report = black.Report()
1326         empty = re.compile(r"")
1327         sources: List[Path] = []
1328         expected = [
1329             Path(path / "b/dont_exclude/a.py"),
1330             Path(path / "b/dont_exclude/a.pyi"),
1331             Path(path / "b/exclude/a.py"),
1332             Path(path / "b/exclude/a.pyi"),
1333             Path(path / "b/.definitely_exclude/a.py"),
1334             Path(path / "b/.definitely_exclude/a.pyi"),
1335         ]
1336         this_abs = THIS_DIR.resolve()
1337         sources.extend(
1338             black.gen_python_files_in_dir(
1339                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1340             )
1341         )
1342         self.assertEqual(sorted(expected), sorted(sources))
1343
1344     def test_invalid_include_exclude(self) -> None:
1345         for option in ["--include", "--exclude"]:
1346             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1347
1348     def test_preserves_line_endings(self) -> None:
1349         with TemporaryDirectory() as workspace:
1350             test_file = Path(workspace) / "test.py"
1351             for nl in ["\n", "\r\n"]:
1352                 contents = nl.join(["def f(  ):", "    pass"])
1353                 test_file.write_bytes(contents.encode())
1354                 ff(test_file, write_back=black.WriteBack.YES)
1355                 updated_contents: bytes = test_file.read_bytes()
1356                 self.assertIn(nl.encode(), updated_contents)
1357                 if nl == "\n":
1358                     self.assertNotIn(b"\r\n", updated_contents)
1359
1360     def test_preserves_line_endings_via_stdin(self) -> None:
1361         for nl in ["\n", "\r\n"]:
1362             contents = nl.join(["def f(  ):", "    pass"])
1363             runner = BlackRunner()
1364             result = runner.invoke(
1365                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1366             )
1367             self.assertEqual(result.exit_code, 0)
1368             output = runner.stdout_bytes
1369             self.assertIn(nl.encode("utf8"), output)
1370             if nl == "\n":
1371                 self.assertNotIn(b"\r\n", output)
1372
1373     def test_assert_equivalent_different_asts(self) -> None:
1374         with self.assertRaises(AssertionError):
1375             black.assert_equivalent("{}", "None")
1376
1377     def test_symlink_out_of_root_directory(self) -> None:
1378         path = MagicMock()
1379         root = THIS_DIR
1380         child = MagicMock()
1381         include = re.compile(black.DEFAULT_INCLUDES)
1382         exclude = re.compile(black.DEFAULT_EXCLUDES)
1383         report = black.Report()
1384         # `child` should behave like a symlink which resolved path is clearly
1385         # outside of the `root` directory.
1386         path.iterdir.return_value = [child]
1387         child.resolve.return_value = Path("/a/b/c")
1388         child.is_symlink.return_value = True
1389         try:
1390             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1391         except ValueError as ve:
1392             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1393         path.iterdir.assert_called_once()
1394         child.resolve.assert_called_once()
1395         child.is_symlink.assert_called_once()
1396         # `child` should behave like a strange file which resolved path is clearly
1397         # outside of the `root` directory.
1398         child.is_symlink.return_value = False
1399         with self.assertRaises(ValueError):
1400             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1401         path.iterdir.assert_called()
1402         self.assertEqual(path.iterdir.call_count, 2)
1403         child.resolve.assert_called()
1404         self.assertEqual(child.resolve.call_count, 2)
1405         child.is_symlink.assert_called()
1406         self.assertEqual(child.is_symlink.call_count, 2)
1407
1408     def test_shhh_click(self) -> None:
1409         try:
1410             from click import _unicodefun  # type: ignore
1411         except ModuleNotFoundError:
1412             self.skipTest("Incompatible Click version")
1413         if not hasattr(_unicodefun, "_verify_python3_env"):
1414             self.skipTest("Incompatible Click version")
1415         # First, let's see if Click is crashing with a preferred ASCII charset.
1416         with patch("locale.getpreferredencoding") as gpe:
1417             gpe.return_value = "ASCII"
1418             with self.assertRaises(RuntimeError):
1419                 _unicodefun._verify_python3_env()
1420         # Now, let's silence Click...
1421         black.patch_click()
1422         # ...and confirm it's silent.
1423         with patch("locale.getpreferredencoding") as gpe:
1424             gpe.return_value = "ASCII"
1425             try:
1426                 _unicodefun._verify_python3_env()
1427             except RuntimeError as re:
1428                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1429
1430     def test_root_logger_not_used_directly(self) -> None:
1431         def fail(*args: Any, **kwargs: Any) -> None:
1432             self.fail("Record created with root logger")
1433
1434         with patch.multiple(
1435             logging.root,
1436             debug=fail,
1437             info=fail,
1438             warning=fail,
1439             error=fail,
1440             critical=fail,
1441             log=fail,
1442         ):
1443             ff(THIS_FILE)
1444
1445     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1446     @async_test
1447     async def test_blackd_request_needs_formatting(self) -> None:
1448         app = blackd.make_app()
1449         async with TestClient(TestServer(app)) as client:
1450             response = await client.post("/", data=b"print('hello world')")
1451             self.assertEqual(response.status, 200)
1452             self.assertEqual(response.charset, "utf8")
1453             self.assertEqual(await response.read(), b'print("hello world")\n')
1454
1455     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1456     @async_test
1457     async def test_blackd_request_no_change(self) -> None:
1458         app = blackd.make_app()
1459         async with TestClient(TestServer(app)) as client:
1460             response = await client.post("/", data=b'print("hello world")\n')
1461             self.assertEqual(response.status, 204)
1462             self.assertEqual(await response.read(), b"")
1463
1464     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1465     @async_test
1466     async def test_blackd_request_syntax_error(self) -> None:
1467         app = blackd.make_app()
1468         async with TestClient(TestServer(app)) as client:
1469             response = await client.post("/", data=b"what even ( is")
1470             self.assertEqual(response.status, 400)
1471             content = await response.text()
1472             self.assertTrue(
1473                 content.startswith("Cannot parse"),
1474                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1475             )
1476
1477     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1478     @async_test
1479     async def test_blackd_unsupported_version(self) -> None:
1480         app = blackd.make_app()
1481         async with TestClient(TestServer(app)) as client:
1482             response = await client.post(
1483                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1484             )
1485             self.assertEqual(response.status, 501)
1486
1487     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1488     @async_test
1489     async def test_blackd_supported_version(self) -> None:
1490         app = blackd.make_app()
1491         async with TestClient(TestServer(app)) as client:
1492             response = await client.post(
1493                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1494             )
1495             self.assertEqual(response.status, 200)
1496
1497     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1498     @async_test
1499     async def test_blackd_invalid_python_variant(self) -> None:
1500         app = blackd.make_app()
1501         async with TestClient(TestServer(app)) as client:
1502
1503             async def check(header_value: str, expected_status: int = 400) -> None:
1504                 response = await client.post(
1505                     "/",
1506                     data=b"what",
1507                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1508                 )
1509                 self.assertEqual(response.status, expected_status)
1510
1511             await check("lol")
1512             await check("ruby3.5")
1513             await check("pyi3.6")
1514             await check("py1.5")
1515             await check("2.8")
1516             await check("py2.8")
1517             await check("3.0")
1518             await check("pypy3.0")
1519             await check("jython3.4")
1520
1521     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1522     @async_test
1523     async def test_blackd_pyi(self) -> None:
1524         app = blackd.make_app()
1525         async with TestClient(TestServer(app)) as client:
1526             source, expected = read_data("stub.pyi")
1527             response = await client.post(
1528                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1529             )
1530             self.assertEqual(response.status, 200)
1531             self.assertEqual(await response.text(), expected)
1532
1533     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1534     @async_test
1535     async def test_blackd_python_variant(self) -> None:
1536         app = blackd.make_app()
1537         code = (
1538             "def f(\n"
1539             "    and_has_a_bunch_of,\n"
1540             "    very_long_arguments_too,\n"
1541             "    and_lots_of_them_as_well_lol,\n"
1542             "    **and_very_long_keyword_arguments\n"
1543             "):\n"
1544             "    pass\n"
1545         )
1546         async with TestClient(TestServer(app)) as client:
1547
1548             async def check(header_value: str, expected_status: int) -> None:
1549                 response = await client.post(
1550                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1551                 )
1552                 self.assertEqual(response.status, expected_status)
1553
1554             await check("3.6", 200)
1555             await check("py3.6", 200)
1556             await check("3.6,3.7", 200)
1557             await check("3.6,py3.7", 200)
1558
1559             await check("2", 204)
1560             await check("2.7", 204)
1561             await check("py2.7", 204)
1562             await check("3.4", 204)
1563             await check("py3.4", 204)
1564
1565     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1566     @async_test
1567     async def test_blackd_fast(self) -> None:
1568         with open(os.devnull, "w") as dn, redirect_stderr(dn):
1569             app = blackd.make_app()
1570             async with TestClient(TestServer(app)) as client:
1571                 response = await client.post("/", data=b"ur'hello'")
1572                 self.assertEqual(response.status, 500)
1573                 self.assertIn("failed to parse source file", await response.text())
1574                 response = await client.post(
1575                     "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1576                 )
1577                 self.assertEqual(response.status, 200)
1578
1579     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1580     @async_test
1581     async def test_blackd_line_length(self) -> None:
1582         app = blackd.make_app()
1583         async with TestClient(TestServer(app)) as client:
1584             response = await client.post(
1585                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1586             )
1587             self.assertEqual(response.status, 200)
1588
1589     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1590     @async_test
1591     async def test_blackd_invalid_line_length(self) -> None:
1592         app = blackd.make_app()
1593         async with TestClient(TestServer(app)) as client:
1594             response = await client.post(
1595                 "/",
1596                 data=b'print("hello")\n',
1597                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1598             )
1599             self.assertEqual(response.status, 400)
1600
1601     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1602     def test_blackd_main(self) -> None:
1603         with patch("blackd.web.run_app"):
1604             result = CliRunner().invoke(blackd.main, [])
1605             if result.exception is not None:
1606                 raise result.exception
1607             self.assertEqual(result.exit_code, 0)
1608
1609
1610 if __name__ == "__main__":
1611     unittest.main(module="test_black")