]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

b3f1f8262d3437e3203683e972db5cc64c8bf454
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager, redirect_stderr
5 from functools import partial, wraps
6 from io import BytesIO, TextIOWrapper
7 import os
8 from pathlib import Path
9 import re
10 import sys
11 from tempfile import TemporaryDirectory
12 from typing import (
13     Any,
14     BinaryIO,
15     Callable,
16     Coroutine,
17     Generator,
18     List,
19     Tuple,
20     Iterator,
21     TypeVar,
22 )
23 import unittest
24 from unittest.mock import patch, MagicMock
25
26 from click import unstyle
27 from click.testing import CliRunner
28
29 import black
30
31 try:
32     import blackd
33     from aiohttp.test_utils import TestClient, TestServer
34 except ImportError:
35     has_blackd_deps = False
36 else:
37     has_blackd_deps = True
38
39
40 ll = 88
41 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
42 fs = partial(black.format_str, line_length=ll)
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 T = TypeVar("T")
47 R = TypeVar("R")
48
49
50 def dump_to_stderr(*output: str) -> str:
51     return "\n" + "\n".join(output) + "\n"
52
53
54 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
55     """read_data('test_name') -> 'input', 'output'"""
56     if not name.endswith((".py", ".pyi", ".out", ".diff")):
57         name += ".py"
58     _input: List[str] = []
59     _output: List[str] = []
60     base_dir = THIS_DIR / "data" if data else THIS_DIR
61     with open(base_dir / name, "r", encoding="utf8") as test:
62         lines = test.readlines()
63     result = _input
64     for line in lines:
65         line = line.replace(EMPTY_LINE, "")
66         if line.rstrip() == "# output":
67             result = _output
68             continue
69
70         result.append(line)
71     if _input and not _output:
72         # If there's no output marker, treat the entire file as already pre-formatted.
73         _output = _input[:]
74     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
75
76
77 @contextmanager
78 def cache_dir(exists: bool = True) -> Iterator[Path]:
79     with TemporaryDirectory() as workspace:
80         cache_dir = Path(workspace)
81         if not exists:
82             cache_dir = cache_dir / "new"
83         with patch("black.CACHE_DIR", cache_dir):
84             yield cache_dir
85
86
87 @contextmanager
88 def event_loop(close: bool) -> Iterator[None]:
89     policy = asyncio.get_event_loop_policy()
90     old_loop = policy.get_event_loop()
91     loop = policy.new_event_loop()
92     asyncio.set_event_loop(loop)
93     try:
94         yield
95
96     finally:
97         policy.set_event_loop(old_loop)
98         if close:
99             loop.close()
100
101
102 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
103     @event_loop(close=True)
104     @wraps(f)
105     def wrapper(*args: Any, **kwargs: Any) -> None:
106         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
107
108     return wrapper
109
110
111 class BlackRunner(CliRunner):
112     """Modify CliRunner so that stderr is not merged with stdout.
113
114     This is a hack that can be removed once we depend on Click 7.x"""
115
116     def __init__(self) -> None:
117         self.stderrbuf = BytesIO()
118         self.stdoutbuf = BytesIO()
119         self.stdout_bytes = b""
120         self.stderr_bytes = b""
121         super().__init__()
122
123     @contextmanager
124     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
125         with super().isolation(*args, **kwargs) as output:
126             try:
127                 hold_stderr = sys.stderr
128                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
129                 yield output
130             finally:
131                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
132                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
133                 sys.stderr = hold_stderr
134
135
136 class BlackTestCase(unittest.TestCase):
137     maxDiff = None
138
139     def assertFormatEqual(self, expected: str, actual: str) -> None:
140         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
141             bdv: black.DebugVisitor[Any]
142             black.out("Expected tree:", fg="green")
143             try:
144                 exp_node = black.lib2to3_parse(expected)
145                 bdv = black.DebugVisitor()
146                 list(bdv.visit(exp_node))
147             except Exception as ve:
148                 black.err(str(ve))
149             black.out("Actual tree:", fg="red")
150             try:
151                 exp_node = black.lib2to3_parse(actual)
152                 bdv = black.DebugVisitor()
153                 list(bdv.visit(exp_node))
154             except Exception as ve:
155                 black.err(str(ve))
156         self.assertEqual(expected, actual)
157
158     @patch("black.dump_to_file", dump_to_stderr)
159     def test_empty(self) -> None:
160         source = expected = ""
161         actual = fs(source)
162         self.assertFormatEqual(expected, actual)
163         black.assert_equivalent(source, actual)
164         black.assert_stable(source, actual, line_length=ll)
165
166     def test_empty_ff(self) -> None:
167         expected = ""
168         tmp_file = Path(black.dump_to_file())
169         try:
170             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
171             with open(tmp_file, encoding="utf8") as f:
172                 actual = f.read()
173         finally:
174             os.unlink(tmp_file)
175         self.assertFormatEqual(expected, actual)
176
177     @patch("black.dump_to_file", dump_to_stderr)
178     def test_self(self) -> None:
179         source, expected = read_data("test_black", data=False)
180         actual = fs(source)
181         self.assertFormatEqual(expected, actual)
182         black.assert_equivalent(source, actual)
183         black.assert_stable(source, actual, line_length=ll)
184         self.assertFalse(ff(THIS_FILE))
185
186     @patch("black.dump_to_file", dump_to_stderr)
187     def test_black(self) -> None:
188         source, expected = read_data("../black", data=False)
189         actual = fs(source)
190         self.assertFormatEqual(expected, actual)
191         black.assert_equivalent(source, actual)
192         black.assert_stable(source, actual, line_length=ll)
193         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
194
195     def test_piping(self) -> None:
196         source, expected = read_data("../black", data=False)
197         result = BlackRunner().invoke(
198             black.main,
199             ["-", "--fast", f"--line-length={ll}"],
200             input=BytesIO(source.encode("utf8")),
201         )
202         self.assertEqual(result.exit_code, 0)
203         self.assertFormatEqual(expected, result.output)
204         black.assert_equivalent(source, result.output)
205         black.assert_stable(source, result.output, line_length=ll)
206
207     def test_piping_diff(self) -> None:
208         diff_header = re.compile(
209             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
210             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
211         )
212         source, _ = read_data("expression.py")
213         expected, _ = read_data("expression.diff")
214         config = THIS_DIR / "data" / "empty_pyproject.toml"
215         args = ["-", "--fast", f"--line-length={ll}", "--diff", f"--config={config}"]
216         result = BlackRunner().invoke(
217             black.main, args, input=BytesIO(source.encode("utf8"))
218         )
219         self.assertEqual(result.exit_code, 0)
220         actual = diff_header.sub("[Deterministic header]", result.output)
221         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
222         self.assertEqual(expected, actual)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_setup(self) -> None:
226         source, expected = read_data("../setup", data=False)
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
232
233     @patch("black.dump_to_file", dump_to_stderr)
234     def test_function(self) -> None:
235         source, expected = read_data("function")
236         actual = fs(source)
237         self.assertFormatEqual(expected, actual)
238         black.assert_equivalent(source, actual)
239         black.assert_stable(source, actual, line_length=ll)
240
241     @patch("black.dump_to_file", dump_to_stderr)
242     def test_function2(self) -> None:
243         source, expected = read_data("function2")
244         actual = fs(source)
245         self.assertFormatEqual(expected, actual)
246         black.assert_equivalent(source, actual)
247         black.assert_stable(source, actual, line_length=ll)
248
249     @patch("black.dump_to_file", dump_to_stderr)
250     def test_expression(self) -> None:
251         source, expected = read_data("expression")
252         actual = fs(source)
253         self.assertFormatEqual(expected, actual)
254         black.assert_equivalent(source, actual)
255         black.assert_stable(source, actual, line_length=ll)
256
257     def test_expression_ff(self) -> None:
258         source, expected = read_data("expression")
259         tmp_file = Path(black.dump_to_file(source))
260         try:
261             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
262             with open(tmp_file, encoding="utf8") as f:
263                 actual = f.read()
264         finally:
265             os.unlink(tmp_file)
266         self.assertFormatEqual(expected, actual)
267         with patch("black.dump_to_file", dump_to_stderr):
268             black.assert_equivalent(source, actual)
269             black.assert_stable(source, actual, line_length=ll)
270
271     def test_expression_diff(self) -> None:
272         source, _ = read_data("expression.py")
273         expected, _ = read_data("expression.diff")
274         tmp_file = Path(black.dump_to_file(source))
275         diff_header = re.compile(
276             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
277             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
278         )
279         try:
280             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
281             self.assertEqual(result.exit_code, 0)
282         finally:
283             os.unlink(tmp_file)
284         actual = result.output
285         actual = diff_header.sub("[Deterministic header]", actual)
286         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
287         if expected != actual:
288             dump = black.dump_to_file(actual)
289             msg = (
290                 f"Expected diff isn't equal to the actual. If you made changes "
291                 f"to expression.py and this is an anticipated difference, "
292                 f"overwrite tests/data/expression.diff with {dump}"
293             )
294             self.assertEqual(expected, actual, msg)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_fstring(self) -> None:
298         source, expected = read_data("fstring")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_string_quotes(self) -> None:
306         source, expected = read_data("string_quotes")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311         mode = black.FileMode.NO_STRING_NORMALIZATION
312         not_normalized = fs(source, mode=mode)
313         self.assertFormatEqual(source, not_normalized)
314         black.assert_equivalent(source, not_normalized)
315         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
316
317     @patch("black.dump_to_file", dump_to_stderr)
318     def test_slices(self) -> None:
319         source, expected = read_data("slices")
320         actual = fs(source)
321         self.assertFormatEqual(expected, actual)
322         black.assert_equivalent(source, actual)
323         black.assert_stable(source, actual, line_length=ll)
324
325     @patch("black.dump_to_file", dump_to_stderr)
326     def test_comments(self) -> None:
327         source, expected = read_data("comments")
328         actual = fs(source)
329         self.assertFormatEqual(expected, actual)
330         black.assert_equivalent(source, actual)
331         black.assert_stable(source, actual, line_length=ll)
332
333     @patch("black.dump_to_file", dump_to_stderr)
334     def test_comments2(self) -> None:
335         source, expected = read_data("comments2")
336         actual = fs(source)
337         self.assertFormatEqual(expected, actual)
338         black.assert_equivalent(source, actual)
339         black.assert_stable(source, actual, line_length=ll)
340
341     @patch("black.dump_to_file", dump_to_stderr)
342     def test_comments3(self) -> None:
343         source, expected = read_data("comments3")
344         actual = fs(source)
345         self.assertFormatEqual(expected, actual)
346         black.assert_equivalent(source, actual)
347         black.assert_stable(source, actual, line_length=ll)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_comments4(self) -> None:
351         source, expected = read_data("comments4")
352         actual = fs(source)
353         self.assertFormatEqual(expected, actual)
354         black.assert_equivalent(source, actual)
355         black.assert_stable(source, actual, line_length=ll)
356
357     @patch("black.dump_to_file", dump_to_stderr)
358     def test_comments5(self) -> None:
359         source, expected = read_data("comments5")
360         actual = fs(source)
361         self.assertFormatEqual(expected, actual)
362         black.assert_equivalent(source, actual)
363         black.assert_stable(source, actual, line_length=ll)
364
365     @patch("black.dump_to_file", dump_to_stderr)
366     def test_comments6(self) -> None:
367         source, expected = read_data("comments6")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         black.assert_equivalent(source, actual)
371         black.assert_stable(source, actual, line_length=ll)
372
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_cantfit(self) -> None:
375         source, expected = read_data("cantfit")
376         actual = fs(source)
377         self.assertFormatEqual(expected, actual)
378         black.assert_equivalent(source, actual)
379         black.assert_stable(source, actual, line_length=ll)
380
381     @patch("black.dump_to_file", dump_to_stderr)
382     def test_import_spacing(self) -> None:
383         source, expected = read_data("import_spacing")
384         actual = fs(source)
385         self.assertFormatEqual(expected, actual)
386         black.assert_equivalent(source, actual)
387         black.assert_stable(source, actual, line_length=ll)
388
389     @patch("black.dump_to_file", dump_to_stderr)
390     def test_composition(self) -> None:
391         source, expected = read_data("composition")
392         actual = fs(source)
393         self.assertFormatEqual(expected, actual)
394         black.assert_equivalent(source, actual)
395         black.assert_stable(source, actual, line_length=ll)
396
397     @patch("black.dump_to_file", dump_to_stderr)
398     def test_empty_lines(self) -> None:
399         source, expected = read_data("empty_lines")
400         actual = fs(source)
401         self.assertFormatEqual(expected, actual)
402         black.assert_equivalent(source, actual)
403         black.assert_stable(source, actual, line_length=ll)
404
405     @patch("black.dump_to_file", dump_to_stderr)
406     def test_string_prefixes(self) -> None:
407         source, expected = read_data("string_prefixes")
408         actual = fs(source)
409         self.assertFormatEqual(expected, actual)
410         black.assert_equivalent(source, actual)
411         black.assert_stable(source, actual, line_length=ll)
412
413     @patch("black.dump_to_file", dump_to_stderr)
414     def test_numeric_literals(self) -> None:
415         source, expected = read_data("numeric_literals")
416         actual = fs(source, mode=black.FileMode.PYTHON36)
417         self.assertFormatEqual(expected, actual)
418         black.assert_equivalent(source, actual)
419         black.assert_stable(source, actual, line_length=ll)
420
421     @patch("black.dump_to_file", dump_to_stderr)
422     def test_numeric_literals_ignoring_underscores(self) -> None:
423         source, expected = read_data("numeric_literals_skip_underscores")
424         mode = (
425             black.FileMode.PYTHON36 | black.FileMode.NO_NUMERIC_UNDERSCORE_NORMALIZATION
426         )
427         actual = fs(source, mode=mode)
428         self.assertFormatEqual(expected, actual)
429         black.assert_equivalent(source, actual)
430         black.assert_stable(source, actual, line_length=ll, mode=mode)
431
432     @patch("black.dump_to_file", dump_to_stderr)
433     def test_numeric_literals_py2(self) -> None:
434         source, expected = read_data("numeric_literals_py2")
435         actual = fs(source)
436         self.assertFormatEqual(expected, actual)
437         black.assert_stable(source, actual, line_length=ll)
438
439     @patch("black.dump_to_file", dump_to_stderr)
440     def test_python2(self) -> None:
441         source, expected = read_data("python2")
442         actual = fs(source)
443         self.assertFormatEqual(expected, actual)
444         # black.assert_equivalent(source, actual)
445         black.assert_stable(source, actual, line_length=ll)
446
447     @patch("black.dump_to_file", dump_to_stderr)
448     def test_python2_unicode_literals(self) -> None:
449         source, expected = read_data("python2_unicode_literals")
450         actual = fs(source)
451         self.assertFormatEqual(expected, actual)
452         black.assert_stable(source, actual, line_length=ll)
453
454     @patch("black.dump_to_file", dump_to_stderr)
455     def test_stub(self) -> None:
456         mode = black.FileMode.PYI
457         source, expected = read_data("stub.pyi")
458         actual = fs(source, mode=mode)
459         self.assertFormatEqual(expected, actual)
460         black.assert_stable(source, actual, line_length=ll, mode=mode)
461
462     @patch("black.dump_to_file", dump_to_stderr)
463     def test_python37(self) -> None:
464         source, expected = read_data("python37")
465         actual = fs(source)
466         self.assertFormatEqual(expected, actual)
467         major, minor = sys.version_info[:2]
468         if major > 3 or (major == 3 and minor >= 7):
469             black.assert_equivalent(source, actual)
470         black.assert_stable(source, actual, line_length=ll)
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_fmtonoff(self) -> None:
474         source, expected = read_data("fmtonoff")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         black.assert_equivalent(source, actual)
478         black.assert_stable(source, actual, line_length=ll)
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_fmtonoff2(self) -> None:
482         source, expected = read_data("fmtonoff2")
483         actual = fs(source)
484         self.assertFormatEqual(expected, actual)
485         black.assert_equivalent(source, actual)
486         black.assert_stable(source, actual, line_length=ll)
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_remove_empty_parentheses_after_class(self) -> None:
490         source, expected = read_data("class_blank_parentheses")
491         actual = fs(source)
492         self.assertFormatEqual(expected, actual)
493         black.assert_equivalent(source, actual)
494         black.assert_stable(source, actual, line_length=ll)
495
496     @patch("black.dump_to_file", dump_to_stderr)
497     def test_new_line_between_class_and_code(self) -> None:
498         source, expected = read_data("class_methods_new_line")
499         actual = fs(source)
500         self.assertFormatEqual(expected, actual)
501         black.assert_equivalent(source, actual)
502         black.assert_stable(source, actual, line_length=ll)
503
504     @patch("black.dump_to_file", dump_to_stderr)
505     def test_bracket_match(self) -> None:
506         source, expected = read_data("bracketmatch")
507         actual = fs(source)
508         self.assertFormatEqual(expected, actual)
509         black.assert_equivalent(source, actual)
510         black.assert_stable(source, actual, line_length=ll)
511
512     def test_report_verbose(self) -> None:
513         report = black.Report(verbose=True)
514         out_lines = []
515         err_lines = []
516
517         def out(msg: str, **kwargs: Any) -> None:
518             out_lines.append(msg)
519
520         def err(msg: str, **kwargs: Any) -> None:
521             err_lines.append(msg)
522
523         with patch("black.out", out), patch("black.err", err):
524             report.done(Path("f1"), black.Changed.NO)
525             self.assertEqual(len(out_lines), 1)
526             self.assertEqual(len(err_lines), 0)
527             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
528             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
529             self.assertEqual(report.return_code, 0)
530             report.done(Path("f2"), black.Changed.YES)
531             self.assertEqual(len(out_lines), 2)
532             self.assertEqual(len(err_lines), 0)
533             self.assertEqual(out_lines[-1], "reformatted f2")
534             self.assertEqual(
535                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
536             )
537             report.done(Path("f3"), black.Changed.CACHED)
538             self.assertEqual(len(out_lines), 3)
539             self.assertEqual(len(err_lines), 0)
540             self.assertEqual(
541                 out_lines[-1], "f3 wasn't modified on disk since last run."
542             )
543             self.assertEqual(
544                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
545             )
546             self.assertEqual(report.return_code, 0)
547             report.check = True
548             self.assertEqual(report.return_code, 1)
549             report.check = False
550             report.failed(Path("e1"), "boom")
551             self.assertEqual(len(out_lines), 3)
552             self.assertEqual(len(err_lines), 1)
553             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
554             self.assertEqual(
555                 unstyle(str(report)),
556                 "1 file reformatted, 2 files left unchanged, "
557                 "1 file failed to reformat.",
558             )
559             self.assertEqual(report.return_code, 123)
560             report.done(Path("f3"), black.Changed.YES)
561             self.assertEqual(len(out_lines), 4)
562             self.assertEqual(len(err_lines), 1)
563             self.assertEqual(out_lines[-1], "reformatted f3")
564             self.assertEqual(
565                 unstyle(str(report)),
566                 "2 files reformatted, 2 files left unchanged, "
567                 "1 file failed to reformat.",
568             )
569             self.assertEqual(report.return_code, 123)
570             report.failed(Path("e2"), "boom")
571             self.assertEqual(len(out_lines), 4)
572             self.assertEqual(len(err_lines), 2)
573             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
574             self.assertEqual(
575                 unstyle(str(report)),
576                 "2 files reformatted, 2 files left unchanged, "
577                 "2 files failed to reformat.",
578             )
579             self.assertEqual(report.return_code, 123)
580             report.path_ignored(Path("wat"), "no match")
581             self.assertEqual(len(out_lines), 5)
582             self.assertEqual(len(err_lines), 2)
583             self.assertEqual(out_lines[-1], "wat ignored: no match")
584             self.assertEqual(
585                 unstyle(str(report)),
586                 "2 files reformatted, 2 files left unchanged, "
587                 "2 files failed to reformat.",
588             )
589             self.assertEqual(report.return_code, 123)
590             report.done(Path("f4"), black.Changed.NO)
591             self.assertEqual(len(out_lines), 6)
592             self.assertEqual(len(err_lines), 2)
593             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
594             self.assertEqual(
595                 unstyle(str(report)),
596                 "2 files reformatted, 3 files left unchanged, "
597                 "2 files failed to reformat.",
598             )
599             self.assertEqual(report.return_code, 123)
600             report.check = True
601             self.assertEqual(
602                 unstyle(str(report)),
603                 "2 files would be reformatted, 3 files would be left unchanged, "
604                 "2 files would fail to reformat.",
605             )
606
607     def test_report_quiet(self) -> None:
608         report = black.Report(quiet=True)
609         out_lines = []
610         err_lines = []
611
612         def out(msg: str, **kwargs: Any) -> None:
613             out_lines.append(msg)
614
615         def err(msg: str, **kwargs: Any) -> None:
616             err_lines.append(msg)
617
618         with patch("black.out", out), patch("black.err", err):
619             report.done(Path("f1"), black.Changed.NO)
620             self.assertEqual(len(out_lines), 0)
621             self.assertEqual(len(err_lines), 0)
622             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
623             self.assertEqual(report.return_code, 0)
624             report.done(Path("f2"), black.Changed.YES)
625             self.assertEqual(len(out_lines), 0)
626             self.assertEqual(len(err_lines), 0)
627             self.assertEqual(
628                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
629             )
630             report.done(Path("f3"), black.Changed.CACHED)
631             self.assertEqual(len(out_lines), 0)
632             self.assertEqual(len(err_lines), 0)
633             self.assertEqual(
634                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
635             )
636             self.assertEqual(report.return_code, 0)
637             report.check = True
638             self.assertEqual(report.return_code, 1)
639             report.check = False
640             report.failed(Path("e1"), "boom")
641             self.assertEqual(len(out_lines), 0)
642             self.assertEqual(len(err_lines), 1)
643             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
644             self.assertEqual(
645                 unstyle(str(report)),
646                 "1 file reformatted, 2 files left unchanged, "
647                 "1 file failed to reformat.",
648             )
649             self.assertEqual(report.return_code, 123)
650             report.done(Path("f3"), black.Changed.YES)
651             self.assertEqual(len(out_lines), 0)
652             self.assertEqual(len(err_lines), 1)
653             self.assertEqual(
654                 unstyle(str(report)),
655                 "2 files reformatted, 2 files left unchanged, "
656                 "1 file failed to reformat.",
657             )
658             self.assertEqual(report.return_code, 123)
659             report.failed(Path("e2"), "boom")
660             self.assertEqual(len(out_lines), 0)
661             self.assertEqual(len(err_lines), 2)
662             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "2 files reformatted, 2 files left unchanged, "
666                 "2 files failed to reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.path_ignored(Path("wat"), "no match")
670             self.assertEqual(len(out_lines), 0)
671             self.assertEqual(len(err_lines), 2)
672             self.assertEqual(
673                 unstyle(str(report)),
674                 "2 files reformatted, 2 files left unchanged, "
675                 "2 files failed to reformat.",
676             )
677             self.assertEqual(report.return_code, 123)
678             report.done(Path("f4"), black.Changed.NO)
679             self.assertEqual(len(out_lines), 0)
680             self.assertEqual(len(err_lines), 2)
681             self.assertEqual(
682                 unstyle(str(report)),
683                 "2 files reformatted, 3 files left unchanged, "
684                 "2 files failed to reformat.",
685             )
686             self.assertEqual(report.return_code, 123)
687             report.check = True
688             self.assertEqual(
689                 unstyle(str(report)),
690                 "2 files would be reformatted, 3 files would be left unchanged, "
691                 "2 files would fail to reformat.",
692             )
693
694     def test_report_normal(self) -> None:
695         report = black.Report()
696         out_lines = []
697         err_lines = []
698
699         def out(msg: str, **kwargs: Any) -> None:
700             out_lines.append(msg)
701
702         def err(msg: str, **kwargs: Any) -> None:
703             err_lines.append(msg)
704
705         with patch("black.out", out), patch("black.err", err):
706             report.done(Path("f1"), black.Changed.NO)
707             self.assertEqual(len(out_lines), 0)
708             self.assertEqual(len(err_lines), 0)
709             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
710             self.assertEqual(report.return_code, 0)
711             report.done(Path("f2"), black.Changed.YES)
712             self.assertEqual(len(out_lines), 1)
713             self.assertEqual(len(err_lines), 0)
714             self.assertEqual(out_lines[-1], "reformatted f2")
715             self.assertEqual(
716                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
717             )
718             report.done(Path("f3"), black.Changed.CACHED)
719             self.assertEqual(len(out_lines), 1)
720             self.assertEqual(len(err_lines), 0)
721             self.assertEqual(out_lines[-1], "reformatted f2")
722             self.assertEqual(
723                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
724             )
725             self.assertEqual(report.return_code, 0)
726             report.check = True
727             self.assertEqual(report.return_code, 1)
728             report.check = False
729             report.failed(Path("e1"), "boom")
730             self.assertEqual(len(out_lines), 1)
731             self.assertEqual(len(err_lines), 1)
732             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
733             self.assertEqual(
734                 unstyle(str(report)),
735                 "1 file reformatted, 2 files left unchanged, "
736                 "1 file failed to reformat.",
737             )
738             self.assertEqual(report.return_code, 123)
739             report.done(Path("f3"), black.Changed.YES)
740             self.assertEqual(len(out_lines), 2)
741             self.assertEqual(len(err_lines), 1)
742             self.assertEqual(out_lines[-1], "reformatted f3")
743             self.assertEqual(
744                 unstyle(str(report)),
745                 "2 files reformatted, 2 files left unchanged, "
746                 "1 file failed to reformat.",
747             )
748             self.assertEqual(report.return_code, 123)
749             report.failed(Path("e2"), "boom")
750             self.assertEqual(len(out_lines), 2)
751             self.assertEqual(len(err_lines), 2)
752             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
753             self.assertEqual(
754                 unstyle(str(report)),
755                 "2 files reformatted, 2 files left unchanged, "
756                 "2 files failed to reformat.",
757             )
758             self.assertEqual(report.return_code, 123)
759             report.path_ignored(Path("wat"), "no match")
760             self.assertEqual(len(out_lines), 2)
761             self.assertEqual(len(err_lines), 2)
762             self.assertEqual(
763                 unstyle(str(report)),
764                 "2 files reformatted, 2 files left unchanged, "
765                 "2 files failed to reformat.",
766             )
767             self.assertEqual(report.return_code, 123)
768             report.done(Path("f4"), black.Changed.NO)
769             self.assertEqual(len(out_lines), 2)
770             self.assertEqual(len(err_lines), 2)
771             self.assertEqual(
772                 unstyle(str(report)),
773                 "2 files reformatted, 3 files left unchanged, "
774                 "2 files failed to reformat.",
775             )
776             self.assertEqual(report.return_code, 123)
777             report.check = True
778             self.assertEqual(
779                 unstyle(str(report)),
780                 "2 files would be reformatted, 3 files would be left unchanged, "
781                 "2 files would fail to reformat.",
782             )
783
784     def test_is_python36(self) -> None:
785         node = black.lib2to3_parse("def f(*, arg): ...\n")
786         self.assertFalse(black.is_python36(node))
787         node = black.lib2to3_parse("def f(*, arg,): ...\n")
788         self.assertTrue(black.is_python36(node))
789         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
790         self.assertTrue(black.is_python36(node))
791         node = black.lib2to3_parse("123_456\n")
792         self.assertTrue(black.is_python36(node))
793         node = black.lib2to3_parse("123456\n")
794         self.assertFalse(black.is_python36(node))
795         source, expected = read_data("function")
796         node = black.lib2to3_parse(source)
797         self.assertTrue(black.is_python36(node))
798         node = black.lib2to3_parse(expected)
799         self.assertTrue(black.is_python36(node))
800         source, expected = read_data("expression")
801         node = black.lib2to3_parse(source)
802         self.assertFalse(black.is_python36(node))
803         node = black.lib2to3_parse(expected)
804         self.assertFalse(black.is_python36(node))
805
806     def test_get_future_imports(self) -> None:
807         node = black.lib2to3_parse("\n")
808         self.assertEqual(set(), black.get_future_imports(node))
809         node = black.lib2to3_parse("from __future__ import black\n")
810         self.assertEqual({"black"}, black.get_future_imports(node))
811         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
812         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
813         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
814         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
815         node = black.lib2to3_parse(
816             "from __future__ import multiple\nfrom __future__ import imports\n"
817         )
818         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
819         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
820         self.assertEqual({"black"}, black.get_future_imports(node))
821         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
822         self.assertEqual({"black"}, black.get_future_imports(node))
823         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
824         self.assertEqual(set(), black.get_future_imports(node))
825         node = black.lib2to3_parse("from some.module import black\n")
826         self.assertEqual(set(), black.get_future_imports(node))
827         node = black.lib2to3_parse(
828             "from __future__ import unicode_literals as _unicode_literals"
829         )
830         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
831         node = black.lib2to3_parse(
832             "from __future__ import unicode_literals as _lol, print"
833         )
834         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
835
836     def test_debug_visitor(self) -> None:
837         source, _ = read_data("debug_visitor.py")
838         expected, _ = read_data("debug_visitor.out")
839         out_lines = []
840         err_lines = []
841
842         def out(msg: str, **kwargs: Any) -> None:
843             out_lines.append(msg)
844
845         def err(msg: str, **kwargs: Any) -> None:
846             err_lines.append(msg)
847
848         with patch("black.out", out), patch("black.err", err):
849             black.DebugVisitor.show(source)
850         actual = "\n".join(out_lines) + "\n"
851         log_name = ""
852         if expected != actual:
853             log_name = black.dump_to_file(*out_lines)
854         self.assertEqual(
855             expected,
856             actual,
857             f"AST print out is different. Actual version dumped to {log_name}",
858         )
859
860     def test_format_file_contents(self) -> None:
861         empty = ""
862         with self.assertRaises(black.NothingChanged):
863             black.format_file_contents(empty, line_length=ll, fast=False)
864         just_nl = "\n"
865         with self.assertRaises(black.NothingChanged):
866             black.format_file_contents(just_nl, line_length=ll, fast=False)
867         same = "l = [1, 2, 3]\n"
868         with self.assertRaises(black.NothingChanged):
869             black.format_file_contents(same, line_length=ll, fast=False)
870         different = "l = [1,2,3]"
871         expected = same
872         actual = black.format_file_contents(different, line_length=ll, fast=False)
873         self.assertEqual(expected, actual)
874         invalid = "return if you can"
875         with self.assertRaises(black.InvalidInput) as e:
876             black.format_file_contents(invalid, line_length=ll, fast=False)
877         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
878
879     def test_endmarker(self) -> None:
880         n = black.lib2to3_parse("\n")
881         self.assertEqual(n.type, black.syms.file_input)
882         self.assertEqual(len(n.children), 1)
883         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
884
885     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
886     def test_assertFormatEqual(self) -> None:
887         out_lines = []
888         err_lines = []
889
890         def out(msg: str, **kwargs: Any) -> None:
891             out_lines.append(msg)
892
893         def err(msg: str, **kwargs: Any) -> None:
894             err_lines.append(msg)
895
896         with patch("black.out", out), patch("black.err", err):
897             with self.assertRaises(AssertionError):
898                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
899
900         out_str = "".join(out_lines)
901         self.assertTrue("Expected tree:" in out_str)
902         self.assertTrue("Actual tree:" in out_str)
903         self.assertEqual("".join(err_lines), "")
904
905     def test_cache_broken_file(self) -> None:
906         mode = black.FileMode.AUTO_DETECT
907         with cache_dir() as workspace:
908             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
909             with cache_file.open("w") as fobj:
910                 fobj.write("this is not a pickle")
911             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
912             src = (workspace / "test.py").resolve()
913             with src.open("w") as fobj:
914                 fobj.write("print('hello')")
915             result = CliRunner().invoke(black.main, [str(src)])
916             self.assertEqual(result.exit_code, 0)
917             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
918             self.assertIn(src, cache)
919
920     def test_cache_single_file_already_cached(self) -> None:
921         mode = black.FileMode.AUTO_DETECT
922         with cache_dir() as workspace:
923             src = (workspace / "test.py").resolve()
924             with src.open("w") as fobj:
925                 fobj.write("print('hello')")
926             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
927             result = CliRunner().invoke(black.main, [str(src)])
928             self.assertEqual(result.exit_code, 0)
929             with src.open("r") as fobj:
930                 self.assertEqual(fobj.read(), "print('hello')")
931
932     @event_loop(close=False)
933     def test_cache_multiple_files(self) -> None:
934         mode = black.FileMode.AUTO_DETECT
935         with cache_dir() as workspace, patch(
936             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
937         ):
938             one = (workspace / "one.py").resolve()
939             with one.open("w") as fobj:
940                 fobj.write("print('hello')")
941             two = (workspace / "two.py").resolve()
942             with two.open("w") as fobj:
943                 fobj.write("print('hello')")
944             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
945             result = CliRunner().invoke(black.main, [str(workspace)])
946             self.assertEqual(result.exit_code, 0)
947             with one.open("r") as fobj:
948                 self.assertEqual(fobj.read(), "print('hello')")
949             with two.open("r") as fobj:
950                 self.assertEqual(fobj.read(), 'print("hello")\n')
951             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
952             self.assertIn(one, cache)
953             self.assertIn(two, cache)
954
955     def test_no_cache_when_writeback_diff(self) -> None:
956         mode = black.FileMode.AUTO_DETECT
957         with cache_dir() as workspace:
958             src = (workspace / "test.py").resolve()
959             with src.open("w") as fobj:
960                 fobj.write("print('hello')")
961             result = CliRunner().invoke(black.main, [str(src), "--diff"])
962             self.assertEqual(result.exit_code, 0)
963             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
964             self.assertFalse(cache_file.exists())
965
966     def test_no_cache_when_stdin(self) -> None:
967         mode = black.FileMode.AUTO_DETECT
968         with cache_dir():
969             result = CliRunner().invoke(
970                 black.main, ["-"], input=BytesIO(b"print('hello')")
971             )
972             self.assertEqual(result.exit_code, 0)
973             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
974             self.assertFalse(cache_file.exists())
975
976     def test_read_cache_no_cachefile(self) -> None:
977         mode = black.FileMode.AUTO_DETECT
978         with cache_dir():
979             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
980
981     def test_write_cache_read_cache(self) -> None:
982         mode = black.FileMode.AUTO_DETECT
983         with cache_dir() as workspace:
984             src = (workspace / "test.py").resolve()
985             src.touch()
986             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
987             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
988             self.assertIn(src, cache)
989             self.assertEqual(cache[src], black.get_cache_info(src))
990
991     def test_filter_cached(self) -> None:
992         with TemporaryDirectory() as workspace:
993             path = Path(workspace)
994             uncached = (path / "uncached").resolve()
995             cached = (path / "cached").resolve()
996             cached_but_changed = (path / "changed").resolve()
997             uncached.touch()
998             cached.touch()
999             cached_but_changed.touch()
1000             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1001             todo, done = black.filter_cached(
1002                 cache, {uncached, cached, cached_but_changed}
1003             )
1004             self.assertEqual(todo, {uncached, cached_but_changed})
1005             self.assertEqual(done, {cached})
1006
1007     def test_write_cache_creates_directory_if_needed(self) -> None:
1008         mode = black.FileMode.AUTO_DETECT
1009         with cache_dir(exists=False) as workspace:
1010             self.assertFalse(workspace.exists())
1011             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
1012             self.assertTrue(workspace.exists())
1013
1014     @event_loop(close=False)
1015     def test_failed_formatting_does_not_get_cached(self) -> None:
1016         mode = black.FileMode.AUTO_DETECT
1017         with cache_dir() as workspace, patch(
1018             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1019         ):
1020             failing = (workspace / "failing.py").resolve()
1021             with failing.open("w") as fobj:
1022                 fobj.write("not actually python")
1023             clean = (workspace / "clean.py").resolve()
1024             with clean.open("w") as fobj:
1025                 fobj.write('print("hello")\n')
1026             result = CliRunner().invoke(black.main, [str(workspace)])
1027             self.assertEqual(result.exit_code, 123)
1028             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
1029             self.assertNotIn(failing, cache)
1030             self.assertIn(clean, cache)
1031
1032     def test_write_cache_write_fail(self) -> None:
1033         mode = black.FileMode.AUTO_DETECT
1034         with cache_dir(), patch.object(Path, "open") as mock:
1035             mock.side_effect = OSError
1036             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
1037
1038     @event_loop(close=False)
1039     def test_check_diff_use_together(self) -> None:
1040         with cache_dir():
1041             # Files which will be reformatted.
1042             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1043             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
1044             self.assertEqual(result.exit_code, 1, result.output)
1045             # Files which will not be reformatted.
1046             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1047             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
1048             self.assertEqual(result.exit_code, 0, result.output)
1049             # Multi file command.
1050             result = CliRunner().invoke(
1051                 black.main, [str(src1), str(src2), "--diff", "--check"]
1052             )
1053             self.assertEqual(result.exit_code, 1, result.output)
1054
1055     def test_no_files(self) -> None:
1056         with cache_dir():
1057             # Without an argument, black exits with error code 0.
1058             result = CliRunner().invoke(black.main, [])
1059             self.assertEqual(result.exit_code, 0)
1060
1061     def test_broken_symlink(self) -> None:
1062         with cache_dir() as workspace:
1063             symlink = workspace / "broken_link.py"
1064             try:
1065                 symlink.symlink_to("nonexistent.py")
1066             except OSError as e:
1067                 self.skipTest(f"Can't create symlinks: {e}")
1068             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
1069             self.assertEqual(result.exit_code, 0)
1070
1071     def test_read_cache_line_lengths(self) -> None:
1072         mode = black.FileMode.AUTO_DETECT
1073         with cache_dir() as workspace:
1074             path = (workspace / "file.py").resolve()
1075             path.touch()
1076             black.write_cache({}, [path], 1, mode)
1077             one = black.read_cache(1, mode)
1078             self.assertIn(path, one)
1079             two = black.read_cache(2, mode)
1080             self.assertNotIn(path, two)
1081
1082     def test_single_file_force_pyi(self) -> None:
1083         reg_mode = black.FileMode.AUTO_DETECT
1084         pyi_mode = black.FileMode.PYI
1085         contents, expected = read_data("force_pyi")
1086         with cache_dir() as workspace:
1087             path = (workspace / "file.py").resolve()
1088             with open(path, "w") as fh:
1089                 fh.write(contents)
1090             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
1091             self.assertEqual(result.exit_code, 0)
1092             with open(path, "r") as fh:
1093                 actual = fh.read()
1094             # verify cache with --pyi is separate
1095             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1096             self.assertIn(path, pyi_cache)
1097             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1098             self.assertNotIn(path, normal_cache)
1099         self.assertEqual(actual, expected)
1100
1101     @event_loop(close=False)
1102     def test_multi_file_force_pyi(self) -> None:
1103         reg_mode = black.FileMode.AUTO_DETECT
1104         pyi_mode = black.FileMode.PYI
1105         contents, expected = read_data("force_pyi")
1106         with cache_dir() as workspace:
1107             paths = [
1108                 (workspace / "file1.py").resolve(),
1109                 (workspace / "file2.py").resolve(),
1110             ]
1111             for path in paths:
1112                 with open(path, "w") as fh:
1113                     fh.write(contents)
1114             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
1115             self.assertEqual(result.exit_code, 0)
1116             for path in paths:
1117                 with open(path, "r") as fh:
1118                     actual = fh.read()
1119                 self.assertEqual(actual, expected)
1120             # verify cache with --pyi is separate
1121             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1122             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1123             for path in paths:
1124                 self.assertIn(path, pyi_cache)
1125                 self.assertNotIn(path, normal_cache)
1126
1127     def test_pipe_force_pyi(self) -> None:
1128         source, expected = read_data("force_pyi")
1129         result = CliRunner().invoke(
1130             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1131         )
1132         self.assertEqual(result.exit_code, 0)
1133         actual = result.output
1134         self.assertFormatEqual(actual, expected)
1135
1136     def test_single_file_force_py36(self) -> None:
1137         reg_mode = black.FileMode.AUTO_DETECT
1138         py36_mode = black.FileMode.PYTHON36
1139         source, expected = read_data("force_py36")
1140         with cache_dir() as workspace:
1141             path = (workspace / "file.py").resolve()
1142             with open(path, "w") as fh:
1143                 fh.write(source)
1144             result = CliRunner().invoke(black.main, [str(path), "--py36"])
1145             self.assertEqual(result.exit_code, 0)
1146             with open(path, "r") as fh:
1147                 actual = fh.read()
1148             # verify cache with --py36 is separate
1149             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1150             self.assertIn(path, py36_cache)
1151             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1152             self.assertNotIn(path, normal_cache)
1153         self.assertEqual(actual, expected)
1154
1155     @event_loop(close=False)
1156     def test_multi_file_force_py36(self) -> None:
1157         reg_mode = black.FileMode.AUTO_DETECT
1158         py36_mode = black.FileMode.PYTHON36
1159         source, expected = read_data("force_py36")
1160         with cache_dir() as workspace:
1161             paths = [
1162                 (workspace / "file1.py").resolve(),
1163                 (workspace / "file2.py").resolve(),
1164             ]
1165             for path in paths:
1166                 with open(path, "w") as fh:
1167                     fh.write(source)
1168             result = CliRunner().invoke(
1169                 black.main, [str(p) for p in paths] + ["--py36"]
1170             )
1171             self.assertEqual(result.exit_code, 0)
1172             for path in paths:
1173                 with open(path, "r") as fh:
1174                     actual = fh.read()
1175                 self.assertEqual(actual, expected)
1176             # verify cache with --py36 is separate
1177             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1178             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1179             for path in paths:
1180                 self.assertIn(path, pyi_cache)
1181                 self.assertNotIn(path, normal_cache)
1182
1183     def test_pipe_force_py36(self) -> None:
1184         source, expected = read_data("force_py36")
1185         result = CliRunner().invoke(
1186             black.main, ["-", "-q", "--py36"], input=BytesIO(source.encode("utf8"))
1187         )
1188         self.assertEqual(result.exit_code, 0)
1189         actual = result.output
1190         self.assertFormatEqual(actual, expected)
1191
1192     def test_include_exclude(self) -> None:
1193         path = THIS_DIR / "data" / "include_exclude_tests"
1194         include = re.compile(r"\.pyi?$")
1195         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1196         report = black.Report()
1197         sources: List[Path] = []
1198         expected = [
1199             Path(path / "b/dont_exclude/a.py"),
1200             Path(path / "b/dont_exclude/a.pyi"),
1201         ]
1202         this_abs = THIS_DIR.resolve()
1203         sources.extend(
1204             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1205         )
1206         self.assertEqual(sorted(expected), sorted(sources))
1207
1208     def test_empty_include(self) -> None:
1209         path = THIS_DIR / "data" / "include_exclude_tests"
1210         report = black.Report()
1211         empty = re.compile(r"")
1212         sources: List[Path] = []
1213         expected = [
1214             Path(path / "b/exclude/a.pie"),
1215             Path(path / "b/exclude/a.py"),
1216             Path(path / "b/exclude/a.pyi"),
1217             Path(path / "b/dont_exclude/a.pie"),
1218             Path(path / "b/dont_exclude/a.py"),
1219             Path(path / "b/dont_exclude/a.pyi"),
1220             Path(path / "b/.definitely_exclude/a.pie"),
1221             Path(path / "b/.definitely_exclude/a.py"),
1222             Path(path / "b/.definitely_exclude/a.pyi"),
1223         ]
1224         this_abs = THIS_DIR.resolve()
1225         sources.extend(
1226             black.gen_python_files_in_dir(
1227                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1228             )
1229         )
1230         self.assertEqual(sorted(expected), sorted(sources))
1231
1232     def test_empty_exclude(self) -> None:
1233         path = THIS_DIR / "data" / "include_exclude_tests"
1234         report = black.Report()
1235         empty = re.compile(r"")
1236         sources: List[Path] = []
1237         expected = [
1238             Path(path / "b/dont_exclude/a.py"),
1239             Path(path / "b/dont_exclude/a.pyi"),
1240             Path(path / "b/exclude/a.py"),
1241             Path(path / "b/exclude/a.pyi"),
1242             Path(path / "b/.definitely_exclude/a.py"),
1243             Path(path / "b/.definitely_exclude/a.pyi"),
1244         ]
1245         this_abs = THIS_DIR.resolve()
1246         sources.extend(
1247             black.gen_python_files_in_dir(
1248                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1249             )
1250         )
1251         self.assertEqual(sorted(expected), sorted(sources))
1252
1253     def test_invalid_include_exclude(self) -> None:
1254         for option in ["--include", "--exclude"]:
1255             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
1256             self.assertEqual(result.exit_code, 2)
1257
1258     def test_preserves_line_endings(self) -> None:
1259         with TemporaryDirectory() as workspace:
1260             test_file = Path(workspace) / "test.py"
1261             for nl in ["\n", "\r\n"]:
1262                 contents = nl.join(["def f(  ):", "    pass"])
1263                 test_file.write_bytes(contents.encode())
1264                 ff(test_file, write_back=black.WriteBack.YES)
1265                 updated_contents: bytes = test_file.read_bytes()
1266                 self.assertIn(nl.encode(), updated_contents)
1267                 if nl == "\n":
1268                     self.assertNotIn(b"\r\n", updated_contents)
1269
1270     def test_preserves_line_endings_via_stdin(self) -> None:
1271         for nl in ["\n", "\r\n"]:
1272             contents = nl.join(["def f(  ):", "    pass"])
1273             runner = BlackRunner()
1274             result = runner.invoke(
1275                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1276             )
1277             self.assertEqual(result.exit_code, 0)
1278             output = runner.stdout_bytes
1279             self.assertIn(nl.encode("utf8"), output)
1280             if nl == "\n":
1281                 self.assertNotIn(b"\r\n", output)
1282
1283     def test_assert_equivalent_different_asts(self) -> None:
1284         with self.assertRaises(AssertionError):
1285             black.assert_equivalent("{}", "None")
1286
1287     def test_symlink_out_of_root_directory(self) -> None:
1288         path = MagicMock()
1289         root = THIS_DIR
1290         child = MagicMock()
1291         include = re.compile(black.DEFAULT_INCLUDES)
1292         exclude = re.compile(black.DEFAULT_EXCLUDES)
1293         report = black.Report()
1294         # `child` should behave like a symlink which resolved path is clearly
1295         # outside of the `root` directory.
1296         path.iterdir.return_value = [child]
1297         child.resolve.return_value = Path("/a/b/c")
1298         child.is_symlink.return_value = True
1299         try:
1300             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1301         except ValueError as ve:
1302             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1303         path.iterdir.assert_called_once()
1304         child.resolve.assert_called_once()
1305         child.is_symlink.assert_called_once()
1306         # `child` should behave like a strange file which resolved path is clearly
1307         # outside of the `root` directory.
1308         child.is_symlink.return_value = False
1309         with self.assertRaises(ValueError):
1310             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1311         path.iterdir.assert_called()
1312         self.assertEqual(path.iterdir.call_count, 2)
1313         child.resolve.assert_called()
1314         self.assertEqual(child.resolve.call_count, 2)
1315         child.is_symlink.assert_called()
1316         self.assertEqual(child.is_symlink.call_count, 2)
1317
1318     def test_shhh_click(self) -> None:
1319         try:
1320             from click import _unicodefun  # type: ignore
1321         except ModuleNotFoundError:
1322             self.skipTest("Incompatible Click version")
1323         if not hasattr(_unicodefun, "_verify_python3_env"):
1324             self.skipTest("Incompatible Click version")
1325         # First, let's see if Click is crashing with a preferred ASCII charset.
1326         with patch("locale.getpreferredencoding") as gpe:
1327             gpe.return_value = "ASCII"
1328             with self.assertRaises(RuntimeError):
1329                 _unicodefun._verify_python3_env()
1330         # Now, let's silence Click...
1331         black.patch_click()
1332         # ...and confirm it's silent.
1333         with patch("locale.getpreferredencoding") as gpe:
1334             gpe.return_value = "ASCII"
1335             try:
1336                 _unicodefun._verify_python3_env()
1337             except RuntimeError as re:
1338                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1339
1340     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1341     @async_test
1342     async def test_blackd_request_needs_formatting(self) -> None:
1343         app = blackd.make_app()
1344         async with TestClient(TestServer(app)) as client:
1345             response = await client.post("/", data=b"print('hello world')")
1346             self.assertEqual(response.status, 200)
1347             self.assertEqual(response.charset, "utf8")
1348             self.assertEqual(await response.read(), b'print("hello world")\n')
1349
1350     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1351     @async_test
1352     async def test_blackd_request_no_change(self) -> None:
1353         app = blackd.make_app()
1354         async with TestClient(TestServer(app)) as client:
1355             response = await client.post("/", data=b'print("hello world")\n')
1356             self.assertEqual(response.status, 204)
1357             self.assertEqual(await response.read(), b"")
1358
1359     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1360     @async_test
1361     async def test_blackd_request_syntax_error(self) -> None:
1362         app = blackd.make_app()
1363         async with TestClient(TestServer(app)) as client:
1364             response = await client.post("/", data=b"what even ( is")
1365             self.assertEqual(response.status, 400)
1366             content = await response.text()
1367             self.assertTrue(
1368                 content.startswith("Cannot parse"),
1369                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1370             )
1371
1372     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1373     @async_test
1374     async def test_blackd_unsupported_version(self) -> None:
1375         app = blackd.make_app()
1376         async with TestClient(TestServer(app)) as client:
1377             response = await client.post(
1378                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1379             )
1380             self.assertEqual(response.status, 501)
1381
1382     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1383     @async_test
1384     async def test_blackd_supported_version(self) -> None:
1385         app = blackd.make_app()
1386         async with TestClient(TestServer(app)) as client:
1387             response = await client.post(
1388                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1389             )
1390             self.assertEqual(response.status, 200)
1391
1392     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1393     @async_test
1394     async def test_blackd_invalid_python_variant(self) -> None:
1395         app = blackd.make_app()
1396         async with TestClient(TestServer(app)) as client:
1397             response = await client.post(
1398                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: "lol"}
1399             )
1400             self.assertEqual(response.status, 400)
1401
1402     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1403     @async_test
1404     async def test_blackd_pyi(self) -> None:
1405         app = blackd.make_app()
1406         async with TestClient(TestServer(app)) as client:
1407             source, expected = read_data("stub.pyi")
1408             response = await client.post(
1409                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1410             )
1411             self.assertEqual(response.status, 200)
1412             self.assertEqual(await response.text(), expected)
1413
1414     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1415     @async_test
1416     async def test_blackd_py36(self) -> None:
1417         app = blackd.make_app()
1418         async with TestClient(TestServer(app)) as client:
1419             response = await client.post(
1420                 "/",
1421                 data=(
1422                     "def f(\n"
1423                     "    and_has_a_bunch_of,\n"
1424                     "    very_long_arguments_too,\n"
1425                     "    and_lots_of_them_as_well_lol,\n"
1426                     "    **and_very_long_keyword_arguments\n"
1427                     "):\n"
1428                     "    pass\n"
1429                 ),
1430                 headers={blackd.PYTHON_VARIANT_HEADER: "3.6"},
1431             )
1432             self.assertEqual(response.status, 200)
1433             response = await client.post(
1434                 "/",
1435                 data=(
1436                     "def f(\n"
1437                     "    and_has_a_bunch_of,\n"
1438                     "    very_long_arguments_too,\n"
1439                     "    and_lots_of_them_as_well_lol,\n"
1440                     "    **and_very_long_keyword_arguments\n"
1441                     "):\n"
1442                     "    pass\n"
1443                 ),
1444                 headers={blackd.PYTHON_VARIANT_HEADER: "3.5"},
1445             )
1446             self.assertEqual(response.status, 204)
1447             response = await client.post(
1448                 "/",
1449                 data=(
1450                     "def f(\n"
1451                     "    and_has_a_bunch_of,\n"
1452                     "    very_long_arguments_too,\n"
1453                     "    and_lots_of_them_as_well_lol,\n"
1454                     "    **and_very_long_keyword_arguments\n"
1455                     "):\n"
1456                     "    pass\n"
1457                 ),
1458                 headers={blackd.PYTHON_VARIANT_HEADER: "2"},
1459             )
1460             self.assertEqual(response.status, 204)
1461
1462     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1463     @async_test
1464     async def test_blackd_fast(self) -> None:
1465         with open(os.devnull, "w") as dn, redirect_stderr(dn):
1466             app = blackd.make_app()
1467             async with TestClient(TestServer(app)) as client:
1468                 response = await client.post("/", data=b"ur'hello'")
1469                 self.assertEqual(response.status, 500)
1470                 self.assertIn("failed to parse source file", await response.text())
1471                 response = await client.post(
1472                     "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1473                 )
1474                 self.assertEqual(response.status, 200)
1475
1476     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1477     @async_test
1478     async def test_blackd_line_length(self) -> None:
1479         app = blackd.make_app()
1480         async with TestClient(TestServer(app)) as client:
1481             response = await client.post(
1482                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1483             )
1484             self.assertEqual(response.status, 200)
1485
1486     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1487     @async_test
1488     async def test_blackd_invalid_line_length(self) -> None:
1489         app = blackd.make_app()
1490         async with TestClient(TestServer(app)) as client:
1491             response = await client.post(
1492                 "/",
1493                 data=b'print("hello")\n',
1494                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1495             )
1496             self.assertEqual(response.status, 400)
1497
1498     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1499     def test_blackd_main(self) -> None:
1500         with patch("blackd.web.run_app"):
1501             result = CliRunner().invoke(blackd.main, [])
1502             if result.exception is not None:
1503                 raise result.exception
1504             self.assertEqual(result.exit_code, 0)
1505
1506
1507 if __name__ == "__main__":
1508     unittest.main(module="test_black")