]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add PyCon talk link to README (#844)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_expression(self) -> None:
269         source, expected = read_data("expression")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     def test_expression_ff(self) -> None:
276         source, expected = read_data("expression")
277         tmp_file = Path(black.dump_to_file(source))
278         try:
279             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
280             with open(tmp_file, encoding="utf8") as f:
281                 actual = f.read()
282         finally:
283             os.unlink(tmp_file)
284         self.assertFormatEqual(expected, actual)
285         with patch("black.dump_to_file", dump_to_stderr):
286             black.assert_equivalent(source, actual)
287             black.assert_stable(source, actual, black.FileMode())
288
289     def test_expression_diff(self) -> None:
290         source, _ = read_data("expression.py")
291         expected, _ = read_data("expression.diff")
292         tmp_file = Path(black.dump_to_file(source))
293         diff_header = re.compile(
294             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
295             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
296         )
297         try:
298             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
299             self.assertEqual(result.exit_code, 0)
300         finally:
301             os.unlink(tmp_file)
302         actual = result.output
303         actual = diff_header.sub("[Deterministic header]", actual)
304         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
305         if expected != actual:
306             dump = black.dump_to_file(actual)
307             msg = (
308                 f"Expected diff isn't equal to the actual. If you made changes "
309                 f"to expression.py and this is an anticipated difference, "
310                 f"overwrite tests/data/expression.diff with {dump}"
311             )
312             self.assertEqual(expected, actual, msg)
313
314     @patch("black.dump_to_file", dump_to_stderr)
315     def test_fstring(self) -> None:
316         source, expected = read_data("fstring")
317         actual = fs(source)
318         self.assertFormatEqual(expected, actual)
319         black.assert_equivalent(source, actual)
320         black.assert_stable(source, actual, black.FileMode())
321
322     @patch("black.dump_to_file", dump_to_stderr)
323     def test_string_quotes(self) -> None:
324         source, expected = read_data("string_quotes")
325         actual = fs(source)
326         self.assertFormatEqual(expected, actual)
327         black.assert_equivalent(source, actual)
328         black.assert_stable(source, actual, black.FileMode())
329         mode = black.FileMode(string_normalization=False)
330         not_normalized = fs(source, mode=mode)
331         self.assertFormatEqual(source, not_normalized)
332         black.assert_equivalent(source, not_normalized)
333         black.assert_stable(source, not_normalized, mode=mode)
334
335     @patch("black.dump_to_file", dump_to_stderr)
336     def test_slices(self) -> None:
337         source, expected = read_data("slices")
338         actual = fs(source)
339         self.assertFormatEqual(expected, actual)
340         black.assert_equivalent(source, actual)
341         black.assert_stable(source, actual, black.FileMode())
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_comments(self) -> None:
345         source, expected = read_data("comments")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, black.FileMode())
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_comments2(self) -> None:
353         source, expected = read_data("comments2")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, black.FileMode())
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_comments3(self) -> None:
361         source, expected = read_data("comments3")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, black.FileMode())
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_comments4(self) -> None:
369         source, expected = read_data("comments4")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, black.FileMode())
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_comments5(self) -> None:
377         source, expected = read_data("comments5")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, black.FileMode())
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_comments6(self) -> None:
385         source, expected = read_data("comments6")
386         actual = fs(source)
387         self.assertFormatEqual(expected, actual)
388         black.assert_equivalent(source, actual)
389         black.assert_stable(source, actual, black.FileMode())
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_comments7(self) -> None:
393         source, expected = read_data("comments7")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, black.FileMode())
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_comment_after_escaped_newline(self) -> None:
401         source, expected = read_data("comment_after_escaped_newline")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, black.FileMode())
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_cantfit(self) -> None:
409         source, expected = read_data("cantfit")
410         actual = fs(source)
411         self.assertFormatEqual(expected, actual)
412         black.assert_equivalent(source, actual)
413         black.assert_stable(source, actual, black.FileMode())
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_import_spacing(self) -> None:
417         source, expected = read_data("import_spacing")
418         actual = fs(source)
419         self.assertFormatEqual(expected, actual)
420         black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, black.FileMode())
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_composition(self) -> None:
425         source, expected = read_data("composition")
426         actual = fs(source)
427         self.assertFormatEqual(expected, actual)
428         black.assert_equivalent(source, actual)
429         black.assert_stable(source, actual, black.FileMode())
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_empty_lines(self) -> None:
433         source, expected = read_data("empty_lines")
434         actual = fs(source)
435         self.assertFormatEqual(expected, actual)
436         black.assert_equivalent(source, actual)
437         black.assert_stable(source, actual, black.FileMode())
438
439     @patch("black.dump_to_file", dump_to_stderr)
440     def test_string_prefixes(self) -> None:
441         source, expected = read_data("string_prefixes")
442         actual = fs(source)
443         self.assertFormatEqual(expected, actual)
444         black.assert_equivalent(source, actual)
445         black.assert_stable(source, actual, black.FileMode())
446
447     @patch("black.dump_to_file", dump_to_stderr)
448     def test_numeric_literals(self) -> None:
449         source, expected = read_data("numeric_literals")
450         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
451         actual = fs(source, mode=mode)
452         self.assertFormatEqual(expected, actual)
453         black.assert_equivalent(source, actual)
454         black.assert_stable(source, actual, mode)
455
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_numeric_literals_ignoring_underscores(self) -> None:
458         source, expected = read_data("numeric_literals_skip_underscores")
459         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
460         actual = fs(source, mode=mode)
461         self.assertFormatEqual(expected, actual)
462         black.assert_equivalent(source, actual)
463         black.assert_stable(source, actual, mode)
464
465     @patch("black.dump_to_file", dump_to_stderr)
466     def test_numeric_literals_py2(self) -> None:
467         source, expected = read_data("numeric_literals_py2")
468         actual = fs(source)
469         self.assertFormatEqual(expected, actual)
470         black.assert_stable(source, actual, black.FileMode())
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_python2(self) -> None:
474         source, expected = read_data("python2")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         black.assert_equivalent(source, actual)
478         black.assert_stable(source, actual, black.FileMode())
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_python2_print_function(self) -> None:
482         source, expected = read_data("python2_print_function")
483         mode = black.FileMode(target_versions={TargetVersion.PY27})
484         actual = fs(source, mode=mode)
485         self.assertFormatEqual(expected, actual)
486         black.assert_equivalent(source, actual)
487         black.assert_stable(source, actual, mode)
488
489     @patch("black.dump_to_file", dump_to_stderr)
490     def test_python2_unicode_literals(self) -> None:
491         source, expected = read_data("python2_unicode_literals")
492         actual = fs(source)
493         self.assertFormatEqual(expected, actual)
494         black.assert_equivalent(source, actual)
495         black.assert_stable(source, actual, black.FileMode())
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_stub(self) -> None:
499         mode = black.FileMode(is_pyi=True)
500         source, expected = read_data("stub.pyi")
501         actual = fs(source, mode=mode)
502         self.assertFormatEqual(expected, actual)
503         black.assert_stable(source, actual, mode)
504
505     @patch("black.dump_to_file", dump_to_stderr)
506     def test_python37(self) -> None:
507         source, expected = read_data("python37")
508         actual = fs(source)
509         self.assertFormatEqual(expected, actual)
510         major, minor = sys.version_info[:2]
511         if major > 3 or (major == 3 and minor >= 7):
512             black.assert_equivalent(source, actual)
513         black.assert_stable(source, actual, black.FileMode())
514
515     @patch("black.dump_to_file", dump_to_stderr)
516     def test_fmtonoff(self) -> None:
517         source, expected = read_data("fmtonoff")
518         actual = fs(source)
519         self.assertFormatEqual(expected, actual)
520         black.assert_equivalent(source, actual)
521         black.assert_stable(source, actual, black.FileMode())
522
523     @patch("black.dump_to_file", dump_to_stderr)
524     def test_fmtonoff2(self) -> None:
525         source, expected = read_data("fmtonoff2")
526         actual = fs(source)
527         self.assertFormatEqual(expected, actual)
528         black.assert_equivalent(source, actual)
529         black.assert_stable(source, actual, black.FileMode())
530
531     @patch("black.dump_to_file", dump_to_stderr)
532     def test_remove_empty_parentheses_after_class(self) -> None:
533         source, expected = read_data("class_blank_parentheses")
534         actual = fs(source)
535         self.assertFormatEqual(expected, actual)
536         black.assert_equivalent(source, actual)
537         black.assert_stable(source, actual, black.FileMode())
538
539     @patch("black.dump_to_file", dump_to_stderr)
540     def test_new_line_between_class_and_code(self) -> None:
541         source, expected = read_data("class_methods_new_line")
542         actual = fs(source)
543         self.assertFormatEqual(expected, actual)
544         black.assert_equivalent(source, actual)
545         black.assert_stable(source, actual, black.FileMode())
546
547     @patch("black.dump_to_file", dump_to_stderr)
548     def test_bracket_match(self) -> None:
549         source, expected = read_data("bracketmatch")
550         actual = fs(source)
551         self.assertFormatEqual(expected, actual)
552         black.assert_equivalent(source, actual)
553         black.assert_stable(source, actual, black.FileMode())
554
555     @patch("black.dump_to_file", dump_to_stderr)
556     def test_tuple_assign(self) -> None:
557         source, expected = read_data("tupleassign")
558         actual = fs(source)
559         self.assertFormatEqual(expected, actual)
560         black.assert_equivalent(source, actual)
561         black.assert_stable(source, actual, black.FileMode())
562
563     def test_tab_comment_indentation(self) -> None:
564         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
565         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
566         self.assertFormatEqual(contents_spc, fs(contents_spc))
567         self.assertFormatEqual(contents_spc, fs(contents_tab))
568
569         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
570         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
571         self.assertFormatEqual(contents_spc, fs(contents_spc))
572         self.assertFormatEqual(contents_spc, fs(contents_tab))
573
574         # mixed tabs and spaces (valid Python 2 code)
575         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
576         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
577         self.assertFormatEqual(contents_spc, fs(contents_spc))
578         self.assertFormatEqual(contents_spc, fs(contents_tab))
579
580         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
581         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
582         self.assertFormatEqual(contents_spc, fs(contents_spc))
583         self.assertFormatEqual(contents_spc, fs(contents_tab))
584
585     def test_report_verbose(self) -> None:
586         report = black.Report(verbose=True)
587         out_lines = []
588         err_lines = []
589
590         def out(msg: str, **kwargs: Any) -> None:
591             out_lines.append(msg)
592
593         def err(msg: str, **kwargs: Any) -> None:
594             err_lines.append(msg)
595
596         with patch("black.out", out), patch("black.err", err):
597             report.done(Path("f1"), black.Changed.NO)
598             self.assertEqual(len(out_lines), 1)
599             self.assertEqual(len(err_lines), 0)
600             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
601             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
602             self.assertEqual(report.return_code, 0)
603             report.done(Path("f2"), black.Changed.YES)
604             self.assertEqual(len(out_lines), 2)
605             self.assertEqual(len(err_lines), 0)
606             self.assertEqual(out_lines[-1], "reformatted f2")
607             self.assertEqual(
608                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
609             )
610             report.done(Path("f3"), black.Changed.CACHED)
611             self.assertEqual(len(out_lines), 3)
612             self.assertEqual(len(err_lines), 0)
613             self.assertEqual(
614                 out_lines[-1], "f3 wasn't modified on disk since last run."
615             )
616             self.assertEqual(
617                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
618             )
619             self.assertEqual(report.return_code, 0)
620             report.check = True
621             self.assertEqual(report.return_code, 1)
622             report.check = False
623             report.failed(Path("e1"), "boom")
624             self.assertEqual(len(out_lines), 3)
625             self.assertEqual(len(err_lines), 1)
626             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
627             self.assertEqual(
628                 unstyle(str(report)),
629                 "1 file reformatted, 2 files left unchanged, "
630                 "1 file failed to reformat.",
631             )
632             self.assertEqual(report.return_code, 123)
633             report.done(Path("f3"), black.Changed.YES)
634             self.assertEqual(len(out_lines), 4)
635             self.assertEqual(len(err_lines), 1)
636             self.assertEqual(out_lines[-1], "reformatted f3")
637             self.assertEqual(
638                 unstyle(str(report)),
639                 "2 files reformatted, 2 files left unchanged, "
640                 "1 file failed to reformat.",
641             )
642             self.assertEqual(report.return_code, 123)
643             report.failed(Path("e2"), "boom")
644             self.assertEqual(len(out_lines), 4)
645             self.assertEqual(len(err_lines), 2)
646             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
647             self.assertEqual(
648                 unstyle(str(report)),
649                 "2 files reformatted, 2 files left unchanged, "
650                 "2 files failed to reformat.",
651             )
652             self.assertEqual(report.return_code, 123)
653             report.path_ignored(Path("wat"), "no match")
654             self.assertEqual(len(out_lines), 5)
655             self.assertEqual(len(err_lines), 2)
656             self.assertEqual(out_lines[-1], "wat ignored: no match")
657             self.assertEqual(
658                 unstyle(str(report)),
659                 "2 files reformatted, 2 files left unchanged, "
660                 "2 files failed to reformat.",
661             )
662             self.assertEqual(report.return_code, 123)
663             report.done(Path("f4"), black.Changed.NO)
664             self.assertEqual(len(out_lines), 6)
665             self.assertEqual(len(err_lines), 2)
666             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
667             self.assertEqual(
668                 unstyle(str(report)),
669                 "2 files reformatted, 3 files left unchanged, "
670                 "2 files failed to reformat.",
671             )
672             self.assertEqual(report.return_code, 123)
673             report.check = True
674             self.assertEqual(
675                 unstyle(str(report)),
676                 "2 files would be reformatted, 3 files would be left unchanged, "
677                 "2 files would fail to reformat.",
678             )
679
680     def test_report_quiet(self) -> None:
681         report = black.Report(quiet=True)
682         out_lines = []
683         err_lines = []
684
685         def out(msg: str, **kwargs: Any) -> None:
686             out_lines.append(msg)
687
688         def err(msg: str, **kwargs: Any) -> None:
689             err_lines.append(msg)
690
691         with patch("black.out", out), patch("black.err", err):
692             report.done(Path("f1"), black.Changed.NO)
693             self.assertEqual(len(out_lines), 0)
694             self.assertEqual(len(err_lines), 0)
695             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
696             self.assertEqual(report.return_code, 0)
697             report.done(Path("f2"), black.Changed.YES)
698             self.assertEqual(len(out_lines), 0)
699             self.assertEqual(len(err_lines), 0)
700             self.assertEqual(
701                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
702             )
703             report.done(Path("f3"), black.Changed.CACHED)
704             self.assertEqual(len(out_lines), 0)
705             self.assertEqual(len(err_lines), 0)
706             self.assertEqual(
707                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
708             )
709             self.assertEqual(report.return_code, 0)
710             report.check = True
711             self.assertEqual(report.return_code, 1)
712             report.check = False
713             report.failed(Path("e1"), "boom")
714             self.assertEqual(len(out_lines), 0)
715             self.assertEqual(len(err_lines), 1)
716             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
717             self.assertEqual(
718                 unstyle(str(report)),
719                 "1 file reformatted, 2 files left unchanged, "
720                 "1 file failed to reformat.",
721             )
722             self.assertEqual(report.return_code, 123)
723             report.done(Path("f3"), black.Changed.YES)
724             self.assertEqual(len(out_lines), 0)
725             self.assertEqual(len(err_lines), 1)
726             self.assertEqual(
727                 unstyle(str(report)),
728                 "2 files reformatted, 2 files left unchanged, "
729                 "1 file failed to reformat.",
730             )
731             self.assertEqual(report.return_code, 123)
732             report.failed(Path("e2"), "boom")
733             self.assertEqual(len(out_lines), 0)
734             self.assertEqual(len(err_lines), 2)
735             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
736             self.assertEqual(
737                 unstyle(str(report)),
738                 "2 files reformatted, 2 files left unchanged, "
739                 "2 files failed to reformat.",
740             )
741             self.assertEqual(report.return_code, 123)
742             report.path_ignored(Path("wat"), "no match")
743             self.assertEqual(len(out_lines), 0)
744             self.assertEqual(len(err_lines), 2)
745             self.assertEqual(
746                 unstyle(str(report)),
747                 "2 files reformatted, 2 files left unchanged, "
748                 "2 files failed to reformat.",
749             )
750             self.assertEqual(report.return_code, 123)
751             report.done(Path("f4"), black.Changed.NO)
752             self.assertEqual(len(out_lines), 0)
753             self.assertEqual(len(err_lines), 2)
754             self.assertEqual(
755                 unstyle(str(report)),
756                 "2 files reformatted, 3 files left unchanged, "
757                 "2 files failed to reformat.",
758             )
759             self.assertEqual(report.return_code, 123)
760             report.check = True
761             self.assertEqual(
762                 unstyle(str(report)),
763                 "2 files would be reformatted, 3 files would be left unchanged, "
764                 "2 files would fail to reformat.",
765             )
766
767     def test_report_normal(self) -> None:
768         report = black.Report()
769         out_lines = []
770         err_lines = []
771
772         def out(msg: str, **kwargs: Any) -> None:
773             out_lines.append(msg)
774
775         def err(msg: str, **kwargs: Any) -> None:
776             err_lines.append(msg)
777
778         with patch("black.out", out), patch("black.err", err):
779             report.done(Path("f1"), black.Changed.NO)
780             self.assertEqual(len(out_lines), 0)
781             self.assertEqual(len(err_lines), 0)
782             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
783             self.assertEqual(report.return_code, 0)
784             report.done(Path("f2"), black.Changed.YES)
785             self.assertEqual(len(out_lines), 1)
786             self.assertEqual(len(err_lines), 0)
787             self.assertEqual(out_lines[-1], "reformatted f2")
788             self.assertEqual(
789                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
790             )
791             report.done(Path("f3"), black.Changed.CACHED)
792             self.assertEqual(len(out_lines), 1)
793             self.assertEqual(len(err_lines), 0)
794             self.assertEqual(out_lines[-1], "reformatted f2")
795             self.assertEqual(
796                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
797             )
798             self.assertEqual(report.return_code, 0)
799             report.check = True
800             self.assertEqual(report.return_code, 1)
801             report.check = False
802             report.failed(Path("e1"), "boom")
803             self.assertEqual(len(out_lines), 1)
804             self.assertEqual(len(err_lines), 1)
805             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
806             self.assertEqual(
807                 unstyle(str(report)),
808                 "1 file reformatted, 2 files left unchanged, "
809                 "1 file failed to reformat.",
810             )
811             self.assertEqual(report.return_code, 123)
812             report.done(Path("f3"), black.Changed.YES)
813             self.assertEqual(len(out_lines), 2)
814             self.assertEqual(len(err_lines), 1)
815             self.assertEqual(out_lines[-1], "reformatted f3")
816             self.assertEqual(
817                 unstyle(str(report)),
818                 "2 files reformatted, 2 files left unchanged, "
819                 "1 file failed to reformat.",
820             )
821             self.assertEqual(report.return_code, 123)
822             report.failed(Path("e2"), "boom")
823             self.assertEqual(len(out_lines), 2)
824             self.assertEqual(len(err_lines), 2)
825             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
826             self.assertEqual(
827                 unstyle(str(report)),
828                 "2 files reformatted, 2 files left unchanged, "
829                 "2 files failed to reformat.",
830             )
831             self.assertEqual(report.return_code, 123)
832             report.path_ignored(Path("wat"), "no match")
833             self.assertEqual(len(out_lines), 2)
834             self.assertEqual(len(err_lines), 2)
835             self.assertEqual(
836                 unstyle(str(report)),
837                 "2 files reformatted, 2 files left unchanged, "
838                 "2 files failed to reformat.",
839             )
840             self.assertEqual(report.return_code, 123)
841             report.done(Path("f4"), black.Changed.NO)
842             self.assertEqual(len(out_lines), 2)
843             self.assertEqual(len(err_lines), 2)
844             self.assertEqual(
845                 unstyle(str(report)),
846                 "2 files reformatted, 3 files left unchanged, "
847                 "2 files failed to reformat.",
848             )
849             self.assertEqual(report.return_code, 123)
850             report.check = True
851             self.assertEqual(
852                 unstyle(str(report)),
853                 "2 files would be reformatted, 3 files would be left unchanged, "
854                 "2 files would fail to reformat.",
855             )
856
857     def test_lib2to3_parse(self) -> None:
858         with self.assertRaises(black.InvalidInput):
859             black.lib2to3_parse("invalid syntax")
860
861         straddling = "x + y"
862         black.lib2to3_parse(straddling)
863         black.lib2to3_parse(straddling, {TargetVersion.PY27})
864         black.lib2to3_parse(straddling, {TargetVersion.PY36})
865         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
866
867         py2_only = "print x"
868         black.lib2to3_parse(py2_only)
869         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
870         with self.assertRaises(black.InvalidInput):
871             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
872         with self.assertRaises(black.InvalidInput):
873             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
874
875         py3_only = "exec(x, end=y)"
876         black.lib2to3_parse(py3_only)
877         with self.assertRaises(black.InvalidInput):
878             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
879         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
880         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
881
882     def test_get_features_used(self) -> None:
883         node = black.lib2to3_parse("def f(*, arg): ...\n")
884         self.assertEqual(black.get_features_used(node), set())
885         node = black.lib2to3_parse("def f(*, arg,): ...\n")
886         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
887         node = black.lib2to3_parse("f(*arg,)\n")
888         self.assertEqual(
889             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
890         )
891         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
892         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
893         node = black.lib2to3_parse("123_456\n")
894         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
895         node = black.lib2to3_parse("123456\n")
896         self.assertEqual(black.get_features_used(node), set())
897         source, expected = read_data("function")
898         node = black.lib2to3_parse(source)
899         expected_features = {
900             Feature.TRAILING_COMMA_IN_CALL,
901             Feature.TRAILING_COMMA_IN_DEF,
902             Feature.F_STRINGS,
903         }
904         self.assertEqual(black.get_features_used(node), expected_features)
905         node = black.lib2to3_parse(expected)
906         self.assertEqual(black.get_features_used(node), expected_features)
907         source, expected = read_data("expression")
908         node = black.lib2to3_parse(source)
909         self.assertEqual(black.get_features_used(node), set())
910         node = black.lib2to3_parse(expected)
911         self.assertEqual(black.get_features_used(node), set())
912
913     def test_get_future_imports(self) -> None:
914         node = black.lib2to3_parse("\n")
915         self.assertEqual(set(), black.get_future_imports(node))
916         node = black.lib2to3_parse("from __future__ import black\n")
917         self.assertEqual({"black"}, black.get_future_imports(node))
918         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
919         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
920         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
921         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
922         node = black.lib2to3_parse(
923             "from __future__ import multiple\nfrom __future__ import imports\n"
924         )
925         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
926         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
927         self.assertEqual({"black"}, black.get_future_imports(node))
928         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
929         self.assertEqual({"black"}, black.get_future_imports(node))
930         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
931         self.assertEqual(set(), black.get_future_imports(node))
932         node = black.lib2to3_parse("from some.module import black\n")
933         self.assertEqual(set(), black.get_future_imports(node))
934         node = black.lib2to3_parse(
935             "from __future__ import unicode_literals as _unicode_literals"
936         )
937         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
938         node = black.lib2to3_parse(
939             "from __future__ import unicode_literals as _lol, print"
940         )
941         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
942
943     def test_debug_visitor(self) -> None:
944         source, _ = read_data("debug_visitor.py")
945         expected, _ = read_data("debug_visitor.out")
946         out_lines = []
947         err_lines = []
948
949         def out(msg: str, **kwargs: Any) -> None:
950             out_lines.append(msg)
951
952         def err(msg: str, **kwargs: Any) -> None:
953             err_lines.append(msg)
954
955         with patch("black.out", out), patch("black.err", err):
956             black.DebugVisitor.show(source)
957         actual = "\n".join(out_lines) + "\n"
958         log_name = ""
959         if expected != actual:
960             log_name = black.dump_to_file(*out_lines)
961         self.assertEqual(
962             expected,
963             actual,
964             f"AST print out is different. Actual version dumped to {log_name}",
965         )
966
967     def test_format_file_contents(self) -> None:
968         empty = ""
969         mode = black.FileMode()
970         with self.assertRaises(black.NothingChanged):
971             black.format_file_contents(empty, mode=mode, fast=False)
972         just_nl = "\n"
973         with self.assertRaises(black.NothingChanged):
974             black.format_file_contents(just_nl, mode=mode, fast=False)
975         same = "l = [1, 2, 3]\n"
976         with self.assertRaises(black.NothingChanged):
977             black.format_file_contents(same, mode=mode, fast=False)
978         different = "l = [1,2,3]"
979         expected = same
980         actual = black.format_file_contents(different, mode=mode, fast=False)
981         self.assertEqual(expected, actual)
982         invalid = "return if you can"
983         with self.assertRaises(black.InvalidInput) as e:
984             black.format_file_contents(invalid, mode=mode, fast=False)
985         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
986
987     def test_endmarker(self) -> None:
988         n = black.lib2to3_parse("\n")
989         self.assertEqual(n.type, black.syms.file_input)
990         self.assertEqual(len(n.children), 1)
991         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
992
993     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
994     def test_assertFormatEqual(self) -> None:
995         out_lines = []
996         err_lines = []
997
998         def out(msg: str, **kwargs: Any) -> None:
999             out_lines.append(msg)
1000
1001         def err(msg: str, **kwargs: Any) -> None:
1002             err_lines.append(msg)
1003
1004         with patch("black.out", out), patch("black.err", err):
1005             with self.assertRaises(AssertionError):
1006                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
1007
1008         out_str = "".join(out_lines)
1009         self.assertTrue("Expected tree:" in out_str)
1010         self.assertTrue("Actual tree:" in out_str)
1011         self.assertEqual("".join(err_lines), "")
1012
1013     def test_cache_broken_file(self) -> None:
1014         mode = black.FileMode()
1015         with cache_dir() as workspace:
1016             cache_file = black.get_cache_file(mode)
1017             with cache_file.open("w") as fobj:
1018                 fobj.write("this is not a pickle")
1019             self.assertEqual(black.read_cache(mode), {})
1020             src = (workspace / "test.py").resolve()
1021             with src.open("w") as fobj:
1022                 fobj.write("print('hello')")
1023             self.invokeBlack([str(src)])
1024             cache = black.read_cache(mode)
1025             self.assertIn(src, cache)
1026
1027     def test_cache_single_file_already_cached(self) -> None:
1028         mode = black.FileMode()
1029         with cache_dir() as workspace:
1030             src = (workspace / "test.py").resolve()
1031             with src.open("w") as fobj:
1032                 fobj.write("print('hello')")
1033             black.write_cache({}, [src], mode)
1034             self.invokeBlack([str(src)])
1035             with src.open("r") as fobj:
1036                 self.assertEqual(fobj.read(), "print('hello')")
1037
1038     @event_loop(close=False)
1039     def test_cache_multiple_files(self) -> None:
1040         mode = black.FileMode()
1041         with cache_dir() as workspace, patch(
1042             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1043         ):
1044             one = (workspace / "one.py").resolve()
1045             with one.open("w") as fobj:
1046                 fobj.write("print('hello')")
1047             two = (workspace / "two.py").resolve()
1048             with two.open("w") as fobj:
1049                 fobj.write("print('hello')")
1050             black.write_cache({}, [one], mode)
1051             self.invokeBlack([str(workspace)])
1052             with one.open("r") as fobj:
1053                 self.assertEqual(fobj.read(), "print('hello')")
1054             with two.open("r") as fobj:
1055                 self.assertEqual(fobj.read(), 'print("hello")\n')
1056             cache = black.read_cache(mode)
1057             self.assertIn(one, cache)
1058             self.assertIn(two, cache)
1059
1060     def test_no_cache_when_writeback_diff(self) -> None:
1061         mode = black.FileMode()
1062         with cache_dir() as workspace:
1063             src = (workspace / "test.py").resolve()
1064             with src.open("w") as fobj:
1065                 fobj.write("print('hello')")
1066             self.invokeBlack([str(src), "--diff"])
1067             cache_file = black.get_cache_file(mode)
1068             self.assertFalse(cache_file.exists())
1069
1070     def test_no_cache_when_stdin(self) -> None:
1071         mode = black.FileMode()
1072         with cache_dir():
1073             result = CliRunner().invoke(
1074                 black.main, ["-"], input=BytesIO(b"print('hello')")
1075             )
1076             self.assertEqual(result.exit_code, 0)
1077             cache_file = black.get_cache_file(mode)
1078             self.assertFalse(cache_file.exists())
1079
1080     def test_read_cache_no_cachefile(self) -> None:
1081         mode = black.FileMode()
1082         with cache_dir():
1083             self.assertEqual(black.read_cache(mode), {})
1084
1085     def test_write_cache_read_cache(self) -> None:
1086         mode = black.FileMode()
1087         with cache_dir() as workspace:
1088             src = (workspace / "test.py").resolve()
1089             src.touch()
1090             black.write_cache({}, [src], mode)
1091             cache = black.read_cache(mode)
1092             self.assertIn(src, cache)
1093             self.assertEqual(cache[src], black.get_cache_info(src))
1094
1095     def test_filter_cached(self) -> None:
1096         with TemporaryDirectory() as workspace:
1097             path = Path(workspace)
1098             uncached = (path / "uncached").resolve()
1099             cached = (path / "cached").resolve()
1100             cached_but_changed = (path / "changed").resolve()
1101             uncached.touch()
1102             cached.touch()
1103             cached_but_changed.touch()
1104             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1105             todo, done = black.filter_cached(
1106                 cache, {uncached, cached, cached_but_changed}
1107             )
1108             self.assertEqual(todo, {uncached, cached_but_changed})
1109             self.assertEqual(done, {cached})
1110
1111     def test_write_cache_creates_directory_if_needed(self) -> None:
1112         mode = black.FileMode()
1113         with cache_dir(exists=False) as workspace:
1114             self.assertFalse(workspace.exists())
1115             black.write_cache({}, [], mode)
1116             self.assertTrue(workspace.exists())
1117
1118     @event_loop(close=False)
1119     def test_failed_formatting_does_not_get_cached(self) -> None:
1120         mode = black.FileMode()
1121         with cache_dir() as workspace, patch(
1122             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1123         ):
1124             failing = (workspace / "failing.py").resolve()
1125             with failing.open("w") as fobj:
1126                 fobj.write("not actually python")
1127             clean = (workspace / "clean.py").resolve()
1128             with clean.open("w") as fobj:
1129                 fobj.write('print("hello")\n')
1130             self.invokeBlack([str(workspace)], exit_code=123)
1131             cache = black.read_cache(mode)
1132             self.assertNotIn(failing, cache)
1133             self.assertIn(clean, cache)
1134
1135     def test_write_cache_write_fail(self) -> None:
1136         mode = black.FileMode()
1137         with cache_dir(), patch.object(Path, "open") as mock:
1138             mock.side_effect = OSError
1139             black.write_cache({}, [], mode)
1140
1141     @event_loop(close=False)
1142     def test_check_diff_use_together(self) -> None:
1143         with cache_dir():
1144             # Files which will be reformatted.
1145             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1146             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1147             # Files which will not be reformatted.
1148             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1149             self.invokeBlack([str(src2), "--diff", "--check"])
1150             # Multi file command.
1151             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1152
1153     def test_no_files(self) -> None:
1154         with cache_dir():
1155             # Without an argument, black exits with error code 0.
1156             self.invokeBlack([])
1157
1158     def test_broken_symlink(self) -> None:
1159         with cache_dir() as workspace:
1160             symlink = workspace / "broken_link.py"
1161             try:
1162                 symlink.symlink_to("nonexistent.py")
1163             except OSError as e:
1164                 self.skipTest(f"Can't create symlinks: {e}")
1165             self.invokeBlack([str(workspace.resolve())])
1166
1167     def test_read_cache_line_lengths(self) -> None:
1168         mode = black.FileMode()
1169         short_mode = black.FileMode(line_length=1)
1170         with cache_dir() as workspace:
1171             path = (workspace / "file.py").resolve()
1172             path.touch()
1173             black.write_cache({}, [path], mode)
1174             one = black.read_cache(mode)
1175             self.assertIn(path, one)
1176             two = black.read_cache(short_mode)
1177             self.assertNotIn(path, two)
1178
1179     def test_single_file_force_pyi(self) -> None:
1180         reg_mode = black.FileMode()
1181         pyi_mode = black.FileMode(is_pyi=True)
1182         contents, expected = read_data("force_pyi")
1183         with cache_dir() as workspace:
1184             path = (workspace / "file.py").resolve()
1185             with open(path, "w") as fh:
1186                 fh.write(contents)
1187             self.invokeBlack([str(path), "--pyi"])
1188             with open(path, "r") as fh:
1189                 actual = fh.read()
1190             # verify cache with --pyi is separate
1191             pyi_cache = black.read_cache(pyi_mode)
1192             self.assertIn(path, pyi_cache)
1193             normal_cache = black.read_cache(reg_mode)
1194             self.assertNotIn(path, normal_cache)
1195         self.assertEqual(actual, expected)
1196
1197     @event_loop(close=False)
1198     def test_multi_file_force_pyi(self) -> None:
1199         reg_mode = black.FileMode()
1200         pyi_mode = black.FileMode(is_pyi=True)
1201         contents, expected = read_data("force_pyi")
1202         with cache_dir() as workspace:
1203             paths = [
1204                 (workspace / "file1.py").resolve(),
1205                 (workspace / "file2.py").resolve(),
1206             ]
1207             for path in paths:
1208                 with open(path, "w") as fh:
1209                     fh.write(contents)
1210             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1211             for path in paths:
1212                 with open(path, "r") as fh:
1213                     actual = fh.read()
1214                 self.assertEqual(actual, expected)
1215             # verify cache with --pyi is separate
1216             pyi_cache = black.read_cache(pyi_mode)
1217             normal_cache = black.read_cache(reg_mode)
1218             for path in paths:
1219                 self.assertIn(path, pyi_cache)
1220                 self.assertNotIn(path, normal_cache)
1221
1222     def test_pipe_force_pyi(self) -> None:
1223         source, expected = read_data("force_pyi")
1224         result = CliRunner().invoke(
1225             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1226         )
1227         self.assertEqual(result.exit_code, 0)
1228         actual = result.output
1229         self.assertFormatEqual(actual, expected)
1230
1231     def test_single_file_force_py36(self) -> None:
1232         reg_mode = black.FileMode()
1233         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1234         source, expected = read_data("force_py36")
1235         with cache_dir() as workspace:
1236             path = (workspace / "file.py").resolve()
1237             with open(path, "w") as fh:
1238                 fh.write(source)
1239             self.invokeBlack([str(path), *PY36_ARGS])
1240             with open(path, "r") as fh:
1241                 actual = fh.read()
1242             # verify cache with --target-version is separate
1243             py36_cache = black.read_cache(py36_mode)
1244             self.assertIn(path, py36_cache)
1245             normal_cache = black.read_cache(reg_mode)
1246             self.assertNotIn(path, normal_cache)
1247         self.assertEqual(actual, expected)
1248
1249     @event_loop(close=False)
1250     def test_multi_file_force_py36(self) -> None:
1251         reg_mode = black.FileMode()
1252         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1253         source, expected = read_data("force_py36")
1254         with cache_dir() as workspace:
1255             paths = [
1256                 (workspace / "file1.py").resolve(),
1257                 (workspace / "file2.py").resolve(),
1258             ]
1259             for path in paths:
1260                 with open(path, "w") as fh:
1261                     fh.write(source)
1262             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1263             for path in paths:
1264                 with open(path, "r") as fh:
1265                     actual = fh.read()
1266                 self.assertEqual(actual, expected)
1267             # verify cache with --target-version is separate
1268             pyi_cache = black.read_cache(py36_mode)
1269             normal_cache = black.read_cache(reg_mode)
1270             for path in paths:
1271                 self.assertIn(path, pyi_cache)
1272                 self.assertNotIn(path, normal_cache)
1273
1274     def test_pipe_force_py36(self) -> None:
1275         source, expected = read_data("force_py36")
1276         result = CliRunner().invoke(
1277             black.main,
1278             ["-", "-q", "--target-version=py36"],
1279             input=BytesIO(source.encode("utf8")),
1280         )
1281         self.assertEqual(result.exit_code, 0)
1282         actual = result.output
1283         self.assertFormatEqual(actual, expected)
1284
1285     def test_include_exclude(self) -> None:
1286         path = THIS_DIR / "data" / "include_exclude_tests"
1287         include = re.compile(r"\.pyi?$")
1288         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1289         report = black.Report()
1290         sources: List[Path] = []
1291         expected = [
1292             Path(path / "b/dont_exclude/a.py"),
1293             Path(path / "b/dont_exclude/a.pyi"),
1294         ]
1295         this_abs = THIS_DIR.resolve()
1296         sources.extend(
1297             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1298         )
1299         self.assertEqual(sorted(expected), sorted(sources))
1300
1301     def test_empty_include(self) -> None:
1302         path = THIS_DIR / "data" / "include_exclude_tests"
1303         report = black.Report()
1304         empty = re.compile(r"")
1305         sources: List[Path] = []
1306         expected = [
1307             Path(path / "b/exclude/a.pie"),
1308             Path(path / "b/exclude/a.py"),
1309             Path(path / "b/exclude/a.pyi"),
1310             Path(path / "b/dont_exclude/a.pie"),
1311             Path(path / "b/dont_exclude/a.py"),
1312             Path(path / "b/dont_exclude/a.pyi"),
1313             Path(path / "b/.definitely_exclude/a.pie"),
1314             Path(path / "b/.definitely_exclude/a.py"),
1315             Path(path / "b/.definitely_exclude/a.pyi"),
1316         ]
1317         this_abs = THIS_DIR.resolve()
1318         sources.extend(
1319             black.gen_python_files_in_dir(
1320                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1321             )
1322         )
1323         self.assertEqual(sorted(expected), sorted(sources))
1324
1325     def test_empty_exclude(self) -> None:
1326         path = THIS_DIR / "data" / "include_exclude_tests"
1327         report = black.Report()
1328         empty = re.compile(r"")
1329         sources: List[Path] = []
1330         expected = [
1331             Path(path / "b/dont_exclude/a.py"),
1332             Path(path / "b/dont_exclude/a.pyi"),
1333             Path(path / "b/exclude/a.py"),
1334             Path(path / "b/exclude/a.pyi"),
1335             Path(path / "b/.definitely_exclude/a.py"),
1336             Path(path / "b/.definitely_exclude/a.pyi"),
1337         ]
1338         this_abs = THIS_DIR.resolve()
1339         sources.extend(
1340             black.gen_python_files_in_dir(
1341                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1342             )
1343         )
1344         self.assertEqual(sorted(expected), sorted(sources))
1345
1346     def test_invalid_include_exclude(self) -> None:
1347         for option in ["--include", "--exclude"]:
1348             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1349
1350     def test_preserves_line_endings(self) -> None:
1351         with TemporaryDirectory() as workspace:
1352             test_file = Path(workspace) / "test.py"
1353             for nl in ["\n", "\r\n"]:
1354                 contents = nl.join(["def f(  ):", "    pass"])
1355                 test_file.write_bytes(contents.encode())
1356                 ff(test_file, write_back=black.WriteBack.YES)
1357                 updated_contents: bytes = test_file.read_bytes()
1358                 self.assertIn(nl.encode(), updated_contents)
1359                 if nl == "\n":
1360                     self.assertNotIn(b"\r\n", updated_contents)
1361
1362     def test_preserves_line_endings_via_stdin(self) -> None:
1363         for nl in ["\n", "\r\n"]:
1364             contents = nl.join(["def f(  ):", "    pass"])
1365             runner = BlackRunner()
1366             result = runner.invoke(
1367                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1368             )
1369             self.assertEqual(result.exit_code, 0)
1370             output = runner.stdout_bytes
1371             self.assertIn(nl.encode("utf8"), output)
1372             if nl == "\n":
1373                 self.assertNotIn(b"\r\n", output)
1374
1375     def test_assert_equivalent_different_asts(self) -> None:
1376         with self.assertRaises(AssertionError):
1377             black.assert_equivalent("{}", "None")
1378
1379     def test_symlink_out_of_root_directory(self) -> None:
1380         path = MagicMock()
1381         root = THIS_DIR
1382         child = MagicMock()
1383         include = re.compile(black.DEFAULT_INCLUDES)
1384         exclude = re.compile(black.DEFAULT_EXCLUDES)
1385         report = black.Report()
1386         # `child` should behave like a symlink which resolved path is clearly
1387         # outside of the `root` directory.
1388         path.iterdir.return_value = [child]
1389         child.resolve.return_value = Path("/a/b/c")
1390         child.is_symlink.return_value = True
1391         try:
1392             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1393         except ValueError as ve:
1394             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1395         path.iterdir.assert_called_once()
1396         child.resolve.assert_called_once()
1397         child.is_symlink.assert_called_once()
1398         # `child` should behave like a strange file which resolved path is clearly
1399         # outside of the `root` directory.
1400         child.is_symlink.return_value = False
1401         with self.assertRaises(ValueError):
1402             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1403         path.iterdir.assert_called()
1404         self.assertEqual(path.iterdir.call_count, 2)
1405         child.resolve.assert_called()
1406         self.assertEqual(child.resolve.call_count, 2)
1407         child.is_symlink.assert_called()
1408         self.assertEqual(child.is_symlink.call_count, 2)
1409
1410     def test_shhh_click(self) -> None:
1411         try:
1412             from click import _unicodefun  # type: ignore
1413         except ModuleNotFoundError:
1414             self.skipTest("Incompatible Click version")
1415         if not hasattr(_unicodefun, "_verify_python3_env"):
1416             self.skipTest("Incompatible Click version")
1417         # First, let's see if Click is crashing with a preferred ASCII charset.
1418         with patch("locale.getpreferredencoding") as gpe:
1419             gpe.return_value = "ASCII"
1420             with self.assertRaises(RuntimeError):
1421                 _unicodefun._verify_python3_env()
1422         # Now, let's silence Click...
1423         black.patch_click()
1424         # ...and confirm it's silent.
1425         with patch("locale.getpreferredencoding") as gpe:
1426             gpe.return_value = "ASCII"
1427             try:
1428                 _unicodefun._verify_python3_env()
1429             except RuntimeError as re:
1430                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1431
1432     def test_root_logger_not_used_directly(self) -> None:
1433         def fail(*args: Any, **kwargs: Any) -> None:
1434             self.fail("Record created with root logger")
1435
1436         with patch.multiple(
1437             logging.root,
1438             debug=fail,
1439             info=fail,
1440             warning=fail,
1441             error=fail,
1442             critical=fail,
1443             log=fail,
1444         ):
1445             ff(THIS_FILE)
1446
1447     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1448     @async_test
1449     async def test_blackd_request_needs_formatting(self) -> None:
1450         app = blackd.make_app()
1451         async with TestClient(TestServer(app)) as client:
1452             response = await client.post("/", data=b"print('hello world')")
1453             self.assertEqual(response.status, 200)
1454             self.assertEqual(response.charset, "utf8")
1455             self.assertEqual(await response.read(), b'print("hello world")\n')
1456
1457     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1458     @async_test
1459     async def test_blackd_request_no_change(self) -> None:
1460         app = blackd.make_app()
1461         async with TestClient(TestServer(app)) as client:
1462             response = await client.post("/", data=b'print("hello world")\n')
1463             self.assertEqual(response.status, 204)
1464             self.assertEqual(await response.read(), b"")
1465
1466     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1467     @async_test
1468     async def test_blackd_request_syntax_error(self) -> None:
1469         app = blackd.make_app()
1470         async with TestClient(TestServer(app)) as client:
1471             response = await client.post("/", data=b"what even ( is")
1472             self.assertEqual(response.status, 400)
1473             content = await response.text()
1474             self.assertTrue(
1475                 content.startswith("Cannot parse"),
1476                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1477             )
1478
1479     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1480     @async_test
1481     async def test_blackd_unsupported_version(self) -> None:
1482         app = blackd.make_app()
1483         async with TestClient(TestServer(app)) as client:
1484             response = await client.post(
1485                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1486             )
1487             self.assertEqual(response.status, 501)
1488
1489     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1490     @async_test
1491     async def test_blackd_supported_version(self) -> None:
1492         app = blackd.make_app()
1493         async with TestClient(TestServer(app)) as client:
1494             response = await client.post(
1495                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1496             )
1497             self.assertEqual(response.status, 200)
1498
1499     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1500     @async_test
1501     async def test_blackd_invalid_python_variant(self) -> None:
1502         app = blackd.make_app()
1503         async with TestClient(TestServer(app)) as client:
1504
1505             async def check(header_value: str, expected_status: int = 400) -> None:
1506                 response = await client.post(
1507                     "/",
1508                     data=b"what",
1509                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1510                 )
1511                 self.assertEqual(response.status, expected_status)
1512
1513             await check("lol")
1514             await check("ruby3.5")
1515             await check("pyi3.6")
1516             await check("py1.5")
1517             await check("2.8")
1518             await check("py2.8")
1519             await check("3.0")
1520             await check("pypy3.0")
1521             await check("jython3.4")
1522
1523     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1524     @async_test
1525     async def test_blackd_pyi(self) -> None:
1526         app = blackd.make_app()
1527         async with TestClient(TestServer(app)) as client:
1528             source, expected = read_data("stub.pyi")
1529             response = await client.post(
1530                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1531             )
1532             self.assertEqual(response.status, 200)
1533             self.assertEqual(await response.text(), expected)
1534
1535     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1536     @async_test
1537     async def test_blackd_python_variant(self) -> None:
1538         app = blackd.make_app()
1539         code = (
1540             "def f(\n"
1541             "    and_has_a_bunch_of,\n"
1542             "    very_long_arguments_too,\n"
1543             "    and_lots_of_them_as_well_lol,\n"
1544             "    **and_very_long_keyword_arguments\n"
1545             "):\n"
1546             "    pass\n"
1547         )
1548         async with TestClient(TestServer(app)) as client:
1549
1550             async def check(header_value: str, expected_status: int) -> None:
1551                 response = await client.post(
1552                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1553                 )
1554                 self.assertEqual(response.status, expected_status)
1555
1556             await check("3.6", 200)
1557             await check("py3.6", 200)
1558             await check("3.6,3.7", 200)
1559             await check("3.6,py3.7", 200)
1560
1561             await check("2", 204)
1562             await check("2.7", 204)
1563             await check("py2.7", 204)
1564             await check("3.4", 204)
1565             await check("py3.4", 204)
1566
1567     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1568     @async_test
1569     async def test_blackd_line_length(self) -> None:
1570         app = blackd.make_app()
1571         async with TestClient(TestServer(app)) as client:
1572             response = await client.post(
1573                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1574             )
1575             self.assertEqual(response.status, 200)
1576
1577     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1578     @async_test
1579     async def test_blackd_invalid_line_length(self) -> None:
1580         app = blackd.make_app()
1581         async with TestClient(TestServer(app)) as client:
1582             response = await client.post(
1583                 "/",
1584                 data=b'print("hello")\n',
1585                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1586             )
1587             self.assertEqual(response.status, 400)
1588
1589     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1590     def test_blackd_main(self) -> None:
1591         with patch("blackd.web.run_app"):
1592             result = CliRunner().invoke(blackd.main, [])
1593             if result.exception is not None:
1594                 raise result.exception
1595             self.assertEqual(result.exit_code, 0)
1596
1597
1598 if __name__ == "__main__":
1599     unittest.main(module="test_black")