]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix print() function on Python 2 (#754)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager, redirect_stderr
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_expression(self) -> None:
269         source, expected = read_data("expression")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     def test_expression_ff(self) -> None:
276         source, expected = read_data("expression")
277         tmp_file = Path(black.dump_to_file(source))
278         try:
279             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
280             with open(tmp_file, encoding="utf8") as f:
281                 actual = f.read()
282         finally:
283             os.unlink(tmp_file)
284         self.assertFormatEqual(expected, actual)
285         with patch("black.dump_to_file", dump_to_stderr):
286             black.assert_equivalent(source, actual)
287             black.assert_stable(source, actual, black.FileMode())
288
289     def test_expression_diff(self) -> None:
290         source, _ = read_data("expression.py")
291         expected, _ = read_data("expression.diff")
292         tmp_file = Path(black.dump_to_file(source))
293         diff_header = re.compile(
294             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
295             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
296         )
297         try:
298             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
299             self.assertEqual(result.exit_code, 0)
300         finally:
301             os.unlink(tmp_file)
302         actual = result.output
303         actual = diff_header.sub("[Deterministic header]", actual)
304         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
305         if expected != actual:
306             dump = black.dump_to_file(actual)
307             msg = (
308                 f"Expected diff isn't equal to the actual. If you made changes "
309                 f"to expression.py and this is an anticipated difference, "
310                 f"overwrite tests/data/expression.diff with {dump}"
311             )
312             self.assertEqual(expected, actual, msg)
313
314     @patch("black.dump_to_file", dump_to_stderr)
315     def test_fstring(self) -> None:
316         source, expected = read_data("fstring")
317         actual = fs(source)
318         self.assertFormatEqual(expected, actual)
319         black.assert_equivalent(source, actual)
320         black.assert_stable(source, actual, black.FileMode())
321
322     @patch("black.dump_to_file", dump_to_stderr)
323     def test_string_quotes(self) -> None:
324         source, expected = read_data("string_quotes")
325         actual = fs(source)
326         self.assertFormatEqual(expected, actual)
327         black.assert_equivalent(source, actual)
328         black.assert_stable(source, actual, black.FileMode())
329         mode = black.FileMode(string_normalization=False)
330         not_normalized = fs(source, mode=mode)
331         self.assertFormatEqual(source, not_normalized)
332         black.assert_equivalent(source, not_normalized)
333         black.assert_stable(source, not_normalized, mode=mode)
334
335     @patch("black.dump_to_file", dump_to_stderr)
336     def test_slices(self) -> None:
337         source, expected = read_data("slices")
338         actual = fs(source)
339         self.assertFormatEqual(expected, actual)
340         black.assert_equivalent(source, actual)
341         black.assert_stable(source, actual, black.FileMode())
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_comments(self) -> None:
345         source, expected = read_data("comments")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, black.FileMode())
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_comments2(self) -> None:
353         source, expected = read_data("comments2")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, black.FileMode())
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_comments3(self) -> None:
361         source, expected = read_data("comments3")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, black.FileMode())
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_comments4(self) -> None:
369         source, expected = read_data("comments4")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, black.FileMode())
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_comments5(self) -> None:
377         source, expected = read_data("comments5")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, black.FileMode())
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_comments6(self) -> None:
385         source, expected = read_data("comments6")
386         actual = fs(source)
387         self.assertFormatEqual(expected, actual)
388         black.assert_equivalent(source, actual)
389         black.assert_stable(source, actual, black.FileMode())
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_cantfit(self) -> None:
393         source, expected = read_data("cantfit")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, black.FileMode())
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_import_spacing(self) -> None:
401         source, expected = read_data("import_spacing")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, black.FileMode())
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_composition(self) -> None:
409         source, expected = read_data("composition")
410         actual = fs(source)
411         self.assertFormatEqual(expected, actual)
412         black.assert_equivalent(source, actual)
413         black.assert_stable(source, actual, black.FileMode())
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_empty_lines(self) -> None:
417         source, expected = read_data("empty_lines")
418         actual = fs(source)
419         self.assertFormatEqual(expected, actual)
420         black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, black.FileMode())
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_string_prefixes(self) -> None:
425         source, expected = read_data("string_prefixes")
426         actual = fs(source)
427         self.assertFormatEqual(expected, actual)
428         black.assert_equivalent(source, actual)
429         black.assert_stable(source, actual, black.FileMode())
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_numeric_literals(self) -> None:
433         source, expected = read_data("numeric_literals")
434         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
435         actual = fs(source, mode=mode)
436         self.assertFormatEqual(expected, actual)
437         black.assert_equivalent(source, actual)
438         black.assert_stable(source, actual, mode)
439
440     @patch("black.dump_to_file", dump_to_stderr)
441     def test_numeric_literals_ignoring_underscores(self) -> None:
442         source, expected = read_data("numeric_literals_skip_underscores")
443         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
444         actual = fs(source, mode=mode)
445         self.assertFormatEqual(expected, actual)
446         black.assert_equivalent(source, actual)
447         black.assert_stable(source, actual, mode)
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_numeric_literals_py2(self) -> None:
451         source, expected = read_data("numeric_literals_py2")
452         actual = fs(source)
453         self.assertFormatEqual(expected, actual)
454         black.assert_stable(source, actual, black.FileMode())
455
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_python2(self) -> None:
458         source, expected = read_data("python2")
459         actual = fs(source)
460         self.assertFormatEqual(expected, actual)
461         # black.assert_equivalent(source, actual)
462         black.assert_stable(source, actual, black.FileMode())
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_python2_print_function(self) -> None:
466         source, expected = read_data("python2_print_function")
467         mode = black.FileMode(target_versions={black.TargetVersion.PY27})
468         actual = fs(source, mode=mode)
469         self.assertFormatEqual(expected, actual)
470         black.assert_stable(source, actual, mode)
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_python2_unicode_literals(self) -> None:
474         source, expected = read_data("python2_unicode_literals")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         black.assert_stable(source, actual, black.FileMode())
478
479     @patch("black.dump_to_file", dump_to_stderr)
480     def test_stub(self) -> None:
481         mode = black.FileMode(is_pyi=True)
482         source, expected = read_data("stub.pyi")
483         actual = fs(source, mode=mode)
484         self.assertFormatEqual(expected, actual)
485         black.assert_stable(source, actual, mode)
486
487     @patch("black.dump_to_file", dump_to_stderr)
488     def test_python37(self) -> None:
489         source, expected = read_data("python37")
490         actual = fs(source)
491         self.assertFormatEqual(expected, actual)
492         major, minor = sys.version_info[:2]
493         if major > 3 or (major == 3 and minor >= 7):
494             black.assert_equivalent(source, actual)
495         black.assert_stable(source, actual, black.FileMode())
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_fmtonoff(self) -> None:
499         source, expected = read_data("fmtonoff")
500         actual = fs(source)
501         self.assertFormatEqual(expected, actual)
502         black.assert_equivalent(source, actual)
503         black.assert_stable(source, actual, black.FileMode())
504
505     @patch("black.dump_to_file", dump_to_stderr)
506     def test_fmtonoff2(self) -> None:
507         source, expected = read_data("fmtonoff2")
508         actual = fs(source)
509         self.assertFormatEqual(expected, actual)
510         black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, black.FileMode())
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_remove_empty_parentheses_after_class(self) -> None:
515         source, expected = read_data("class_blank_parentheses")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         black.assert_equivalent(source, actual)
519         black.assert_stable(source, actual, black.FileMode())
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_new_line_between_class_and_code(self) -> None:
523         source, expected = read_data("class_methods_new_line")
524         actual = fs(source)
525         self.assertFormatEqual(expected, actual)
526         black.assert_equivalent(source, actual)
527         black.assert_stable(source, actual, black.FileMode())
528
529     @patch("black.dump_to_file", dump_to_stderr)
530     def test_bracket_match(self) -> None:
531         source, expected = read_data("bracketmatch")
532         actual = fs(source)
533         self.assertFormatEqual(expected, actual)
534         black.assert_equivalent(source, actual)
535         black.assert_stable(source, actual, black.FileMode())
536
537     def test_tab_comment_indentation(self) -> None:
538         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
539         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
540         self.assertFormatEqual(contents_spc, fs(contents_spc))
541         self.assertFormatEqual(contents_spc, fs(contents_tab))
542
543         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
544         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
545         self.assertFormatEqual(contents_spc, fs(contents_spc))
546         self.assertFormatEqual(contents_spc, fs(contents_tab))
547
548         # mixed tabs and spaces (valid Python 2 code)
549         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
550         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
551         self.assertFormatEqual(contents_spc, fs(contents_spc))
552         self.assertFormatEqual(contents_spc, fs(contents_tab))
553
554         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
555         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
556         self.assertFormatEqual(contents_spc, fs(contents_spc))
557         self.assertFormatEqual(contents_spc, fs(contents_tab))
558
559     def test_report_verbose(self) -> None:
560         report = black.Report(verbose=True)
561         out_lines = []
562         err_lines = []
563
564         def out(msg: str, **kwargs: Any) -> None:
565             out_lines.append(msg)
566
567         def err(msg: str, **kwargs: Any) -> None:
568             err_lines.append(msg)
569
570         with patch("black.out", out), patch("black.err", err):
571             report.done(Path("f1"), black.Changed.NO)
572             self.assertEqual(len(out_lines), 1)
573             self.assertEqual(len(err_lines), 0)
574             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
575             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
576             self.assertEqual(report.return_code, 0)
577             report.done(Path("f2"), black.Changed.YES)
578             self.assertEqual(len(out_lines), 2)
579             self.assertEqual(len(err_lines), 0)
580             self.assertEqual(out_lines[-1], "reformatted f2")
581             self.assertEqual(
582                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
583             )
584             report.done(Path("f3"), black.Changed.CACHED)
585             self.assertEqual(len(out_lines), 3)
586             self.assertEqual(len(err_lines), 0)
587             self.assertEqual(
588                 out_lines[-1], "f3 wasn't modified on disk since last run."
589             )
590             self.assertEqual(
591                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
592             )
593             self.assertEqual(report.return_code, 0)
594             report.check = True
595             self.assertEqual(report.return_code, 1)
596             report.check = False
597             report.failed(Path("e1"), "boom")
598             self.assertEqual(len(out_lines), 3)
599             self.assertEqual(len(err_lines), 1)
600             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
601             self.assertEqual(
602                 unstyle(str(report)),
603                 "1 file reformatted, 2 files left unchanged, "
604                 "1 file failed to reformat.",
605             )
606             self.assertEqual(report.return_code, 123)
607             report.done(Path("f3"), black.Changed.YES)
608             self.assertEqual(len(out_lines), 4)
609             self.assertEqual(len(err_lines), 1)
610             self.assertEqual(out_lines[-1], "reformatted f3")
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files reformatted, 2 files left unchanged, "
614                 "1 file failed to reformat.",
615             )
616             self.assertEqual(report.return_code, 123)
617             report.failed(Path("e2"), "boom")
618             self.assertEqual(len(out_lines), 4)
619             self.assertEqual(len(err_lines), 2)
620             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
621             self.assertEqual(
622                 unstyle(str(report)),
623                 "2 files reformatted, 2 files left unchanged, "
624                 "2 files failed to reformat.",
625             )
626             self.assertEqual(report.return_code, 123)
627             report.path_ignored(Path("wat"), "no match")
628             self.assertEqual(len(out_lines), 5)
629             self.assertEqual(len(err_lines), 2)
630             self.assertEqual(out_lines[-1], "wat ignored: no match")
631             self.assertEqual(
632                 unstyle(str(report)),
633                 "2 files reformatted, 2 files left unchanged, "
634                 "2 files failed to reformat.",
635             )
636             self.assertEqual(report.return_code, 123)
637             report.done(Path("f4"), black.Changed.NO)
638             self.assertEqual(len(out_lines), 6)
639             self.assertEqual(len(err_lines), 2)
640             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
641             self.assertEqual(
642                 unstyle(str(report)),
643                 "2 files reformatted, 3 files left unchanged, "
644                 "2 files failed to reformat.",
645             )
646             self.assertEqual(report.return_code, 123)
647             report.check = True
648             self.assertEqual(
649                 unstyle(str(report)),
650                 "2 files would be reformatted, 3 files would be left unchanged, "
651                 "2 files would fail to reformat.",
652             )
653
654     def test_report_quiet(self) -> None:
655         report = black.Report(quiet=True)
656         out_lines = []
657         err_lines = []
658
659         def out(msg: str, **kwargs: Any) -> None:
660             out_lines.append(msg)
661
662         def err(msg: str, **kwargs: Any) -> None:
663             err_lines.append(msg)
664
665         with patch("black.out", out), patch("black.err", err):
666             report.done(Path("f1"), black.Changed.NO)
667             self.assertEqual(len(out_lines), 0)
668             self.assertEqual(len(err_lines), 0)
669             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
670             self.assertEqual(report.return_code, 0)
671             report.done(Path("f2"), black.Changed.YES)
672             self.assertEqual(len(out_lines), 0)
673             self.assertEqual(len(err_lines), 0)
674             self.assertEqual(
675                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
676             )
677             report.done(Path("f3"), black.Changed.CACHED)
678             self.assertEqual(len(out_lines), 0)
679             self.assertEqual(len(err_lines), 0)
680             self.assertEqual(
681                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
682             )
683             self.assertEqual(report.return_code, 0)
684             report.check = True
685             self.assertEqual(report.return_code, 1)
686             report.check = False
687             report.failed(Path("e1"), "boom")
688             self.assertEqual(len(out_lines), 0)
689             self.assertEqual(len(err_lines), 1)
690             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
691             self.assertEqual(
692                 unstyle(str(report)),
693                 "1 file reformatted, 2 files left unchanged, "
694                 "1 file failed to reformat.",
695             )
696             self.assertEqual(report.return_code, 123)
697             report.done(Path("f3"), black.Changed.YES)
698             self.assertEqual(len(out_lines), 0)
699             self.assertEqual(len(err_lines), 1)
700             self.assertEqual(
701                 unstyle(str(report)),
702                 "2 files reformatted, 2 files left unchanged, "
703                 "1 file failed to reformat.",
704             )
705             self.assertEqual(report.return_code, 123)
706             report.failed(Path("e2"), "boom")
707             self.assertEqual(len(out_lines), 0)
708             self.assertEqual(len(err_lines), 2)
709             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
710             self.assertEqual(
711                 unstyle(str(report)),
712                 "2 files reformatted, 2 files left unchanged, "
713                 "2 files failed to reformat.",
714             )
715             self.assertEqual(report.return_code, 123)
716             report.path_ignored(Path("wat"), "no match")
717             self.assertEqual(len(out_lines), 0)
718             self.assertEqual(len(err_lines), 2)
719             self.assertEqual(
720                 unstyle(str(report)),
721                 "2 files reformatted, 2 files left unchanged, "
722                 "2 files failed to reformat.",
723             )
724             self.assertEqual(report.return_code, 123)
725             report.done(Path("f4"), black.Changed.NO)
726             self.assertEqual(len(out_lines), 0)
727             self.assertEqual(len(err_lines), 2)
728             self.assertEqual(
729                 unstyle(str(report)),
730                 "2 files reformatted, 3 files left unchanged, "
731                 "2 files failed to reformat.",
732             )
733             self.assertEqual(report.return_code, 123)
734             report.check = True
735             self.assertEqual(
736                 unstyle(str(report)),
737                 "2 files would be reformatted, 3 files would be left unchanged, "
738                 "2 files would fail to reformat.",
739             )
740
741     def test_report_normal(self) -> None:
742         report = black.Report()
743         out_lines = []
744         err_lines = []
745
746         def out(msg: str, **kwargs: Any) -> None:
747             out_lines.append(msg)
748
749         def err(msg: str, **kwargs: Any) -> None:
750             err_lines.append(msg)
751
752         with patch("black.out", out), patch("black.err", err):
753             report.done(Path("f1"), black.Changed.NO)
754             self.assertEqual(len(out_lines), 0)
755             self.assertEqual(len(err_lines), 0)
756             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
757             self.assertEqual(report.return_code, 0)
758             report.done(Path("f2"), black.Changed.YES)
759             self.assertEqual(len(out_lines), 1)
760             self.assertEqual(len(err_lines), 0)
761             self.assertEqual(out_lines[-1], "reformatted f2")
762             self.assertEqual(
763                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
764             )
765             report.done(Path("f3"), black.Changed.CACHED)
766             self.assertEqual(len(out_lines), 1)
767             self.assertEqual(len(err_lines), 0)
768             self.assertEqual(out_lines[-1], "reformatted f2")
769             self.assertEqual(
770                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
771             )
772             self.assertEqual(report.return_code, 0)
773             report.check = True
774             self.assertEqual(report.return_code, 1)
775             report.check = False
776             report.failed(Path("e1"), "boom")
777             self.assertEqual(len(out_lines), 1)
778             self.assertEqual(len(err_lines), 1)
779             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
780             self.assertEqual(
781                 unstyle(str(report)),
782                 "1 file reformatted, 2 files left unchanged, "
783                 "1 file failed to reformat.",
784             )
785             self.assertEqual(report.return_code, 123)
786             report.done(Path("f3"), black.Changed.YES)
787             self.assertEqual(len(out_lines), 2)
788             self.assertEqual(len(err_lines), 1)
789             self.assertEqual(out_lines[-1], "reformatted f3")
790             self.assertEqual(
791                 unstyle(str(report)),
792                 "2 files reformatted, 2 files left unchanged, "
793                 "1 file failed to reformat.",
794             )
795             self.assertEqual(report.return_code, 123)
796             report.failed(Path("e2"), "boom")
797             self.assertEqual(len(out_lines), 2)
798             self.assertEqual(len(err_lines), 2)
799             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
800             self.assertEqual(
801                 unstyle(str(report)),
802                 "2 files reformatted, 2 files left unchanged, "
803                 "2 files failed to reformat.",
804             )
805             self.assertEqual(report.return_code, 123)
806             report.path_ignored(Path("wat"), "no match")
807             self.assertEqual(len(out_lines), 2)
808             self.assertEqual(len(err_lines), 2)
809             self.assertEqual(
810                 unstyle(str(report)),
811                 "2 files reformatted, 2 files left unchanged, "
812                 "2 files failed to reformat.",
813             )
814             self.assertEqual(report.return_code, 123)
815             report.done(Path("f4"), black.Changed.NO)
816             self.assertEqual(len(out_lines), 2)
817             self.assertEqual(len(err_lines), 2)
818             self.assertEqual(
819                 unstyle(str(report)),
820                 "2 files reformatted, 3 files left unchanged, "
821                 "2 files failed to reformat.",
822             )
823             self.assertEqual(report.return_code, 123)
824             report.check = True
825             self.assertEqual(
826                 unstyle(str(report)),
827                 "2 files would be reformatted, 3 files would be left unchanged, "
828                 "2 files would fail to reformat.",
829             )
830
831     def test_get_features_used(self) -> None:
832         node = black.lib2to3_parse("def f(*, arg): ...\n")
833         self.assertEqual(black.get_features_used(node), set())
834         node = black.lib2to3_parse("def f(*, arg,): ...\n")
835         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
836         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
837         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
838         node = black.lib2to3_parse("123_456\n")
839         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
840         node = black.lib2to3_parse("123456\n")
841         self.assertEqual(black.get_features_used(node), set())
842         source, expected = read_data("function")
843         node = black.lib2to3_parse(source)
844         self.assertEqual(
845             black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
846         )
847         node = black.lib2to3_parse(expected)
848         self.assertEqual(
849             black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
850         )
851         source, expected = read_data("expression")
852         node = black.lib2to3_parse(source)
853         self.assertEqual(black.get_features_used(node), set())
854         node = black.lib2to3_parse(expected)
855         self.assertEqual(black.get_features_used(node), set())
856
857     def test_get_future_imports(self) -> None:
858         node = black.lib2to3_parse("\n")
859         self.assertEqual(set(), black.get_future_imports(node))
860         node = black.lib2to3_parse("from __future__ import black\n")
861         self.assertEqual({"black"}, black.get_future_imports(node))
862         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
863         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
864         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
865         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
866         node = black.lib2to3_parse(
867             "from __future__ import multiple\nfrom __future__ import imports\n"
868         )
869         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
870         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
871         self.assertEqual({"black"}, black.get_future_imports(node))
872         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
873         self.assertEqual({"black"}, black.get_future_imports(node))
874         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
875         self.assertEqual(set(), black.get_future_imports(node))
876         node = black.lib2to3_parse("from some.module import black\n")
877         self.assertEqual(set(), black.get_future_imports(node))
878         node = black.lib2to3_parse(
879             "from __future__ import unicode_literals as _unicode_literals"
880         )
881         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
882         node = black.lib2to3_parse(
883             "from __future__ import unicode_literals as _lol, print"
884         )
885         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
886
887     def test_debug_visitor(self) -> None:
888         source, _ = read_data("debug_visitor.py")
889         expected, _ = read_data("debug_visitor.out")
890         out_lines = []
891         err_lines = []
892
893         def out(msg: str, **kwargs: Any) -> None:
894             out_lines.append(msg)
895
896         def err(msg: str, **kwargs: Any) -> None:
897             err_lines.append(msg)
898
899         with patch("black.out", out), patch("black.err", err):
900             black.DebugVisitor.show(source)
901         actual = "\n".join(out_lines) + "\n"
902         log_name = ""
903         if expected != actual:
904             log_name = black.dump_to_file(*out_lines)
905         self.assertEqual(
906             expected,
907             actual,
908             f"AST print out is different. Actual version dumped to {log_name}",
909         )
910
911     def test_format_file_contents(self) -> None:
912         empty = ""
913         mode = black.FileMode()
914         with self.assertRaises(black.NothingChanged):
915             black.format_file_contents(empty, mode=mode, fast=False)
916         just_nl = "\n"
917         with self.assertRaises(black.NothingChanged):
918             black.format_file_contents(just_nl, mode=mode, fast=False)
919         same = "l = [1, 2, 3]\n"
920         with self.assertRaises(black.NothingChanged):
921             black.format_file_contents(same, mode=mode, fast=False)
922         different = "l = [1,2,3]"
923         expected = same
924         actual = black.format_file_contents(different, mode=mode, fast=False)
925         self.assertEqual(expected, actual)
926         invalid = "return if you can"
927         with self.assertRaises(black.InvalidInput) as e:
928             black.format_file_contents(invalid, mode=mode, fast=False)
929         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
930
931     def test_endmarker(self) -> None:
932         n = black.lib2to3_parse("\n")
933         self.assertEqual(n.type, black.syms.file_input)
934         self.assertEqual(len(n.children), 1)
935         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
936
937     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
938     def test_assertFormatEqual(self) -> None:
939         out_lines = []
940         err_lines = []
941
942         def out(msg: str, **kwargs: Any) -> None:
943             out_lines.append(msg)
944
945         def err(msg: str, **kwargs: Any) -> None:
946             err_lines.append(msg)
947
948         with patch("black.out", out), patch("black.err", err):
949             with self.assertRaises(AssertionError):
950                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
951
952         out_str = "".join(out_lines)
953         self.assertTrue("Expected tree:" in out_str)
954         self.assertTrue("Actual tree:" in out_str)
955         self.assertEqual("".join(err_lines), "")
956
957     def test_cache_broken_file(self) -> None:
958         mode = black.FileMode()
959         with cache_dir() as workspace:
960             cache_file = black.get_cache_file(mode)
961             with cache_file.open("w") as fobj:
962                 fobj.write("this is not a pickle")
963             self.assertEqual(black.read_cache(mode), {})
964             src = (workspace / "test.py").resolve()
965             with src.open("w") as fobj:
966                 fobj.write("print('hello')")
967             self.invokeBlack([str(src)])
968             cache = black.read_cache(mode)
969             self.assertIn(src, cache)
970
971     def test_cache_single_file_already_cached(self) -> None:
972         mode = black.FileMode()
973         with cache_dir() as workspace:
974             src = (workspace / "test.py").resolve()
975             with src.open("w") as fobj:
976                 fobj.write("print('hello')")
977             black.write_cache({}, [src], mode)
978             self.invokeBlack([str(src)])
979             with src.open("r") as fobj:
980                 self.assertEqual(fobj.read(), "print('hello')")
981
982     @event_loop(close=False)
983     def test_cache_multiple_files(self) -> None:
984         mode = black.FileMode()
985         with cache_dir() as workspace, patch(
986             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
987         ):
988             one = (workspace / "one.py").resolve()
989             with one.open("w") as fobj:
990                 fobj.write("print('hello')")
991             two = (workspace / "two.py").resolve()
992             with two.open("w") as fobj:
993                 fobj.write("print('hello')")
994             black.write_cache({}, [one], mode)
995             self.invokeBlack([str(workspace)])
996             with one.open("r") as fobj:
997                 self.assertEqual(fobj.read(), "print('hello')")
998             with two.open("r") as fobj:
999                 self.assertEqual(fobj.read(), 'print("hello")\n')
1000             cache = black.read_cache(mode)
1001             self.assertIn(one, cache)
1002             self.assertIn(two, cache)
1003
1004     def test_no_cache_when_writeback_diff(self) -> None:
1005         mode = black.FileMode()
1006         with cache_dir() as workspace:
1007             src = (workspace / "test.py").resolve()
1008             with src.open("w") as fobj:
1009                 fobj.write("print('hello')")
1010             self.invokeBlack([str(src), "--diff"])
1011             cache_file = black.get_cache_file(mode)
1012             self.assertFalse(cache_file.exists())
1013
1014     def test_no_cache_when_stdin(self) -> None:
1015         mode = black.FileMode()
1016         with cache_dir():
1017             result = CliRunner().invoke(
1018                 black.main, ["-"], input=BytesIO(b"print('hello')")
1019             )
1020             self.assertEqual(result.exit_code, 0)
1021             cache_file = black.get_cache_file(mode)
1022             self.assertFalse(cache_file.exists())
1023
1024     def test_read_cache_no_cachefile(self) -> None:
1025         mode = black.FileMode()
1026         with cache_dir():
1027             self.assertEqual(black.read_cache(mode), {})
1028
1029     def test_write_cache_read_cache(self) -> None:
1030         mode = black.FileMode()
1031         with cache_dir() as workspace:
1032             src = (workspace / "test.py").resolve()
1033             src.touch()
1034             black.write_cache({}, [src], mode)
1035             cache = black.read_cache(mode)
1036             self.assertIn(src, cache)
1037             self.assertEqual(cache[src], black.get_cache_info(src))
1038
1039     def test_filter_cached(self) -> None:
1040         with TemporaryDirectory() as workspace:
1041             path = Path(workspace)
1042             uncached = (path / "uncached").resolve()
1043             cached = (path / "cached").resolve()
1044             cached_but_changed = (path / "changed").resolve()
1045             uncached.touch()
1046             cached.touch()
1047             cached_but_changed.touch()
1048             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1049             todo, done = black.filter_cached(
1050                 cache, {uncached, cached, cached_but_changed}
1051             )
1052             self.assertEqual(todo, {uncached, cached_but_changed})
1053             self.assertEqual(done, {cached})
1054
1055     def test_write_cache_creates_directory_if_needed(self) -> None:
1056         mode = black.FileMode()
1057         with cache_dir(exists=False) as workspace:
1058             self.assertFalse(workspace.exists())
1059             black.write_cache({}, [], mode)
1060             self.assertTrue(workspace.exists())
1061
1062     @event_loop(close=False)
1063     def test_failed_formatting_does_not_get_cached(self) -> None:
1064         mode = black.FileMode()
1065         with cache_dir() as workspace, patch(
1066             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1067         ):
1068             failing = (workspace / "failing.py").resolve()
1069             with failing.open("w") as fobj:
1070                 fobj.write("not actually python")
1071             clean = (workspace / "clean.py").resolve()
1072             with clean.open("w") as fobj:
1073                 fobj.write('print("hello")\n')
1074             self.invokeBlack([str(workspace)], exit_code=123)
1075             cache = black.read_cache(mode)
1076             self.assertNotIn(failing, cache)
1077             self.assertIn(clean, cache)
1078
1079     def test_write_cache_write_fail(self) -> None:
1080         mode = black.FileMode()
1081         with cache_dir(), patch.object(Path, "open") as mock:
1082             mock.side_effect = OSError
1083             black.write_cache({}, [], mode)
1084
1085     @event_loop(close=False)
1086     def test_check_diff_use_together(self) -> None:
1087         with cache_dir():
1088             # Files which will be reformatted.
1089             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1090             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1091             # Files which will not be reformatted.
1092             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1093             self.invokeBlack([str(src2), "--diff", "--check"])
1094             # Multi file command.
1095             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1096
1097     def test_no_files(self) -> None:
1098         with cache_dir():
1099             # Without an argument, black exits with error code 0.
1100             self.invokeBlack([])
1101
1102     def test_broken_symlink(self) -> None:
1103         with cache_dir() as workspace:
1104             symlink = workspace / "broken_link.py"
1105             try:
1106                 symlink.symlink_to("nonexistent.py")
1107             except OSError as e:
1108                 self.skipTest(f"Can't create symlinks: {e}")
1109             self.invokeBlack([str(workspace.resolve())])
1110
1111     def test_read_cache_line_lengths(self) -> None:
1112         mode = black.FileMode()
1113         short_mode = black.FileMode(line_length=1)
1114         with cache_dir() as workspace:
1115             path = (workspace / "file.py").resolve()
1116             path.touch()
1117             black.write_cache({}, [path], mode)
1118             one = black.read_cache(mode)
1119             self.assertIn(path, one)
1120             two = black.read_cache(short_mode)
1121             self.assertNotIn(path, two)
1122
1123     def test_single_file_force_pyi(self) -> None:
1124         reg_mode = black.FileMode()
1125         pyi_mode = black.FileMode(is_pyi=True)
1126         contents, expected = read_data("force_pyi")
1127         with cache_dir() as workspace:
1128             path = (workspace / "file.py").resolve()
1129             with open(path, "w") as fh:
1130                 fh.write(contents)
1131             self.invokeBlack([str(path), "--pyi"])
1132             with open(path, "r") as fh:
1133                 actual = fh.read()
1134             # verify cache with --pyi is separate
1135             pyi_cache = black.read_cache(pyi_mode)
1136             self.assertIn(path, pyi_cache)
1137             normal_cache = black.read_cache(reg_mode)
1138             self.assertNotIn(path, normal_cache)
1139         self.assertEqual(actual, expected)
1140
1141     @event_loop(close=False)
1142     def test_multi_file_force_pyi(self) -> None:
1143         reg_mode = black.FileMode()
1144         pyi_mode = black.FileMode(is_pyi=True)
1145         contents, expected = read_data("force_pyi")
1146         with cache_dir() as workspace:
1147             paths = [
1148                 (workspace / "file1.py").resolve(),
1149                 (workspace / "file2.py").resolve(),
1150             ]
1151             for path in paths:
1152                 with open(path, "w") as fh:
1153                     fh.write(contents)
1154             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1155             for path in paths:
1156                 with open(path, "r") as fh:
1157                     actual = fh.read()
1158                 self.assertEqual(actual, expected)
1159             # verify cache with --pyi is separate
1160             pyi_cache = black.read_cache(pyi_mode)
1161             normal_cache = black.read_cache(reg_mode)
1162             for path in paths:
1163                 self.assertIn(path, pyi_cache)
1164                 self.assertNotIn(path, normal_cache)
1165
1166     def test_pipe_force_pyi(self) -> None:
1167         source, expected = read_data("force_pyi")
1168         result = CliRunner().invoke(
1169             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1170         )
1171         self.assertEqual(result.exit_code, 0)
1172         actual = result.output
1173         self.assertFormatEqual(actual, expected)
1174
1175     def test_single_file_force_py36(self) -> None:
1176         reg_mode = black.FileMode()
1177         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1178         source, expected = read_data("force_py36")
1179         with cache_dir() as workspace:
1180             path = (workspace / "file.py").resolve()
1181             with open(path, "w") as fh:
1182                 fh.write(source)
1183             self.invokeBlack([str(path), *PY36_ARGS])
1184             with open(path, "r") as fh:
1185                 actual = fh.read()
1186             # verify cache with --target-version is separate
1187             py36_cache = black.read_cache(py36_mode)
1188             self.assertIn(path, py36_cache)
1189             normal_cache = black.read_cache(reg_mode)
1190             self.assertNotIn(path, normal_cache)
1191         self.assertEqual(actual, expected)
1192
1193     @event_loop(close=False)
1194     def test_multi_file_force_py36(self) -> None:
1195         reg_mode = black.FileMode()
1196         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1197         source, expected = read_data("force_py36")
1198         with cache_dir() as workspace:
1199             paths = [
1200                 (workspace / "file1.py").resolve(),
1201                 (workspace / "file2.py").resolve(),
1202             ]
1203             for path in paths:
1204                 with open(path, "w") as fh:
1205                     fh.write(source)
1206             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1207             for path in paths:
1208                 with open(path, "r") as fh:
1209                     actual = fh.read()
1210                 self.assertEqual(actual, expected)
1211             # verify cache with --target-version is separate
1212             pyi_cache = black.read_cache(py36_mode)
1213             normal_cache = black.read_cache(reg_mode)
1214             for path in paths:
1215                 self.assertIn(path, pyi_cache)
1216                 self.assertNotIn(path, normal_cache)
1217
1218     def test_pipe_force_py36(self) -> None:
1219         source, expected = read_data("force_py36")
1220         result = CliRunner().invoke(
1221             black.main,
1222             ["-", "-q", "--target-version=py36"],
1223             input=BytesIO(source.encode("utf8")),
1224         )
1225         self.assertEqual(result.exit_code, 0)
1226         actual = result.output
1227         self.assertFormatEqual(actual, expected)
1228
1229     def test_include_exclude(self) -> None:
1230         path = THIS_DIR / "data" / "include_exclude_tests"
1231         include = re.compile(r"\.pyi?$")
1232         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1233         report = black.Report()
1234         sources: List[Path] = []
1235         expected = [
1236             Path(path / "b/dont_exclude/a.py"),
1237             Path(path / "b/dont_exclude/a.pyi"),
1238         ]
1239         this_abs = THIS_DIR.resolve()
1240         sources.extend(
1241             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1242         )
1243         self.assertEqual(sorted(expected), sorted(sources))
1244
1245     def test_empty_include(self) -> None:
1246         path = THIS_DIR / "data" / "include_exclude_tests"
1247         report = black.Report()
1248         empty = re.compile(r"")
1249         sources: List[Path] = []
1250         expected = [
1251             Path(path / "b/exclude/a.pie"),
1252             Path(path / "b/exclude/a.py"),
1253             Path(path / "b/exclude/a.pyi"),
1254             Path(path / "b/dont_exclude/a.pie"),
1255             Path(path / "b/dont_exclude/a.py"),
1256             Path(path / "b/dont_exclude/a.pyi"),
1257             Path(path / "b/.definitely_exclude/a.pie"),
1258             Path(path / "b/.definitely_exclude/a.py"),
1259             Path(path / "b/.definitely_exclude/a.pyi"),
1260         ]
1261         this_abs = THIS_DIR.resolve()
1262         sources.extend(
1263             black.gen_python_files_in_dir(
1264                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1265             )
1266         )
1267         self.assertEqual(sorted(expected), sorted(sources))
1268
1269     def test_empty_exclude(self) -> None:
1270         path = THIS_DIR / "data" / "include_exclude_tests"
1271         report = black.Report()
1272         empty = re.compile(r"")
1273         sources: List[Path] = []
1274         expected = [
1275             Path(path / "b/dont_exclude/a.py"),
1276             Path(path / "b/dont_exclude/a.pyi"),
1277             Path(path / "b/exclude/a.py"),
1278             Path(path / "b/exclude/a.pyi"),
1279             Path(path / "b/.definitely_exclude/a.py"),
1280             Path(path / "b/.definitely_exclude/a.pyi"),
1281         ]
1282         this_abs = THIS_DIR.resolve()
1283         sources.extend(
1284             black.gen_python_files_in_dir(
1285                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1286             )
1287         )
1288         self.assertEqual(sorted(expected), sorted(sources))
1289
1290     def test_invalid_include_exclude(self) -> None:
1291         for option in ["--include", "--exclude"]:
1292             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1293
1294     def test_preserves_line_endings(self) -> None:
1295         with TemporaryDirectory() as workspace:
1296             test_file = Path(workspace) / "test.py"
1297             for nl in ["\n", "\r\n"]:
1298                 contents = nl.join(["def f(  ):", "    pass"])
1299                 test_file.write_bytes(contents.encode())
1300                 ff(test_file, write_back=black.WriteBack.YES)
1301                 updated_contents: bytes = test_file.read_bytes()
1302                 self.assertIn(nl.encode(), updated_contents)
1303                 if nl == "\n":
1304                     self.assertNotIn(b"\r\n", updated_contents)
1305
1306     def test_preserves_line_endings_via_stdin(self) -> None:
1307         for nl in ["\n", "\r\n"]:
1308             contents = nl.join(["def f(  ):", "    pass"])
1309             runner = BlackRunner()
1310             result = runner.invoke(
1311                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1312             )
1313             self.assertEqual(result.exit_code, 0)
1314             output = runner.stdout_bytes
1315             self.assertIn(nl.encode("utf8"), output)
1316             if nl == "\n":
1317                 self.assertNotIn(b"\r\n", output)
1318
1319     def test_assert_equivalent_different_asts(self) -> None:
1320         with self.assertRaises(AssertionError):
1321             black.assert_equivalent("{}", "None")
1322
1323     def test_symlink_out_of_root_directory(self) -> None:
1324         path = MagicMock()
1325         root = THIS_DIR
1326         child = MagicMock()
1327         include = re.compile(black.DEFAULT_INCLUDES)
1328         exclude = re.compile(black.DEFAULT_EXCLUDES)
1329         report = black.Report()
1330         # `child` should behave like a symlink which resolved path is clearly
1331         # outside of the `root` directory.
1332         path.iterdir.return_value = [child]
1333         child.resolve.return_value = Path("/a/b/c")
1334         child.is_symlink.return_value = True
1335         try:
1336             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1337         except ValueError as ve:
1338             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1339         path.iterdir.assert_called_once()
1340         child.resolve.assert_called_once()
1341         child.is_symlink.assert_called_once()
1342         # `child` should behave like a strange file which resolved path is clearly
1343         # outside of the `root` directory.
1344         child.is_symlink.return_value = False
1345         with self.assertRaises(ValueError):
1346             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1347         path.iterdir.assert_called()
1348         self.assertEqual(path.iterdir.call_count, 2)
1349         child.resolve.assert_called()
1350         self.assertEqual(child.resolve.call_count, 2)
1351         child.is_symlink.assert_called()
1352         self.assertEqual(child.is_symlink.call_count, 2)
1353
1354     def test_shhh_click(self) -> None:
1355         try:
1356             from click import _unicodefun  # type: ignore
1357         except ModuleNotFoundError:
1358             self.skipTest("Incompatible Click version")
1359         if not hasattr(_unicodefun, "_verify_python3_env"):
1360             self.skipTest("Incompatible Click version")
1361         # First, let's see if Click is crashing with a preferred ASCII charset.
1362         with patch("locale.getpreferredencoding") as gpe:
1363             gpe.return_value = "ASCII"
1364             with self.assertRaises(RuntimeError):
1365                 _unicodefun._verify_python3_env()
1366         # Now, let's silence Click...
1367         black.patch_click()
1368         # ...and confirm it's silent.
1369         with patch("locale.getpreferredencoding") as gpe:
1370             gpe.return_value = "ASCII"
1371             try:
1372                 _unicodefun._verify_python3_env()
1373             except RuntimeError as re:
1374                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1375
1376     def test_root_logger_not_used_directly(self) -> None:
1377         def fail(*args: Any, **kwargs: Any) -> None:
1378             self.fail("Record created with root logger")
1379
1380         with patch.multiple(
1381             logging.root,
1382             debug=fail,
1383             info=fail,
1384             warning=fail,
1385             error=fail,
1386             critical=fail,
1387             log=fail,
1388         ):
1389             ff(THIS_FILE)
1390
1391     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1392     @async_test
1393     async def test_blackd_request_needs_formatting(self) -> None:
1394         app = blackd.make_app()
1395         async with TestClient(TestServer(app)) as client:
1396             response = await client.post("/", data=b"print('hello world')")
1397             self.assertEqual(response.status, 200)
1398             self.assertEqual(response.charset, "utf8")
1399             self.assertEqual(await response.read(), b'print("hello world")\n')
1400
1401     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1402     @async_test
1403     async def test_blackd_request_no_change(self) -> None:
1404         app = blackd.make_app()
1405         async with TestClient(TestServer(app)) as client:
1406             response = await client.post("/", data=b'print("hello world")\n')
1407             self.assertEqual(response.status, 204)
1408             self.assertEqual(await response.read(), b"")
1409
1410     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1411     @async_test
1412     async def test_blackd_request_syntax_error(self) -> None:
1413         app = blackd.make_app()
1414         async with TestClient(TestServer(app)) as client:
1415             response = await client.post("/", data=b"what even ( is")
1416             self.assertEqual(response.status, 400)
1417             content = await response.text()
1418             self.assertTrue(
1419                 content.startswith("Cannot parse"),
1420                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1421             )
1422
1423     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1424     @async_test
1425     async def test_blackd_unsupported_version(self) -> None:
1426         app = blackd.make_app()
1427         async with TestClient(TestServer(app)) as client:
1428             response = await client.post(
1429                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1430             )
1431             self.assertEqual(response.status, 501)
1432
1433     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1434     @async_test
1435     async def test_blackd_supported_version(self) -> None:
1436         app = blackd.make_app()
1437         async with TestClient(TestServer(app)) as client:
1438             response = await client.post(
1439                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1440             )
1441             self.assertEqual(response.status, 200)
1442
1443     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1444     @async_test
1445     async def test_blackd_invalid_python_variant(self) -> None:
1446         app = blackd.make_app()
1447         async with TestClient(TestServer(app)) as client:
1448
1449             async def check(header_value: str, expected_status: int = 400) -> None:
1450                 response = await client.post(
1451                     "/",
1452                     data=b"what",
1453                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1454                 )
1455                 self.assertEqual(response.status, expected_status)
1456
1457             await check("lol")
1458             await check("ruby3.5")
1459             await check("pyi3.6")
1460             await check("py1.5")
1461             await check("2.8")
1462             await check("py2.8")
1463             await check("3.0")
1464             await check("pypy3.0")
1465             await check("jython3.4")
1466
1467     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1468     @async_test
1469     async def test_blackd_pyi(self) -> None:
1470         app = blackd.make_app()
1471         async with TestClient(TestServer(app)) as client:
1472             source, expected = read_data("stub.pyi")
1473             response = await client.post(
1474                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1475             )
1476             self.assertEqual(response.status, 200)
1477             self.assertEqual(await response.text(), expected)
1478
1479     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1480     @async_test
1481     async def test_blackd_python_variant(self) -> None:
1482         app = blackd.make_app()
1483         code = (
1484             "def f(\n"
1485             "    and_has_a_bunch_of,\n"
1486             "    very_long_arguments_too,\n"
1487             "    and_lots_of_them_as_well_lol,\n"
1488             "    **and_very_long_keyword_arguments\n"
1489             "):\n"
1490             "    pass\n"
1491         )
1492         async with TestClient(TestServer(app)) as client:
1493
1494             async def check(header_value: str, expected_status: int) -> None:
1495                 response = await client.post(
1496                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1497                 )
1498                 self.assertEqual(response.status, expected_status)
1499
1500             await check("3.6", 200)
1501             await check("py3.6", 200)
1502             await check("3.5,3.7", 200)
1503             await check("3.5,py3.7", 200)
1504
1505             await check("2", 204)
1506             await check("2.7", 204)
1507             await check("py2.7", 204)
1508             await check("3.4", 204)
1509             await check("py3.4", 204)
1510
1511     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1512     @async_test
1513     async def test_blackd_fast(self) -> None:
1514         with open(os.devnull, "w") as dn, redirect_stderr(dn):
1515             app = blackd.make_app()
1516             async with TestClient(TestServer(app)) as client:
1517                 response = await client.post("/", data=b"ur'hello'")
1518                 self.assertEqual(response.status, 500)
1519                 self.assertIn("failed to parse source file", await response.text())
1520                 response = await client.post(
1521                     "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1522                 )
1523                 self.assertEqual(response.status, 200)
1524
1525     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1526     @async_test
1527     async def test_blackd_line_length(self) -> None:
1528         app = blackd.make_app()
1529         async with TestClient(TestServer(app)) as client:
1530             response = await client.post(
1531                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1532             )
1533             self.assertEqual(response.status, 200)
1534
1535     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1536     @async_test
1537     async def test_blackd_invalid_line_length(self) -> None:
1538         app = blackd.make_app()
1539         async with TestClient(TestServer(app)) as client:
1540             response = await client.post(
1541                 "/",
1542                 data=b'print("hello")\n',
1543                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1544             )
1545             self.assertEqual(response.status, 400)
1546
1547     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1548     def test_blackd_main(self) -> None:
1549         with patch("blackd.web.run_app"):
1550             result = CliRunner().invoke(blackd.main, [])
1551             if result.exception is not None:
1552                 raise result.exception
1553             self.assertEqual(result.exit_code, 0)
1554
1555
1556 if __name__ == "__main__":
1557     unittest.main(module="test_black")