]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add trailing comma when a single import doesn't fit on a line. (#504)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial, wraps
6 from io import BytesIO, TextIOWrapper
7 import os
8 from pathlib import Path
9 import re
10 import sys
11 from tempfile import TemporaryDirectory
12 from typing import (
13     Any,
14     BinaryIO,
15     Callable,
16     Coroutine,
17     Generator,
18     List,
19     Tuple,
20     Iterator,
21     TypeVar,
22 )
23 import unittest
24 from unittest.mock import patch, MagicMock
25
26 from click import unstyle
27 from click.testing import CliRunner
28
29 import black
30
31 try:
32     import blackd
33     from aiohttp.test_utils import TestClient, TestServer
34 except ImportError:
35     has_blackd_deps = False
36 else:
37     has_blackd_deps = True
38
39
40 ll = 88
41 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
42 fs = partial(black.format_str, line_length=ll)
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 T = TypeVar("T")
47 R = TypeVar("R")
48
49
50 def dump_to_stderr(*output: str) -> str:
51     return "\n" + "\n".join(output) + "\n"
52
53
54 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
55     """read_data('test_name') -> 'input', 'output'"""
56     if not name.endswith((".py", ".pyi", ".out", ".diff")):
57         name += ".py"
58     _input: List[str] = []
59     _output: List[str] = []
60     base_dir = THIS_DIR / "data" if data else THIS_DIR
61     with open(base_dir / name, "r", encoding="utf8") as test:
62         lines = test.readlines()
63     result = _input
64     for line in lines:
65         line = line.replace(EMPTY_LINE, "")
66         if line.rstrip() == "# output":
67             result = _output
68             continue
69
70         result.append(line)
71     if _input and not _output:
72         # If there's no output marker, treat the entire file as already pre-formatted.
73         _output = _input[:]
74     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
75
76
77 @contextmanager
78 def cache_dir(exists: bool = True) -> Iterator[Path]:
79     with TemporaryDirectory() as workspace:
80         cache_dir = Path(workspace)
81         if not exists:
82             cache_dir = cache_dir / "new"
83         with patch("black.CACHE_DIR", cache_dir):
84             yield cache_dir
85
86
87 @contextmanager
88 def event_loop(close: bool) -> Iterator[None]:
89     policy = asyncio.get_event_loop_policy()
90     old_loop = policy.get_event_loop()
91     loop = policy.new_event_loop()
92     asyncio.set_event_loop(loop)
93     try:
94         yield
95
96     finally:
97         policy.set_event_loop(old_loop)
98         if close:
99             loop.close()
100
101
102 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
103     @event_loop(close=True)
104     @wraps(f)
105     def wrapper(*args: Any, **kwargs: Any) -> None:
106         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
107
108     return wrapper
109
110
111 class BlackRunner(CliRunner):
112     """Modify CliRunner so that stderr is not merged with stdout.
113
114     This is a hack that can be removed once we depend on Click 7.x"""
115
116     def __init__(self) -> None:
117         self.stderrbuf = BytesIO()
118         self.stdoutbuf = BytesIO()
119         self.stdout_bytes = b""
120         self.stderr_bytes = b""
121         super().__init__()
122
123     @contextmanager
124     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
125         with super().isolation(*args, **kwargs) as output:
126             try:
127                 hold_stderr = sys.stderr
128                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
129                 yield output
130             finally:
131                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
132                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
133                 sys.stderr = hold_stderr
134
135
136 class BlackTestCase(unittest.TestCase):
137     maxDiff = None
138
139     def assertFormatEqual(self, expected: str, actual: str) -> None:
140         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
141             bdv: black.DebugVisitor[Any]
142             black.out("Expected tree:", fg="green")
143             try:
144                 exp_node = black.lib2to3_parse(expected)
145                 bdv = black.DebugVisitor()
146                 list(bdv.visit(exp_node))
147             except Exception as ve:
148                 black.err(str(ve))
149             black.out("Actual tree:", fg="red")
150             try:
151                 exp_node = black.lib2to3_parse(actual)
152                 bdv = black.DebugVisitor()
153                 list(bdv.visit(exp_node))
154             except Exception as ve:
155                 black.err(str(ve))
156         self.assertEqual(expected, actual)
157
158     @patch("black.dump_to_file", dump_to_stderr)
159     def test_empty(self) -> None:
160         source = expected = ""
161         actual = fs(source)
162         self.assertFormatEqual(expected, actual)
163         black.assert_equivalent(source, actual)
164         black.assert_stable(source, actual, line_length=ll)
165
166     def test_empty_ff(self) -> None:
167         expected = ""
168         tmp_file = Path(black.dump_to_file())
169         try:
170             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
171             with open(tmp_file, encoding="utf8") as f:
172                 actual = f.read()
173         finally:
174             os.unlink(tmp_file)
175         self.assertFormatEqual(expected, actual)
176
177     @patch("black.dump_to_file", dump_to_stderr)
178     def test_self(self) -> None:
179         source, expected = read_data("test_black", data=False)
180         actual = fs(source)
181         self.assertFormatEqual(expected, actual)
182         black.assert_equivalent(source, actual)
183         black.assert_stable(source, actual, line_length=ll)
184         self.assertFalse(ff(THIS_FILE))
185
186     @patch("black.dump_to_file", dump_to_stderr)
187     def test_black(self) -> None:
188         source, expected = read_data("../black", data=False)
189         actual = fs(source)
190         self.assertFormatEqual(expected, actual)
191         black.assert_equivalent(source, actual)
192         black.assert_stable(source, actual, line_length=ll)
193         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
194
195     def test_piping(self) -> None:
196         source, expected = read_data("../black", data=False)
197         result = BlackRunner().invoke(
198             black.main,
199             ["-", "--fast", f"--line-length={ll}"],
200             input=BytesIO(source.encode("utf8")),
201         )
202         self.assertEqual(result.exit_code, 0)
203         self.assertFormatEqual(expected, result.output)
204         black.assert_equivalent(source, result.output)
205         black.assert_stable(source, result.output, line_length=ll)
206
207     def test_piping_diff(self) -> None:
208         diff_header = re.compile(
209             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
210             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
211         )
212         source, _ = read_data("expression.py")
213         expected, _ = read_data("expression.diff")
214         config = THIS_DIR / "data" / "empty_pyproject.toml"
215         args = ["-", "--fast", f"--line-length={ll}", "--diff", f"--config={config}"]
216         result = BlackRunner().invoke(
217             black.main, args, input=BytesIO(source.encode("utf8"))
218         )
219         self.assertEqual(result.exit_code, 0)
220         actual = diff_header.sub("[Deterministic header]", result.output)
221         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
222         self.assertEqual(expected, actual)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_setup(self) -> None:
226         source, expected = read_data("../setup", data=False)
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
232
233     @patch("black.dump_to_file", dump_to_stderr)
234     def test_function(self) -> None:
235         source, expected = read_data("function")
236         actual = fs(source)
237         self.assertFormatEqual(expected, actual)
238         black.assert_equivalent(source, actual)
239         black.assert_stable(source, actual, line_length=ll)
240
241     @patch("black.dump_to_file", dump_to_stderr)
242     def test_function2(self) -> None:
243         source, expected = read_data("function2")
244         actual = fs(source)
245         self.assertFormatEqual(expected, actual)
246         black.assert_equivalent(source, actual)
247         black.assert_stable(source, actual, line_length=ll)
248
249     @patch("black.dump_to_file", dump_to_stderr)
250     def test_expression(self) -> None:
251         source, expected = read_data("expression")
252         actual = fs(source)
253         self.assertFormatEqual(expected, actual)
254         black.assert_equivalent(source, actual)
255         black.assert_stable(source, actual, line_length=ll)
256
257     def test_expression_ff(self) -> None:
258         source, expected = read_data("expression")
259         tmp_file = Path(black.dump_to_file(source))
260         try:
261             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
262             with open(tmp_file, encoding="utf8") as f:
263                 actual = f.read()
264         finally:
265             os.unlink(tmp_file)
266         self.assertFormatEqual(expected, actual)
267         with patch("black.dump_to_file", dump_to_stderr):
268             black.assert_equivalent(source, actual)
269             black.assert_stable(source, actual, line_length=ll)
270
271     def test_expression_diff(self) -> None:
272         source, _ = read_data("expression.py")
273         expected, _ = read_data("expression.diff")
274         tmp_file = Path(black.dump_to_file(source))
275         diff_header = re.compile(
276             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
277             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
278         )
279         try:
280             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
281             self.assertEqual(result.exit_code, 0)
282         finally:
283             os.unlink(tmp_file)
284         actual = result.output
285         actual = diff_header.sub("[Deterministic header]", actual)
286         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
287         if expected != actual:
288             dump = black.dump_to_file(actual)
289             msg = (
290                 f"Expected diff isn't equal to the actual. If you made changes "
291                 f"to expression.py and this is an anticipated difference, "
292                 f"overwrite tests/expression.diff with {dump}"
293             )
294             self.assertEqual(expected, actual, msg)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_fstring(self) -> None:
298         source, expected = read_data("fstring")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_string_quotes(self) -> None:
306         source, expected = read_data("string_quotes")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311         mode = black.FileMode.NO_STRING_NORMALIZATION
312         not_normalized = fs(source, mode=mode)
313         self.assertFormatEqual(source, not_normalized)
314         black.assert_equivalent(source, not_normalized)
315         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
316
317     @patch("black.dump_to_file", dump_to_stderr)
318     def test_slices(self) -> None:
319         source, expected = read_data("slices")
320         actual = fs(source)
321         self.assertFormatEqual(expected, actual)
322         black.assert_equivalent(source, actual)
323         black.assert_stable(source, actual, line_length=ll)
324
325     @patch("black.dump_to_file", dump_to_stderr)
326     def test_comments(self) -> None:
327         source, expected = read_data("comments")
328         actual = fs(source)
329         self.assertFormatEqual(expected, actual)
330         black.assert_equivalent(source, actual)
331         black.assert_stable(source, actual, line_length=ll)
332
333     @patch("black.dump_to_file", dump_to_stderr)
334     def test_comments2(self) -> None:
335         source, expected = read_data("comments2")
336         actual = fs(source)
337         self.assertFormatEqual(expected, actual)
338         black.assert_equivalent(source, actual)
339         black.assert_stable(source, actual, line_length=ll)
340
341     @patch("black.dump_to_file", dump_to_stderr)
342     def test_comments3(self) -> None:
343         source, expected = read_data("comments3")
344         actual = fs(source)
345         self.assertFormatEqual(expected, actual)
346         black.assert_equivalent(source, actual)
347         black.assert_stable(source, actual, line_length=ll)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_comments4(self) -> None:
351         source, expected = read_data("comments4")
352         actual = fs(source)
353         self.assertFormatEqual(expected, actual)
354         black.assert_equivalent(source, actual)
355         black.assert_stable(source, actual, line_length=ll)
356
357     @patch("black.dump_to_file", dump_to_stderr)
358     def test_comments5(self) -> None:
359         source, expected = read_data("comments5")
360         actual = fs(source)
361         self.assertFormatEqual(expected, actual)
362         black.assert_equivalent(source, actual)
363         black.assert_stable(source, actual, line_length=ll)
364
365     @patch("black.dump_to_file", dump_to_stderr)
366     def test_cantfit(self) -> None:
367         source, expected = read_data("cantfit")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         black.assert_equivalent(source, actual)
371         black.assert_stable(source, actual, line_length=ll)
372
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_import_spacing(self) -> None:
375         source, expected = read_data("import_spacing")
376         actual = fs(source)
377         self.assertFormatEqual(expected, actual)
378         black.assert_equivalent(source, actual)
379         black.assert_stable(source, actual, line_length=ll)
380
381     @patch("black.dump_to_file", dump_to_stderr)
382     def test_composition(self) -> None:
383         source, expected = read_data("composition")
384         actual = fs(source)
385         self.assertFormatEqual(expected, actual)
386         black.assert_equivalent(source, actual)
387         black.assert_stable(source, actual, line_length=ll)
388
389     @patch("black.dump_to_file", dump_to_stderr)
390     def test_empty_lines(self) -> None:
391         source, expected = read_data("empty_lines")
392         actual = fs(source)
393         self.assertFormatEqual(expected, actual)
394         black.assert_equivalent(source, actual)
395         black.assert_stable(source, actual, line_length=ll)
396
397     @patch("black.dump_to_file", dump_to_stderr)
398     def test_string_prefixes(self) -> None:
399         source, expected = read_data("string_prefixes")
400         actual = fs(source)
401         self.assertFormatEqual(expected, actual)
402         black.assert_equivalent(source, actual)
403         black.assert_stable(source, actual, line_length=ll)
404
405     @patch("black.dump_to_file", dump_to_stderr)
406     def test_numeric_literals(self) -> None:
407         source, expected = read_data("numeric_literals")
408         actual = fs(source, mode=black.FileMode.PYTHON36)
409         self.assertFormatEqual(expected, actual)
410         black.assert_equivalent(source, actual)
411         black.assert_stable(source, actual, line_length=ll)
412
413     @patch("black.dump_to_file", dump_to_stderr)
414     def test_numeric_literals_ignoring_underscores(self) -> None:
415         source, expected = read_data("numeric_literals_skip_underscores")
416         mode = (
417             black.FileMode.PYTHON36 | black.FileMode.NO_NUMERIC_UNDERSCORE_NORMALIZATION
418         )
419         actual = fs(source, mode=mode)
420         self.assertFormatEqual(expected, actual)
421         black.assert_equivalent(source, actual)
422         black.assert_stable(source, actual, line_length=ll, mode=mode)
423
424     @patch("black.dump_to_file", dump_to_stderr)
425     def test_numeric_literals_py2(self) -> None:
426         source, expected = read_data("numeric_literals_py2")
427         actual = fs(source)
428         self.assertFormatEqual(expected, actual)
429         black.assert_stable(source, actual, line_length=ll)
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_python2(self) -> None:
433         source, expected = read_data("python2")
434         actual = fs(source)
435         self.assertFormatEqual(expected, actual)
436         # black.assert_equivalent(source, actual)
437         black.assert_stable(source, actual, line_length=ll)
438
439     @patch("black.dump_to_file", dump_to_stderr)
440     def test_python2_unicode_literals(self) -> None:
441         source, expected = read_data("python2_unicode_literals")
442         actual = fs(source)
443         self.assertFormatEqual(expected, actual)
444         black.assert_stable(source, actual, line_length=ll)
445
446     @patch("black.dump_to_file", dump_to_stderr)
447     def test_stub(self) -> None:
448         mode = black.FileMode.PYI
449         source, expected = read_data("stub.pyi")
450         actual = fs(source, mode=mode)
451         self.assertFormatEqual(expected, actual)
452         black.assert_stable(source, actual, line_length=ll, mode=mode)
453
454     @patch("black.dump_to_file", dump_to_stderr)
455     def test_python37(self) -> None:
456         source, expected = read_data("python37")
457         actual = fs(source)
458         self.assertFormatEqual(expected, actual)
459         major, minor = sys.version_info[:2]
460         if major > 3 or (major == 3 and minor >= 7):
461             black.assert_equivalent(source, actual)
462         black.assert_stable(source, actual, line_length=ll)
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_fmtonoff(self) -> None:
466         source, expected = read_data("fmtonoff")
467         actual = fs(source)
468         self.assertFormatEqual(expected, actual)
469         black.assert_equivalent(source, actual)
470         black.assert_stable(source, actual, line_length=ll)
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_fmtonoff2(self) -> None:
474         source, expected = read_data("fmtonoff2")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         black.assert_equivalent(source, actual)
478         black.assert_stable(source, actual, line_length=ll)
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_remove_empty_parentheses_after_class(self) -> None:
482         source, expected = read_data("class_blank_parentheses")
483         actual = fs(source)
484         self.assertFormatEqual(expected, actual)
485         black.assert_equivalent(source, actual)
486         black.assert_stable(source, actual, line_length=ll)
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_new_line_between_class_and_code(self) -> None:
490         source, expected = read_data("class_methods_new_line")
491         actual = fs(source)
492         self.assertFormatEqual(expected, actual)
493         black.assert_equivalent(source, actual)
494         black.assert_stable(source, actual, line_length=ll)
495
496     @patch("black.dump_to_file", dump_to_stderr)
497     def test_bracket_match(self) -> None:
498         source, expected = read_data("bracketmatch")
499         actual = fs(source)
500         self.assertFormatEqual(expected, actual)
501         black.assert_equivalent(source, actual)
502         black.assert_stable(source, actual, line_length=ll)
503
504     def test_report_verbose(self) -> None:
505         report = black.Report(verbose=True)
506         out_lines = []
507         err_lines = []
508
509         def out(msg: str, **kwargs: Any) -> None:
510             out_lines.append(msg)
511
512         def err(msg: str, **kwargs: Any) -> None:
513             err_lines.append(msg)
514
515         with patch("black.out", out), patch("black.err", err):
516             report.done(Path("f1"), black.Changed.NO)
517             self.assertEqual(len(out_lines), 1)
518             self.assertEqual(len(err_lines), 0)
519             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
520             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
521             self.assertEqual(report.return_code, 0)
522             report.done(Path("f2"), black.Changed.YES)
523             self.assertEqual(len(out_lines), 2)
524             self.assertEqual(len(err_lines), 0)
525             self.assertEqual(out_lines[-1], "reformatted f2")
526             self.assertEqual(
527                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
528             )
529             report.done(Path("f3"), black.Changed.CACHED)
530             self.assertEqual(len(out_lines), 3)
531             self.assertEqual(len(err_lines), 0)
532             self.assertEqual(
533                 out_lines[-1], "f3 wasn't modified on disk since last run."
534             )
535             self.assertEqual(
536                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
537             )
538             self.assertEqual(report.return_code, 0)
539             report.check = True
540             self.assertEqual(report.return_code, 1)
541             report.check = False
542             report.failed(Path("e1"), "boom")
543             self.assertEqual(len(out_lines), 3)
544             self.assertEqual(len(err_lines), 1)
545             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
546             self.assertEqual(
547                 unstyle(str(report)),
548                 "1 file reformatted, 2 files left unchanged, "
549                 "1 file failed to reformat.",
550             )
551             self.assertEqual(report.return_code, 123)
552             report.done(Path("f3"), black.Changed.YES)
553             self.assertEqual(len(out_lines), 4)
554             self.assertEqual(len(err_lines), 1)
555             self.assertEqual(out_lines[-1], "reformatted f3")
556             self.assertEqual(
557                 unstyle(str(report)),
558                 "2 files reformatted, 2 files left unchanged, "
559                 "1 file failed to reformat.",
560             )
561             self.assertEqual(report.return_code, 123)
562             report.failed(Path("e2"), "boom")
563             self.assertEqual(len(out_lines), 4)
564             self.assertEqual(len(err_lines), 2)
565             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
566             self.assertEqual(
567                 unstyle(str(report)),
568                 "2 files reformatted, 2 files left unchanged, "
569                 "2 files failed to reformat.",
570             )
571             self.assertEqual(report.return_code, 123)
572             report.path_ignored(Path("wat"), "no match")
573             self.assertEqual(len(out_lines), 5)
574             self.assertEqual(len(err_lines), 2)
575             self.assertEqual(out_lines[-1], "wat ignored: no match")
576             self.assertEqual(
577                 unstyle(str(report)),
578                 "2 files reformatted, 2 files left unchanged, "
579                 "2 files failed to reformat.",
580             )
581             self.assertEqual(report.return_code, 123)
582             report.done(Path("f4"), black.Changed.NO)
583             self.assertEqual(len(out_lines), 6)
584             self.assertEqual(len(err_lines), 2)
585             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
586             self.assertEqual(
587                 unstyle(str(report)),
588                 "2 files reformatted, 3 files left unchanged, "
589                 "2 files failed to reformat.",
590             )
591             self.assertEqual(report.return_code, 123)
592             report.check = True
593             self.assertEqual(
594                 unstyle(str(report)),
595                 "2 files would be reformatted, 3 files would be left unchanged, "
596                 "2 files would fail to reformat.",
597             )
598
599     def test_report_quiet(self) -> None:
600         report = black.Report(quiet=True)
601         out_lines = []
602         err_lines = []
603
604         def out(msg: str, **kwargs: Any) -> None:
605             out_lines.append(msg)
606
607         def err(msg: str, **kwargs: Any) -> None:
608             err_lines.append(msg)
609
610         with patch("black.out", out), patch("black.err", err):
611             report.done(Path("f1"), black.Changed.NO)
612             self.assertEqual(len(out_lines), 0)
613             self.assertEqual(len(err_lines), 0)
614             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
615             self.assertEqual(report.return_code, 0)
616             report.done(Path("f2"), black.Changed.YES)
617             self.assertEqual(len(out_lines), 0)
618             self.assertEqual(len(err_lines), 0)
619             self.assertEqual(
620                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
621             )
622             report.done(Path("f3"), black.Changed.CACHED)
623             self.assertEqual(len(out_lines), 0)
624             self.assertEqual(len(err_lines), 0)
625             self.assertEqual(
626                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
627             )
628             self.assertEqual(report.return_code, 0)
629             report.check = True
630             self.assertEqual(report.return_code, 1)
631             report.check = False
632             report.failed(Path("e1"), "boom")
633             self.assertEqual(len(out_lines), 0)
634             self.assertEqual(len(err_lines), 1)
635             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
636             self.assertEqual(
637                 unstyle(str(report)),
638                 "1 file reformatted, 2 files left unchanged, "
639                 "1 file failed to reformat.",
640             )
641             self.assertEqual(report.return_code, 123)
642             report.done(Path("f3"), black.Changed.YES)
643             self.assertEqual(len(out_lines), 0)
644             self.assertEqual(len(err_lines), 1)
645             self.assertEqual(
646                 unstyle(str(report)),
647                 "2 files reformatted, 2 files left unchanged, "
648                 "1 file failed to reformat.",
649             )
650             self.assertEqual(report.return_code, 123)
651             report.failed(Path("e2"), "boom")
652             self.assertEqual(len(out_lines), 0)
653             self.assertEqual(len(err_lines), 2)
654             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
655             self.assertEqual(
656                 unstyle(str(report)),
657                 "2 files reformatted, 2 files left unchanged, "
658                 "2 files failed to reformat.",
659             )
660             self.assertEqual(report.return_code, 123)
661             report.path_ignored(Path("wat"), "no match")
662             self.assertEqual(len(out_lines), 0)
663             self.assertEqual(len(err_lines), 2)
664             self.assertEqual(
665                 unstyle(str(report)),
666                 "2 files reformatted, 2 files left unchanged, "
667                 "2 files failed to reformat.",
668             )
669             self.assertEqual(report.return_code, 123)
670             report.done(Path("f4"), black.Changed.NO)
671             self.assertEqual(len(out_lines), 0)
672             self.assertEqual(len(err_lines), 2)
673             self.assertEqual(
674                 unstyle(str(report)),
675                 "2 files reformatted, 3 files left unchanged, "
676                 "2 files failed to reformat.",
677             )
678             self.assertEqual(report.return_code, 123)
679             report.check = True
680             self.assertEqual(
681                 unstyle(str(report)),
682                 "2 files would be reformatted, 3 files would be left unchanged, "
683                 "2 files would fail to reformat.",
684             )
685
686     def test_report_normal(self) -> None:
687         report = black.Report()
688         out_lines = []
689         err_lines = []
690
691         def out(msg: str, **kwargs: Any) -> None:
692             out_lines.append(msg)
693
694         def err(msg: str, **kwargs: Any) -> None:
695             err_lines.append(msg)
696
697         with patch("black.out", out), patch("black.err", err):
698             report.done(Path("f1"), black.Changed.NO)
699             self.assertEqual(len(out_lines), 0)
700             self.assertEqual(len(err_lines), 0)
701             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
702             self.assertEqual(report.return_code, 0)
703             report.done(Path("f2"), black.Changed.YES)
704             self.assertEqual(len(out_lines), 1)
705             self.assertEqual(len(err_lines), 0)
706             self.assertEqual(out_lines[-1], "reformatted f2")
707             self.assertEqual(
708                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
709             )
710             report.done(Path("f3"), black.Changed.CACHED)
711             self.assertEqual(len(out_lines), 1)
712             self.assertEqual(len(err_lines), 0)
713             self.assertEqual(out_lines[-1], "reformatted f2")
714             self.assertEqual(
715                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
716             )
717             self.assertEqual(report.return_code, 0)
718             report.check = True
719             self.assertEqual(report.return_code, 1)
720             report.check = False
721             report.failed(Path("e1"), "boom")
722             self.assertEqual(len(out_lines), 1)
723             self.assertEqual(len(err_lines), 1)
724             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
725             self.assertEqual(
726                 unstyle(str(report)),
727                 "1 file reformatted, 2 files left unchanged, "
728                 "1 file failed to reformat.",
729             )
730             self.assertEqual(report.return_code, 123)
731             report.done(Path("f3"), black.Changed.YES)
732             self.assertEqual(len(out_lines), 2)
733             self.assertEqual(len(err_lines), 1)
734             self.assertEqual(out_lines[-1], "reformatted f3")
735             self.assertEqual(
736                 unstyle(str(report)),
737                 "2 files reformatted, 2 files left unchanged, "
738                 "1 file failed to reformat.",
739             )
740             self.assertEqual(report.return_code, 123)
741             report.failed(Path("e2"), "boom")
742             self.assertEqual(len(out_lines), 2)
743             self.assertEqual(len(err_lines), 2)
744             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
745             self.assertEqual(
746                 unstyle(str(report)),
747                 "2 files reformatted, 2 files left unchanged, "
748                 "2 files failed to reformat.",
749             )
750             self.assertEqual(report.return_code, 123)
751             report.path_ignored(Path("wat"), "no match")
752             self.assertEqual(len(out_lines), 2)
753             self.assertEqual(len(err_lines), 2)
754             self.assertEqual(
755                 unstyle(str(report)),
756                 "2 files reformatted, 2 files left unchanged, "
757                 "2 files failed to reformat.",
758             )
759             self.assertEqual(report.return_code, 123)
760             report.done(Path("f4"), black.Changed.NO)
761             self.assertEqual(len(out_lines), 2)
762             self.assertEqual(len(err_lines), 2)
763             self.assertEqual(
764                 unstyle(str(report)),
765                 "2 files reformatted, 3 files left unchanged, "
766                 "2 files failed to reformat.",
767             )
768             self.assertEqual(report.return_code, 123)
769             report.check = True
770             self.assertEqual(
771                 unstyle(str(report)),
772                 "2 files would be reformatted, 3 files would be left unchanged, "
773                 "2 files would fail to reformat.",
774             )
775
776     def test_is_python36(self) -> None:
777         node = black.lib2to3_parse("def f(*, arg): ...\n")
778         self.assertFalse(black.is_python36(node))
779         node = black.lib2to3_parse("def f(*, arg,): ...\n")
780         self.assertTrue(black.is_python36(node))
781         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
782         self.assertTrue(black.is_python36(node))
783         node = black.lib2to3_parse("123_456\n")
784         self.assertTrue(black.is_python36(node))
785         node = black.lib2to3_parse("123456\n")
786         self.assertFalse(black.is_python36(node))
787         source, expected = read_data("function")
788         node = black.lib2to3_parse(source)
789         self.assertTrue(black.is_python36(node))
790         node = black.lib2to3_parse(expected)
791         self.assertTrue(black.is_python36(node))
792         source, expected = read_data("expression")
793         node = black.lib2to3_parse(source)
794         self.assertFalse(black.is_python36(node))
795         node = black.lib2to3_parse(expected)
796         self.assertFalse(black.is_python36(node))
797
798     def test_get_future_imports(self) -> None:
799         node = black.lib2to3_parse("\n")
800         self.assertEqual(set(), black.get_future_imports(node))
801         node = black.lib2to3_parse("from __future__ import black\n")
802         self.assertEqual({"black"}, black.get_future_imports(node))
803         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
804         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
805         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
806         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
807         node = black.lib2to3_parse(
808             "from __future__ import multiple\nfrom __future__ import imports\n"
809         )
810         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
811         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
812         self.assertEqual({"black"}, black.get_future_imports(node))
813         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
814         self.assertEqual({"black"}, black.get_future_imports(node))
815         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
816         self.assertEqual(set(), black.get_future_imports(node))
817         node = black.lib2to3_parse("from some.module import black\n")
818         self.assertEqual(set(), black.get_future_imports(node))
819         node = black.lib2to3_parse(
820             "from __future__ import unicode_literals as _unicode_literals"
821         )
822         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
823         node = black.lib2to3_parse(
824             "from __future__ import unicode_literals as _lol, print"
825         )
826         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
827
828     def test_debug_visitor(self) -> None:
829         source, _ = read_data("debug_visitor.py")
830         expected, _ = read_data("debug_visitor.out")
831         out_lines = []
832         err_lines = []
833
834         def out(msg: str, **kwargs: Any) -> None:
835             out_lines.append(msg)
836
837         def err(msg: str, **kwargs: Any) -> None:
838             err_lines.append(msg)
839
840         with patch("black.out", out), patch("black.err", err):
841             black.DebugVisitor.show(source)
842         actual = "\n".join(out_lines) + "\n"
843         log_name = ""
844         if expected != actual:
845             log_name = black.dump_to_file(*out_lines)
846         self.assertEqual(
847             expected,
848             actual,
849             f"AST print out is different. Actual version dumped to {log_name}",
850         )
851
852     def test_format_file_contents(self) -> None:
853         empty = ""
854         with self.assertRaises(black.NothingChanged):
855             black.format_file_contents(empty, line_length=ll, fast=False)
856         just_nl = "\n"
857         with self.assertRaises(black.NothingChanged):
858             black.format_file_contents(just_nl, line_length=ll, fast=False)
859         same = "l = [1, 2, 3]\n"
860         with self.assertRaises(black.NothingChanged):
861             black.format_file_contents(same, line_length=ll, fast=False)
862         different = "l = [1,2,3]"
863         expected = same
864         actual = black.format_file_contents(different, line_length=ll, fast=False)
865         self.assertEqual(expected, actual)
866         invalid = "return if you can"
867         with self.assertRaises(black.InvalidInput) as e:
868             black.format_file_contents(invalid, line_length=ll, fast=False)
869         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
870
871     def test_endmarker(self) -> None:
872         n = black.lib2to3_parse("\n")
873         self.assertEqual(n.type, black.syms.file_input)
874         self.assertEqual(len(n.children), 1)
875         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
876
877     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
878     def test_assertFormatEqual(self) -> None:
879         out_lines = []
880         err_lines = []
881
882         def out(msg: str, **kwargs: Any) -> None:
883             out_lines.append(msg)
884
885         def err(msg: str, **kwargs: Any) -> None:
886             err_lines.append(msg)
887
888         with patch("black.out", out), patch("black.err", err):
889             with self.assertRaises(AssertionError):
890                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
891
892         out_str = "".join(out_lines)
893         self.assertTrue("Expected tree:" in out_str)
894         self.assertTrue("Actual tree:" in out_str)
895         self.assertEqual("".join(err_lines), "")
896
897     def test_cache_broken_file(self) -> None:
898         mode = black.FileMode.AUTO_DETECT
899         with cache_dir() as workspace:
900             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
901             with cache_file.open("w") as fobj:
902                 fobj.write("this is not a pickle")
903             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
904             src = (workspace / "test.py").resolve()
905             with src.open("w") as fobj:
906                 fobj.write("print('hello')")
907             result = CliRunner().invoke(black.main, [str(src)])
908             self.assertEqual(result.exit_code, 0)
909             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
910             self.assertIn(src, cache)
911
912     def test_cache_single_file_already_cached(self) -> None:
913         mode = black.FileMode.AUTO_DETECT
914         with cache_dir() as workspace:
915             src = (workspace / "test.py").resolve()
916             with src.open("w") as fobj:
917                 fobj.write("print('hello')")
918             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
919             result = CliRunner().invoke(black.main, [str(src)])
920             self.assertEqual(result.exit_code, 0)
921             with src.open("r") as fobj:
922                 self.assertEqual(fobj.read(), "print('hello')")
923
924     @event_loop(close=False)
925     def test_cache_multiple_files(self) -> None:
926         mode = black.FileMode.AUTO_DETECT
927         with cache_dir() as workspace, patch(
928             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
929         ):
930             one = (workspace / "one.py").resolve()
931             with one.open("w") as fobj:
932                 fobj.write("print('hello')")
933             two = (workspace / "two.py").resolve()
934             with two.open("w") as fobj:
935                 fobj.write("print('hello')")
936             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
937             result = CliRunner().invoke(black.main, [str(workspace)])
938             self.assertEqual(result.exit_code, 0)
939             with one.open("r") as fobj:
940                 self.assertEqual(fobj.read(), "print('hello')")
941             with two.open("r") as fobj:
942                 self.assertEqual(fobj.read(), 'print("hello")\n')
943             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
944             self.assertIn(one, cache)
945             self.assertIn(two, cache)
946
947     def test_no_cache_when_writeback_diff(self) -> None:
948         mode = black.FileMode.AUTO_DETECT
949         with cache_dir() as workspace:
950             src = (workspace / "test.py").resolve()
951             with src.open("w") as fobj:
952                 fobj.write("print('hello')")
953             result = CliRunner().invoke(black.main, [str(src), "--diff"])
954             self.assertEqual(result.exit_code, 0)
955             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
956             self.assertFalse(cache_file.exists())
957
958     def test_no_cache_when_stdin(self) -> None:
959         mode = black.FileMode.AUTO_DETECT
960         with cache_dir():
961             result = CliRunner().invoke(
962                 black.main, ["-"], input=BytesIO(b"print('hello')")
963             )
964             self.assertEqual(result.exit_code, 0)
965             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
966             self.assertFalse(cache_file.exists())
967
968     def test_read_cache_no_cachefile(self) -> None:
969         mode = black.FileMode.AUTO_DETECT
970         with cache_dir():
971             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
972
973     def test_write_cache_read_cache(self) -> None:
974         mode = black.FileMode.AUTO_DETECT
975         with cache_dir() as workspace:
976             src = (workspace / "test.py").resolve()
977             src.touch()
978             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
979             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
980             self.assertIn(src, cache)
981             self.assertEqual(cache[src], black.get_cache_info(src))
982
983     def test_filter_cached(self) -> None:
984         with TemporaryDirectory() as workspace:
985             path = Path(workspace)
986             uncached = (path / "uncached").resolve()
987             cached = (path / "cached").resolve()
988             cached_but_changed = (path / "changed").resolve()
989             uncached.touch()
990             cached.touch()
991             cached_but_changed.touch()
992             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
993             todo, done = black.filter_cached(
994                 cache, {uncached, cached, cached_but_changed}
995             )
996             self.assertEqual(todo, {uncached, cached_but_changed})
997             self.assertEqual(done, {cached})
998
999     def test_write_cache_creates_directory_if_needed(self) -> None:
1000         mode = black.FileMode.AUTO_DETECT
1001         with cache_dir(exists=False) as workspace:
1002             self.assertFalse(workspace.exists())
1003             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
1004             self.assertTrue(workspace.exists())
1005
1006     @event_loop(close=False)
1007     def test_failed_formatting_does_not_get_cached(self) -> None:
1008         mode = black.FileMode.AUTO_DETECT
1009         with cache_dir() as workspace, patch(
1010             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1011         ):
1012             failing = (workspace / "failing.py").resolve()
1013             with failing.open("w") as fobj:
1014                 fobj.write("not actually python")
1015             clean = (workspace / "clean.py").resolve()
1016             with clean.open("w") as fobj:
1017                 fobj.write('print("hello")\n')
1018             result = CliRunner().invoke(black.main, [str(workspace)])
1019             self.assertEqual(result.exit_code, 123)
1020             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
1021             self.assertNotIn(failing, cache)
1022             self.assertIn(clean, cache)
1023
1024     def test_write_cache_write_fail(self) -> None:
1025         mode = black.FileMode.AUTO_DETECT
1026         with cache_dir(), patch.object(Path, "open") as mock:
1027             mock.side_effect = OSError
1028             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
1029
1030     @event_loop(close=False)
1031     def test_check_diff_use_together(self) -> None:
1032         with cache_dir():
1033             # Files which will be reformatted.
1034             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1035             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
1036             self.assertEqual(result.exit_code, 1, result.output)
1037             # Files which will not be reformatted.
1038             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1039             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
1040             self.assertEqual(result.exit_code, 0, result.output)
1041             # Multi file command.
1042             result = CliRunner().invoke(
1043                 black.main, [str(src1), str(src2), "--diff", "--check"]
1044             )
1045             self.assertEqual(result.exit_code, 1, result.output)
1046
1047     def test_no_files(self) -> None:
1048         with cache_dir():
1049             # Without an argument, black exits with error code 0.
1050             result = CliRunner().invoke(black.main, [])
1051             self.assertEqual(result.exit_code, 0)
1052
1053     def test_broken_symlink(self) -> None:
1054         with cache_dir() as workspace:
1055             symlink = workspace / "broken_link.py"
1056             try:
1057                 symlink.symlink_to("nonexistent.py")
1058             except OSError as e:
1059                 self.skipTest(f"Can't create symlinks: {e}")
1060             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
1061             self.assertEqual(result.exit_code, 0)
1062
1063     def test_read_cache_line_lengths(self) -> None:
1064         mode = black.FileMode.AUTO_DETECT
1065         with cache_dir() as workspace:
1066             path = (workspace / "file.py").resolve()
1067             path.touch()
1068             black.write_cache({}, [path], 1, mode)
1069             one = black.read_cache(1, mode)
1070             self.assertIn(path, one)
1071             two = black.read_cache(2, mode)
1072             self.assertNotIn(path, two)
1073
1074     def test_single_file_force_pyi(self) -> None:
1075         reg_mode = black.FileMode.AUTO_DETECT
1076         pyi_mode = black.FileMode.PYI
1077         contents, expected = read_data("force_pyi")
1078         with cache_dir() as workspace:
1079             path = (workspace / "file.py").resolve()
1080             with open(path, "w") as fh:
1081                 fh.write(contents)
1082             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
1083             self.assertEqual(result.exit_code, 0)
1084             with open(path, "r") as fh:
1085                 actual = fh.read()
1086             # verify cache with --pyi is separate
1087             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1088             self.assertIn(path, pyi_cache)
1089             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1090             self.assertNotIn(path, normal_cache)
1091         self.assertEqual(actual, expected)
1092
1093     @event_loop(close=False)
1094     def test_multi_file_force_pyi(self) -> None:
1095         reg_mode = black.FileMode.AUTO_DETECT
1096         pyi_mode = black.FileMode.PYI
1097         contents, expected = read_data("force_pyi")
1098         with cache_dir() as workspace:
1099             paths = [
1100                 (workspace / "file1.py").resolve(),
1101                 (workspace / "file2.py").resolve(),
1102             ]
1103             for path in paths:
1104                 with open(path, "w") as fh:
1105                     fh.write(contents)
1106             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
1107             self.assertEqual(result.exit_code, 0)
1108             for path in paths:
1109                 with open(path, "r") as fh:
1110                     actual = fh.read()
1111                 self.assertEqual(actual, expected)
1112             # verify cache with --pyi is separate
1113             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1114             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1115             for path in paths:
1116                 self.assertIn(path, pyi_cache)
1117                 self.assertNotIn(path, normal_cache)
1118
1119     def test_pipe_force_pyi(self) -> None:
1120         source, expected = read_data("force_pyi")
1121         result = CliRunner().invoke(
1122             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1123         )
1124         self.assertEqual(result.exit_code, 0)
1125         actual = result.output
1126         self.assertFormatEqual(actual, expected)
1127
1128     def test_single_file_force_py36(self) -> None:
1129         reg_mode = black.FileMode.AUTO_DETECT
1130         py36_mode = black.FileMode.PYTHON36
1131         source, expected = read_data("force_py36")
1132         with cache_dir() as workspace:
1133             path = (workspace / "file.py").resolve()
1134             with open(path, "w") as fh:
1135                 fh.write(source)
1136             result = CliRunner().invoke(black.main, [str(path), "--py36"])
1137             self.assertEqual(result.exit_code, 0)
1138             with open(path, "r") as fh:
1139                 actual = fh.read()
1140             # verify cache with --py36 is separate
1141             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1142             self.assertIn(path, py36_cache)
1143             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1144             self.assertNotIn(path, normal_cache)
1145         self.assertEqual(actual, expected)
1146
1147     @event_loop(close=False)
1148     def test_multi_file_force_py36(self) -> None:
1149         reg_mode = black.FileMode.AUTO_DETECT
1150         py36_mode = black.FileMode.PYTHON36
1151         source, expected = read_data("force_py36")
1152         with cache_dir() as workspace:
1153             paths = [
1154                 (workspace / "file1.py").resolve(),
1155                 (workspace / "file2.py").resolve(),
1156             ]
1157             for path in paths:
1158                 with open(path, "w") as fh:
1159                     fh.write(source)
1160             result = CliRunner().invoke(
1161                 black.main, [str(p) for p in paths] + ["--py36"]
1162             )
1163             self.assertEqual(result.exit_code, 0)
1164             for path in paths:
1165                 with open(path, "r") as fh:
1166                     actual = fh.read()
1167                 self.assertEqual(actual, expected)
1168             # verify cache with --py36 is separate
1169             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1170             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1171             for path in paths:
1172                 self.assertIn(path, pyi_cache)
1173                 self.assertNotIn(path, normal_cache)
1174
1175     def test_pipe_force_py36(self) -> None:
1176         source, expected = read_data("force_py36")
1177         result = CliRunner().invoke(
1178             black.main, ["-", "-q", "--py36"], input=BytesIO(source.encode("utf8"))
1179         )
1180         self.assertEqual(result.exit_code, 0)
1181         actual = result.output
1182         self.assertFormatEqual(actual, expected)
1183
1184     def test_include_exclude(self) -> None:
1185         path = THIS_DIR / "data" / "include_exclude_tests"
1186         include = re.compile(r"\.pyi?$")
1187         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1188         report = black.Report()
1189         sources: List[Path] = []
1190         expected = [
1191             Path(path / "b/dont_exclude/a.py"),
1192             Path(path / "b/dont_exclude/a.pyi"),
1193         ]
1194         this_abs = THIS_DIR.resolve()
1195         sources.extend(
1196             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1197         )
1198         self.assertEqual(sorted(expected), sorted(sources))
1199
1200     def test_empty_include(self) -> None:
1201         path = THIS_DIR / "data" / "include_exclude_tests"
1202         report = black.Report()
1203         empty = re.compile(r"")
1204         sources: List[Path] = []
1205         expected = [
1206             Path(path / "b/exclude/a.pie"),
1207             Path(path / "b/exclude/a.py"),
1208             Path(path / "b/exclude/a.pyi"),
1209             Path(path / "b/dont_exclude/a.pie"),
1210             Path(path / "b/dont_exclude/a.py"),
1211             Path(path / "b/dont_exclude/a.pyi"),
1212             Path(path / "b/.definitely_exclude/a.pie"),
1213             Path(path / "b/.definitely_exclude/a.py"),
1214             Path(path / "b/.definitely_exclude/a.pyi"),
1215         ]
1216         this_abs = THIS_DIR.resolve()
1217         sources.extend(
1218             black.gen_python_files_in_dir(
1219                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1220             )
1221         )
1222         self.assertEqual(sorted(expected), sorted(sources))
1223
1224     def test_empty_exclude(self) -> None:
1225         path = THIS_DIR / "data" / "include_exclude_tests"
1226         report = black.Report()
1227         empty = re.compile(r"")
1228         sources: List[Path] = []
1229         expected = [
1230             Path(path / "b/dont_exclude/a.py"),
1231             Path(path / "b/dont_exclude/a.pyi"),
1232             Path(path / "b/exclude/a.py"),
1233             Path(path / "b/exclude/a.pyi"),
1234             Path(path / "b/.definitely_exclude/a.py"),
1235             Path(path / "b/.definitely_exclude/a.pyi"),
1236         ]
1237         this_abs = THIS_DIR.resolve()
1238         sources.extend(
1239             black.gen_python_files_in_dir(
1240                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1241             )
1242         )
1243         self.assertEqual(sorted(expected), sorted(sources))
1244
1245     def test_invalid_include_exclude(self) -> None:
1246         for option in ["--include", "--exclude"]:
1247             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
1248             self.assertEqual(result.exit_code, 2)
1249
1250     def test_preserves_line_endings(self) -> None:
1251         with TemporaryDirectory() as workspace:
1252             test_file = Path(workspace) / "test.py"
1253             for nl in ["\n", "\r\n"]:
1254                 contents = nl.join(["def f(  ):", "    pass"])
1255                 test_file.write_bytes(contents.encode())
1256                 ff(test_file, write_back=black.WriteBack.YES)
1257                 updated_contents: bytes = test_file.read_bytes()
1258                 self.assertIn(nl.encode(), updated_contents)
1259                 if nl == "\n":
1260                     self.assertNotIn(b"\r\n", updated_contents)
1261
1262     def test_preserves_line_endings_via_stdin(self) -> None:
1263         for nl in ["\n", "\r\n"]:
1264             contents = nl.join(["def f(  ):", "    pass"])
1265             runner = BlackRunner()
1266             result = runner.invoke(
1267                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1268             )
1269             self.assertEqual(result.exit_code, 0)
1270             output = runner.stdout_bytes
1271             self.assertIn(nl.encode("utf8"), output)
1272             if nl == "\n":
1273                 self.assertNotIn(b"\r\n", output)
1274
1275     def test_assert_equivalent_different_asts(self) -> None:
1276         with self.assertRaises(AssertionError):
1277             black.assert_equivalent("{}", "None")
1278
1279     def test_symlink_out_of_root_directory(self) -> None:
1280         path = MagicMock()
1281         root = THIS_DIR
1282         child = MagicMock()
1283         include = re.compile(black.DEFAULT_INCLUDES)
1284         exclude = re.compile(black.DEFAULT_EXCLUDES)
1285         report = black.Report()
1286         # `child` should behave like a symlink which resolved path is clearly
1287         # outside of the `root` directory.
1288         path.iterdir.return_value = [child]
1289         child.resolve.return_value = Path("/a/b/c")
1290         child.is_symlink.return_value = True
1291         try:
1292             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1293         except ValueError as ve:
1294             self.fail("`get_python_files_in_dir()` failed: {ve}")
1295         path.iterdir.assert_called_once()
1296         child.resolve.assert_called_once()
1297         child.is_symlink.assert_called_once()
1298         # `child` should behave like a strange file which resolved path is clearly
1299         # outside of the `root` directory.
1300         child.is_symlink.return_value = False
1301         with self.assertRaises(ValueError):
1302             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1303         path.iterdir.assert_called()
1304         self.assertEqual(path.iterdir.call_count, 2)
1305         child.resolve.assert_called()
1306         self.assertEqual(child.resolve.call_count, 2)
1307         child.is_symlink.assert_called()
1308         self.assertEqual(child.is_symlink.call_count, 2)
1309
1310     def test_shhh_click(self) -> None:
1311         try:
1312             from click import _unicodefun  # type: ignore
1313         except ModuleNotFoundError:
1314             self.skipTest("Incompatible Click version")
1315         if not hasattr(_unicodefun, "_verify_python3_env"):
1316             self.skipTest("Incompatible Click version")
1317         # First, let's see if Click is crashing with a preferred ASCII charset.
1318         with patch("locale.getpreferredencoding") as gpe:
1319             gpe.return_value = "ASCII"
1320             with self.assertRaises(RuntimeError):
1321                 _unicodefun._verify_python3_env()
1322         # Now, let's silence Click...
1323         black.patch_click()
1324         # ...and confirm it's silent.
1325         with patch("locale.getpreferredencoding") as gpe:
1326             gpe.return_value = "ASCII"
1327             try:
1328                 _unicodefun._verify_python3_env()
1329             except RuntimeError as re:
1330                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1331
1332     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1333     @async_test
1334     async def test_blackd_request_needs_formatting(self) -> None:
1335         app = blackd.make_app()
1336         async with TestClient(TestServer(app)) as client:
1337             response = await client.post("/", data=b"print('hello world')")
1338             self.assertEqual(response.status, 200)
1339             self.assertEqual(response.charset, "utf8")
1340             self.assertEqual(await response.read(), b'print("hello world")\n')
1341
1342     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1343     @async_test
1344     async def test_blackd_request_no_change(self) -> None:
1345         app = blackd.make_app()
1346         async with TestClient(TestServer(app)) as client:
1347             response = await client.post("/", data=b'print("hello world")\n')
1348             self.assertEqual(response.status, 204)
1349             self.assertEqual(await response.read(), b"")
1350
1351     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1352     @async_test
1353     async def test_blackd_request_syntax_error(self) -> None:
1354         app = blackd.make_app()
1355         async with TestClient(TestServer(app)) as client:
1356             response = await client.post("/", data=b"what even ( is")
1357             self.assertEqual(response.status, 400)
1358             content = await response.text()
1359             self.assertTrue(
1360                 content.startswith("Cannot parse"),
1361                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1362             )
1363
1364     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1365     @async_test
1366     async def test_blackd_unsupported_version(self) -> None:
1367         app = blackd.make_app()
1368         async with TestClient(TestServer(app)) as client:
1369             response = await client.post(
1370                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1371             )
1372             self.assertEqual(response.status, 501)
1373
1374     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1375     @async_test
1376     async def test_blackd_supported_version(self) -> None:
1377         app = blackd.make_app()
1378         async with TestClient(TestServer(app)) as client:
1379             response = await client.post(
1380                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1381             )
1382             self.assertEqual(response.status, 200)
1383
1384     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1385     @async_test
1386     async def test_blackd_invalid_python_variant(self) -> None:
1387         app = blackd.make_app()
1388         async with TestClient(TestServer(app)) as client:
1389             response = await client.post(
1390                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: "lol"}
1391             )
1392             self.assertEqual(response.status, 400)
1393
1394     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1395     @async_test
1396     async def test_blackd_pyi(self) -> None:
1397         app = blackd.make_app()
1398         async with TestClient(TestServer(app)) as client:
1399             source, expected = read_data("stub.pyi")
1400             response = await client.post(
1401                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1402             )
1403             self.assertEqual(response.status, 200)
1404             self.assertEqual(await response.text(), expected)
1405
1406     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1407     @async_test
1408     async def test_blackd_py36(self) -> None:
1409         app = blackd.make_app()
1410         async with TestClient(TestServer(app)) as client:
1411             response = await client.post(
1412                 "/",
1413                 data=(
1414                     "def f(\n"
1415                     "    and_has_a_bunch_of,\n"
1416                     "    very_long_arguments_too,\n"
1417                     "    and_lots_of_them_as_well_lol,\n"
1418                     "    **and_very_long_keyword_arguments\n"
1419                     "):\n"
1420                     "    pass\n"
1421                 ),
1422                 headers={blackd.PYTHON_VARIANT_HEADER: "3.6"},
1423             )
1424             self.assertEqual(response.status, 200)
1425             response = await client.post(
1426                 "/",
1427                 data=(
1428                     "def f(\n"
1429                     "    and_has_a_bunch_of,\n"
1430                     "    very_long_arguments_too,\n"
1431                     "    and_lots_of_them_as_well_lol,\n"
1432                     "    **and_very_long_keyword_arguments\n"
1433                     "):\n"
1434                     "    pass\n"
1435                 ),
1436                 headers={blackd.PYTHON_VARIANT_HEADER: "3.5"},
1437             )
1438             self.assertEqual(response.status, 204)
1439             response = await client.post(
1440                 "/",
1441                 data=(
1442                     "def f(\n"
1443                     "    and_has_a_bunch_of,\n"
1444                     "    very_long_arguments_too,\n"
1445                     "    and_lots_of_them_as_well_lol,\n"
1446                     "    **and_very_long_keyword_arguments\n"
1447                     "):\n"
1448                     "    pass\n"
1449                 ),
1450                 headers={blackd.PYTHON_VARIANT_HEADER: "2"},
1451             )
1452             self.assertEqual(response.status, 204)
1453
1454     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1455     @async_test
1456     async def test_blackd_fast(self) -> None:
1457         app = blackd.make_app()
1458         async with TestClient(TestServer(app)) as client:
1459             response = await client.post("/", data=b"ur'hello'")
1460             self.assertEqual(response.status, 500)
1461             self.assertIn("failed to parse source file", await response.text())
1462             response = await client.post(
1463                 "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1464             )
1465             self.assertEqual(response.status, 200)
1466
1467     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1468     @async_test
1469     async def test_blackd_line_length(self) -> None:
1470         app = blackd.make_app()
1471         async with TestClient(TestServer(app)) as client:
1472             response = await client.post(
1473                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1474             )
1475             self.assertEqual(response.status, 200)
1476
1477     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1478     @async_test
1479     async def test_blackd_invalid_line_length(self) -> None:
1480         app = blackd.make_app()
1481         async with TestClient(TestServer(app)) as client:
1482             response = await client.post(
1483                 "/",
1484                 data=b'print("hello")\n',
1485                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1486             )
1487             self.assertEqual(response.status, 400)
1488
1489     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1490     def test_blackd_main(self) -> None:
1491         with patch("blackd.web.run_app"):
1492             result = CliRunner().invoke(blackd.main, [])
1493             if result.exception is not None:
1494                 raise result.exception
1495             self.assertEqual(result.exit_code, 0)
1496
1497
1498 if __name__ == "__main__":
1499     unittest.main(module="test_black")