]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Update README with missing change log, etc.
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import BytesIO, TextIOWrapper
7 import os
8 from pathlib import Path
9 import re
10 import sys
11 from tempfile import TemporaryDirectory
12 from typing import Any, BinaryIO, Generator, List, Tuple, Iterator
13 import unittest
14 from unittest.mock import patch, MagicMock
15
16 from click import unstyle
17 from click.testing import CliRunner
18
19 import black
20
21
22 ll = 88
23 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
24 fs = partial(black.format_str, line_length=ll)
25 THIS_FILE = Path(__file__)
26 THIS_DIR = THIS_FILE.parent
27 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
28
29
30 def dump_to_stderr(*output: str) -> str:
31     return "\n" + "\n".join(output) + "\n"
32
33
34 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
35     """read_data('test_name') -> 'input', 'output'"""
36     if not name.endswith((".py", ".pyi", ".out", ".diff")):
37         name += ".py"
38     _input: List[str] = []
39     _output: List[str] = []
40     base_dir = THIS_DIR / "data" if data else THIS_DIR
41     with open(base_dir / name, "r", encoding="utf8") as test:
42         lines = test.readlines()
43     result = _input
44     for line in lines:
45         line = line.replace(EMPTY_LINE, "")
46         if line.rstrip() == "# output":
47             result = _output
48             continue
49
50         result.append(line)
51     if _input and not _output:
52         # If there's no output marker, treat the entire file as already pre-formatted.
53         _output = _input[:]
54     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
55
56
57 @contextmanager
58 def cache_dir(exists: bool = True) -> Iterator[Path]:
59     with TemporaryDirectory() as workspace:
60         cache_dir = Path(workspace)
61         if not exists:
62             cache_dir = cache_dir / "new"
63         with patch("black.CACHE_DIR", cache_dir):
64             yield cache_dir
65
66
67 @contextmanager
68 def event_loop(close: bool) -> Iterator[None]:
69     policy = asyncio.get_event_loop_policy()
70     old_loop = policy.get_event_loop()
71     loop = policy.new_event_loop()
72     asyncio.set_event_loop(loop)
73     try:
74         yield
75
76     finally:
77         policy.set_event_loop(old_loop)
78         if close:
79             loop.close()
80
81
82 class BlackRunner(CliRunner):
83     """Modify CliRunner so that stderr is not merged with stdout.
84
85     This is a hack that can be removed once we depend on Click 7.x"""
86
87     def __init__(self, stderrbuf: BinaryIO) -> None:
88         self.stderrbuf = stderrbuf
89         super().__init__()
90
91     @contextmanager
92     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
93         with super().isolation(*args, **kwargs) as output:
94             try:
95                 hold_stderr = sys.stderr
96                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
97                 yield output
98             finally:
99                 sys.stderr = hold_stderr
100
101
102 class BlackTestCase(unittest.TestCase):
103     maxDiff = None
104
105     def assertFormatEqual(self, expected: str, actual: str) -> None:
106         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
107             bdv: black.DebugVisitor[Any]
108             black.out("Expected tree:", fg="green")
109             try:
110                 exp_node = black.lib2to3_parse(expected)
111                 bdv = black.DebugVisitor()
112                 list(bdv.visit(exp_node))
113             except Exception as ve:
114                 black.err(str(ve))
115             black.out("Actual tree:", fg="red")
116             try:
117                 exp_node = black.lib2to3_parse(actual)
118                 bdv = black.DebugVisitor()
119                 list(bdv.visit(exp_node))
120             except Exception as ve:
121                 black.err(str(ve))
122         self.assertEqual(expected, actual)
123
124     @patch("black.dump_to_file", dump_to_stderr)
125     def test_empty(self) -> None:
126         source = expected = ""
127         actual = fs(source)
128         self.assertFormatEqual(expected, actual)
129         black.assert_equivalent(source, actual)
130         black.assert_stable(source, actual, line_length=ll)
131
132     def test_empty_ff(self) -> None:
133         expected = ""
134         tmp_file = Path(black.dump_to_file())
135         try:
136             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
137             with open(tmp_file, encoding="utf8") as f:
138                 actual = f.read()
139         finally:
140             os.unlink(tmp_file)
141         self.assertFormatEqual(expected, actual)
142
143     @patch("black.dump_to_file", dump_to_stderr)
144     def test_self(self) -> None:
145         source, expected = read_data("test_black", data=False)
146         actual = fs(source)
147         self.assertFormatEqual(expected, actual)
148         black.assert_equivalent(source, actual)
149         black.assert_stable(source, actual, line_length=ll)
150         self.assertFalse(ff(THIS_FILE))
151
152     @patch("black.dump_to_file", dump_to_stderr)
153     def test_black(self) -> None:
154         source, expected = read_data("../black", data=False)
155         actual = fs(source)
156         self.assertFormatEqual(expected, actual)
157         black.assert_equivalent(source, actual)
158         black.assert_stable(source, actual, line_length=ll)
159         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
160
161     def test_piping(self) -> None:
162         source, expected = read_data("../black", data=False)
163         stderrbuf = BytesIO()
164         result = BlackRunner(stderrbuf).invoke(
165             black.main,
166             ["-", "--fast", f"--line-length={ll}"],
167             input=BytesIO(source.encode("utf8")),
168         )
169         self.assertEqual(result.exit_code, 0)
170         self.assertFormatEqual(expected, result.output)
171         black.assert_equivalent(source, result.output)
172         black.assert_stable(source, result.output, line_length=ll)
173
174     def test_piping_diff(self) -> None:
175         diff_header = re.compile(
176             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
177             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
178         )
179         source, _ = read_data("expression.py")
180         expected, _ = read_data("expression.diff")
181         config = THIS_DIR / "data" / "empty_pyproject.toml"
182         stderrbuf = BytesIO()
183         args = ["-", "--fast", f"--line-length={ll}", "--diff", f"--config={config}"]
184         result = BlackRunner(stderrbuf).invoke(
185             black.main, args, input=BytesIO(source.encode("utf8"))
186         )
187         self.assertEqual(result.exit_code, 0)
188         actual = diff_header.sub("[Deterministic header]", result.output)
189         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
190         self.assertEqual(expected, actual)
191
192     @patch("black.dump_to_file", dump_to_stderr)
193     def test_setup(self) -> None:
194         source, expected = read_data("../setup", data=False)
195         actual = fs(source)
196         self.assertFormatEqual(expected, actual)
197         black.assert_equivalent(source, actual)
198         black.assert_stable(source, actual, line_length=ll)
199         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
200
201     @patch("black.dump_to_file", dump_to_stderr)
202     def test_function(self) -> None:
203         source, expected = read_data("function")
204         actual = fs(source)
205         self.assertFormatEqual(expected, actual)
206         black.assert_equivalent(source, actual)
207         black.assert_stable(source, actual, line_length=ll)
208
209     @patch("black.dump_to_file", dump_to_stderr)
210     def test_function2(self) -> None:
211         source, expected = read_data("function2")
212         actual = fs(source)
213         self.assertFormatEqual(expected, actual)
214         black.assert_equivalent(source, actual)
215         black.assert_stable(source, actual, line_length=ll)
216
217     @patch("black.dump_to_file", dump_to_stderr)
218     def test_expression(self) -> None:
219         source, expected = read_data("expression")
220         actual = fs(source)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, line_length=ll)
224
225     def test_expression_ff(self) -> None:
226         source, expected = read_data("expression")
227         tmp_file = Path(black.dump_to_file(source))
228         try:
229             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
230             with open(tmp_file, encoding="utf8") as f:
231                 actual = f.read()
232         finally:
233             os.unlink(tmp_file)
234         self.assertFormatEqual(expected, actual)
235         with patch("black.dump_to_file", dump_to_stderr):
236             black.assert_equivalent(source, actual)
237             black.assert_stable(source, actual, line_length=ll)
238
239     def test_expression_diff(self) -> None:
240         source, _ = read_data("expression.py")
241         expected, _ = read_data("expression.diff")
242         tmp_file = Path(black.dump_to_file(source))
243         diff_header = re.compile(
244             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
245             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
246         )
247         stderrbuf = BytesIO()
248         try:
249             result = BlackRunner(stderrbuf).invoke(
250                 black.main, ["--diff", str(tmp_file)]
251             )
252             self.assertEqual(result.exit_code, 0)
253         finally:
254             os.unlink(tmp_file)
255         actual = result.output
256         actual = diff_header.sub("[Deterministic header]", actual)
257         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
258         if expected != actual:
259             dump = black.dump_to_file(actual)
260             msg = (
261                 f"Expected diff isn't equal to the actual. If you made changes "
262                 f"to expression.py and this is an anticipated difference, "
263                 f"overwrite tests/expression.diff with {dump}"
264             )
265             self.assertEqual(expected, actual, msg)
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_fstring(self) -> None:
269         source, expected = read_data("fstring")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, line_length=ll)
274
275     @patch("black.dump_to_file", dump_to_stderr)
276     def test_string_quotes(self) -> None:
277         source, expected = read_data("string_quotes")
278         actual = fs(source)
279         self.assertFormatEqual(expected, actual)
280         black.assert_equivalent(source, actual)
281         black.assert_stable(source, actual, line_length=ll)
282         mode = black.FileMode.NO_STRING_NORMALIZATION
283         not_normalized = fs(source, mode=mode)
284         self.assertFormatEqual(source, not_normalized)
285         black.assert_equivalent(source, not_normalized)
286         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_slices(self) -> None:
290         source, expected = read_data("slices")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_comments(self) -> None:
298         source, expected = read_data("comments")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_comments2(self) -> None:
306         source, expected = read_data("comments2")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     @patch("black.dump_to_file", dump_to_stderr)
313     def test_comments3(self) -> None:
314         source, expected = read_data("comments3")
315         actual = fs(source)
316         self.assertFormatEqual(expected, actual)
317         black.assert_equivalent(source, actual)
318         black.assert_stable(source, actual, line_length=ll)
319
320     @patch("black.dump_to_file", dump_to_stderr)
321     def test_comments4(self) -> None:
322         source, expected = read_data("comments4")
323         actual = fs(source)
324         self.assertFormatEqual(expected, actual)
325         black.assert_equivalent(source, actual)
326         black.assert_stable(source, actual, line_length=ll)
327
328     @patch("black.dump_to_file", dump_to_stderr)
329     def test_comments5(self) -> None:
330         source, expected = read_data("comments5")
331         actual = fs(source)
332         self.assertFormatEqual(expected, actual)
333         black.assert_equivalent(source, actual)
334         black.assert_stable(source, actual, line_length=ll)
335
336     @patch("black.dump_to_file", dump_to_stderr)
337     def test_cantfit(self) -> None:
338         source, expected = read_data("cantfit")
339         actual = fs(source)
340         self.assertFormatEqual(expected, actual)
341         black.assert_equivalent(source, actual)
342         black.assert_stable(source, actual, line_length=ll)
343
344     @patch("black.dump_to_file", dump_to_stderr)
345     def test_import_spacing(self) -> None:
346         source, expected = read_data("import_spacing")
347         actual = fs(source)
348         self.assertFormatEqual(expected, actual)
349         black.assert_equivalent(source, actual)
350         black.assert_stable(source, actual, line_length=ll)
351
352     @patch("black.dump_to_file", dump_to_stderr)
353     def test_composition(self) -> None:
354         source, expected = read_data("composition")
355         actual = fs(source)
356         self.assertFormatEqual(expected, actual)
357         black.assert_equivalent(source, actual)
358         black.assert_stable(source, actual, line_length=ll)
359
360     @patch("black.dump_to_file", dump_to_stderr)
361     def test_empty_lines(self) -> None:
362         source, expected = read_data("empty_lines")
363         actual = fs(source)
364         self.assertFormatEqual(expected, actual)
365         black.assert_equivalent(source, actual)
366         black.assert_stable(source, actual, line_length=ll)
367
368     @patch("black.dump_to_file", dump_to_stderr)
369     def test_string_prefixes(self) -> None:
370         source, expected = read_data("string_prefixes")
371         actual = fs(source)
372         self.assertFormatEqual(expected, actual)
373         black.assert_equivalent(source, actual)
374         black.assert_stable(source, actual, line_length=ll)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_numeric_literals(self) -> None:
378         source, expected = read_data("numeric_literals")
379         actual = fs(source, mode=black.FileMode.PYTHON36)
380         self.assertFormatEqual(expected, actual)
381         black.assert_equivalent(source, actual)
382         black.assert_stable(source, actual, line_length=ll)
383
384     @patch("black.dump_to_file", dump_to_stderr)
385     def test_numeric_literals_py2(self) -> None:
386         source, expected = read_data("numeric_literals_py2")
387         actual = fs(source)
388         self.assertFormatEqual(expected, actual)
389         black.assert_stable(source, actual, line_length=ll)
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_python2(self) -> None:
393         source, expected = read_data("python2")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         # black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, line_length=ll)
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_python2_unicode_literals(self) -> None:
401         source, expected = read_data("python2_unicode_literals")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_stable(source, actual, line_length=ll)
405
406     @patch("black.dump_to_file", dump_to_stderr)
407     def test_stub(self) -> None:
408         mode = black.FileMode.PYI
409         source, expected = read_data("stub.pyi")
410         actual = fs(source, mode=mode)
411         self.assertFormatEqual(expected, actual)
412         black.assert_stable(source, actual, line_length=ll, mode=mode)
413
414     @patch("black.dump_to_file", dump_to_stderr)
415     def test_fmtonoff(self) -> None:
416         source, expected = read_data("fmtonoff")
417         actual = fs(source)
418         self.assertFormatEqual(expected, actual)
419         black.assert_equivalent(source, actual)
420         black.assert_stable(source, actual, line_length=ll)
421
422     @patch("black.dump_to_file", dump_to_stderr)
423     def test_fmtonoff2(self) -> None:
424         source, expected = read_data("fmtonoff2")
425         actual = fs(source)
426         self.assertFormatEqual(expected, actual)
427         black.assert_equivalent(source, actual)
428         black.assert_stable(source, actual, line_length=ll)
429
430     @patch("black.dump_to_file", dump_to_stderr)
431     def test_remove_empty_parentheses_after_class(self) -> None:
432         source, expected = read_data("class_blank_parentheses")
433         actual = fs(source)
434         self.assertFormatEqual(expected, actual)
435         black.assert_equivalent(source, actual)
436         black.assert_stable(source, actual, line_length=ll)
437
438     @patch("black.dump_to_file", dump_to_stderr)
439     def test_new_line_between_class_and_code(self) -> None:
440         source, expected = read_data("class_methods_new_line")
441         actual = fs(source)
442         self.assertFormatEqual(expected, actual)
443         black.assert_equivalent(source, actual)
444         black.assert_stable(source, actual, line_length=ll)
445
446     def test_report_verbose(self) -> None:
447         report = black.Report(verbose=True)
448         out_lines = []
449         err_lines = []
450
451         def out(msg: str, **kwargs: Any) -> None:
452             out_lines.append(msg)
453
454         def err(msg: str, **kwargs: Any) -> None:
455             err_lines.append(msg)
456
457         with patch("black.out", out), patch("black.err", err):
458             report.done(Path("f1"), black.Changed.NO)
459             self.assertEqual(len(out_lines), 1)
460             self.assertEqual(len(err_lines), 0)
461             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
462             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
463             self.assertEqual(report.return_code, 0)
464             report.done(Path("f2"), black.Changed.YES)
465             self.assertEqual(len(out_lines), 2)
466             self.assertEqual(len(err_lines), 0)
467             self.assertEqual(out_lines[-1], "reformatted f2")
468             self.assertEqual(
469                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
470             )
471             report.done(Path("f3"), black.Changed.CACHED)
472             self.assertEqual(len(out_lines), 3)
473             self.assertEqual(len(err_lines), 0)
474             self.assertEqual(
475                 out_lines[-1], "f3 wasn't modified on disk since last run."
476             )
477             self.assertEqual(
478                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
479             )
480             self.assertEqual(report.return_code, 0)
481             report.check = True
482             self.assertEqual(report.return_code, 1)
483             report.check = False
484             report.failed(Path("e1"), "boom")
485             self.assertEqual(len(out_lines), 3)
486             self.assertEqual(len(err_lines), 1)
487             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
488             self.assertEqual(
489                 unstyle(str(report)),
490                 "1 file reformatted, 2 files left unchanged, "
491                 "1 file failed to reformat.",
492             )
493             self.assertEqual(report.return_code, 123)
494             report.done(Path("f3"), black.Changed.YES)
495             self.assertEqual(len(out_lines), 4)
496             self.assertEqual(len(err_lines), 1)
497             self.assertEqual(out_lines[-1], "reformatted f3")
498             self.assertEqual(
499                 unstyle(str(report)),
500                 "2 files reformatted, 2 files left unchanged, "
501                 "1 file failed to reformat.",
502             )
503             self.assertEqual(report.return_code, 123)
504             report.failed(Path("e2"), "boom")
505             self.assertEqual(len(out_lines), 4)
506             self.assertEqual(len(err_lines), 2)
507             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
508             self.assertEqual(
509                 unstyle(str(report)),
510                 "2 files reformatted, 2 files left unchanged, "
511                 "2 files failed to reformat.",
512             )
513             self.assertEqual(report.return_code, 123)
514             report.path_ignored(Path("wat"), "no match")
515             self.assertEqual(len(out_lines), 5)
516             self.assertEqual(len(err_lines), 2)
517             self.assertEqual(out_lines[-1], "wat ignored: no match")
518             self.assertEqual(
519                 unstyle(str(report)),
520                 "2 files reformatted, 2 files left unchanged, "
521                 "2 files failed to reformat.",
522             )
523             self.assertEqual(report.return_code, 123)
524             report.done(Path("f4"), black.Changed.NO)
525             self.assertEqual(len(out_lines), 6)
526             self.assertEqual(len(err_lines), 2)
527             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
528             self.assertEqual(
529                 unstyle(str(report)),
530                 "2 files reformatted, 3 files left unchanged, "
531                 "2 files failed to reformat.",
532             )
533             self.assertEqual(report.return_code, 123)
534             report.check = True
535             self.assertEqual(
536                 unstyle(str(report)),
537                 "2 files would be reformatted, 3 files would be left unchanged, "
538                 "2 files would fail to reformat.",
539             )
540
541     def test_report_quiet(self) -> None:
542         report = black.Report(quiet=True)
543         out_lines = []
544         err_lines = []
545
546         def out(msg: str, **kwargs: Any) -> None:
547             out_lines.append(msg)
548
549         def err(msg: str, **kwargs: Any) -> None:
550             err_lines.append(msg)
551
552         with patch("black.out", out), patch("black.err", err):
553             report.done(Path("f1"), black.Changed.NO)
554             self.assertEqual(len(out_lines), 0)
555             self.assertEqual(len(err_lines), 0)
556             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
557             self.assertEqual(report.return_code, 0)
558             report.done(Path("f2"), black.Changed.YES)
559             self.assertEqual(len(out_lines), 0)
560             self.assertEqual(len(err_lines), 0)
561             self.assertEqual(
562                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
563             )
564             report.done(Path("f3"), black.Changed.CACHED)
565             self.assertEqual(len(out_lines), 0)
566             self.assertEqual(len(err_lines), 0)
567             self.assertEqual(
568                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
569             )
570             self.assertEqual(report.return_code, 0)
571             report.check = True
572             self.assertEqual(report.return_code, 1)
573             report.check = False
574             report.failed(Path("e1"), "boom")
575             self.assertEqual(len(out_lines), 0)
576             self.assertEqual(len(err_lines), 1)
577             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
578             self.assertEqual(
579                 unstyle(str(report)),
580                 "1 file reformatted, 2 files left unchanged, "
581                 "1 file failed to reformat.",
582             )
583             self.assertEqual(report.return_code, 123)
584             report.done(Path("f3"), black.Changed.YES)
585             self.assertEqual(len(out_lines), 0)
586             self.assertEqual(len(err_lines), 1)
587             self.assertEqual(
588                 unstyle(str(report)),
589                 "2 files reformatted, 2 files left unchanged, "
590                 "1 file failed to reformat.",
591             )
592             self.assertEqual(report.return_code, 123)
593             report.failed(Path("e2"), "boom")
594             self.assertEqual(len(out_lines), 0)
595             self.assertEqual(len(err_lines), 2)
596             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
597             self.assertEqual(
598                 unstyle(str(report)),
599                 "2 files reformatted, 2 files left unchanged, "
600                 "2 files failed to reformat.",
601             )
602             self.assertEqual(report.return_code, 123)
603             report.path_ignored(Path("wat"), "no match")
604             self.assertEqual(len(out_lines), 0)
605             self.assertEqual(len(err_lines), 2)
606             self.assertEqual(
607                 unstyle(str(report)),
608                 "2 files reformatted, 2 files left unchanged, "
609                 "2 files failed to reformat.",
610             )
611             self.assertEqual(report.return_code, 123)
612             report.done(Path("f4"), black.Changed.NO)
613             self.assertEqual(len(out_lines), 0)
614             self.assertEqual(len(err_lines), 2)
615             self.assertEqual(
616                 unstyle(str(report)),
617                 "2 files reformatted, 3 files left unchanged, "
618                 "2 files failed to reformat.",
619             )
620             self.assertEqual(report.return_code, 123)
621             report.check = True
622             self.assertEqual(
623                 unstyle(str(report)),
624                 "2 files would be reformatted, 3 files would be left unchanged, "
625                 "2 files would fail to reformat.",
626             )
627
628     def test_report_normal(self) -> None:
629         report = black.Report()
630         out_lines = []
631         err_lines = []
632
633         def out(msg: str, **kwargs: Any) -> None:
634             out_lines.append(msg)
635
636         def err(msg: str, **kwargs: Any) -> None:
637             err_lines.append(msg)
638
639         with patch("black.out", out), patch("black.err", err):
640             report.done(Path("f1"), black.Changed.NO)
641             self.assertEqual(len(out_lines), 0)
642             self.assertEqual(len(err_lines), 0)
643             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
644             self.assertEqual(report.return_code, 0)
645             report.done(Path("f2"), black.Changed.YES)
646             self.assertEqual(len(out_lines), 1)
647             self.assertEqual(len(err_lines), 0)
648             self.assertEqual(out_lines[-1], "reformatted f2")
649             self.assertEqual(
650                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
651             )
652             report.done(Path("f3"), black.Changed.CACHED)
653             self.assertEqual(len(out_lines), 1)
654             self.assertEqual(len(err_lines), 0)
655             self.assertEqual(out_lines[-1], "reformatted f2")
656             self.assertEqual(
657                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
658             )
659             self.assertEqual(report.return_code, 0)
660             report.check = True
661             self.assertEqual(report.return_code, 1)
662             report.check = False
663             report.failed(Path("e1"), "boom")
664             self.assertEqual(len(out_lines), 1)
665             self.assertEqual(len(err_lines), 1)
666             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
667             self.assertEqual(
668                 unstyle(str(report)),
669                 "1 file reformatted, 2 files left unchanged, "
670                 "1 file failed to reformat.",
671             )
672             self.assertEqual(report.return_code, 123)
673             report.done(Path("f3"), black.Changed.YES)
674             self.assertEqual(len(out_lines), 2)
675             self.assertEqual(len(err_lines), 1)
676             self.assertEqual(out_lines[-1], "reformatted f3")
677             self.assertEqual(
678                 unstyle(str(report)),
679                 "2 files reformatted, 2 files left unchanged, "
680                 "1 file failed to reformat.",
681             )
682             self.assertEqual(report.return_code, 123)
683             report.failed(Path("e2"), "boom")
684             self.assertEqual(len(out_lines), 2)
685             self.assertEqual(len(err_lines), 2)
686             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
687             self.assertEqual(
688                 unstyle(str(report)),
689                 "2 files reformatted, 2 files left unchanged, "
690                 "2 files failed to reformat.",
691             )
692             self.assertEqual(report.return_code, 123)
693             report.path_ignored(Path("wat"), "no match")
694             self.assertEqual(len(out_lines), 2)
695             self.assertEqual(len(err_lines), 2)
696             self.assertEqual(
697                 unstyle(str(report)),
698                 "2 files reformatted, 2 files left unchanged, "
699                 "2 files failed to reformat.",
700             )
701             self.assertEqual(report.return_code, 123)
702             report.done(Path("f4"), black.Changed.NO)
703             self.assertEqual(len(out_lines), 2)
704             self.assertEqual(len(err_lines), 2)
705             self.assertEqual(
706                 unstyle(str(report)),
707                 "2 files reformatted, 3 files left unchanged, "
708                 "2 files failed to reformat.",
709             )
710             self.assertEqual(report.return_code, 123)
711             report.check = True
712             self.assertEqual(
713                 unstyle(str(report)),
714                 "2 files would be reformatted, 3 files would be left unchanged, "
715                 "2 files would fail to reformat.",
716             )
717
718     def test_is_python36(self) -> None:
719         node = black.lib2to3_parse("def f(*, arg): ...\n")
720         self.assertFalse(black.is_python36(node))
721         node = black.lib2to3_parse("def f(*, arg,): ...\n")
722         self.assertTrue(black.is_python36(node))
723         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
724         self.assertTrue(black.is_python36(node))
725         source, expected = read_data("function")
726         node = black.lib2to3_parse(source)
727         self.assertTrue(black.is_python36(node))
728         node = black.lib2to3_parse(expected)
729         self.assertTrue(black.is_python36(node))
730         source, expected = read_data("expression")
731         node = black.lib2to3_parse(source)
732         self.assertFalse(black.is_python36(node))
733         node = black.lib2to3_parse(expected)
734         self.assertFalse(black.is_python36(node))
735
736     def test_get_future_imports(self) -> None:
737         node = black.lib2to3_parse("\n")
738         self.assertEqual(set(), black.get_future_imports(node))
739         node = black.lib2to3_parse("from __future__ import black\n")
740         self.assertEqual({"black"}, black.get_future_imports(node))
741         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
742         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
743         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
744         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
745         node = black.lib2to3_parse(
746             "from __future__ import multiple\nfrom __future__ import imports\n"
747         )
748         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
749         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
750         self.assertEqual({"black"}, black.get_future_imports(node))
751         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
752         self.assertEqual({"black"}, black.get_future_imports(node))
753         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
754         self.assertEqual(set(), black.get_future_imports(node))
755         node = black.lib2to3_parse("from some.module import black\n")
756         self.assertEqual(set(), black.get_future_imports(node))
757         node = black.lib2to3_parse(
758             "from __future__ import unicode_literals as _unicode_literals"
759         )
760         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
761         node = black.lib2to3_parse(
762             "from __future__ import unicode_literals as _lol, print"
763         )
764         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
765
766     def test_debug_visitor(self) -> None:
767         source, _ = read_data("debug_visitor.py")
768         expected, _ = read_data("debug_visitor.out")
769         out_lines = []
770         err_lines = []
771
772         def out(msg: str, **kwargs: Any) -> None:
773             out_lines.append(msg)
774
775         def err(msg: str, **kwargs: Any) -> None:
776             err_lines.append(msg)
777
778         with patch("black.out", out), patch("black.err", err):
779             black.DebugVisitor.show(source)
780         actual = "\n".join(out_lines) + "\n"
781         log_name = ""
782         if expected != actual:
783             log_name = black.dump_to_file(*out_lines)
784         self.assertEqual(
785             expected,
786             actual,
787             f"AST print out is different. Actual version dumped to {log_name}",
788         )
789
790     def test_format_file_contents(self) -> None:
791         empty = ""
792         with self.assertRaises(black.NothingChanged):
793             black.format_file_contents(empty, line_length=ll, fast=False)
794         just_nl = "\n"
795         with self.assertRaises(black.NothingChanged):
796             black.format_file_contents(just_nl, line_length=ll, fast=False)
797         same = "l = [1, 2, 3]\n"
798         with self.assertRaises(black.NothingChanged):
799             black.format_file_contents(same, line_length=ll, fast=False)
800         different = "l = [1,2,3]"
801         expected = same
802         actual = black.format_file_contents(different, line_length=ll, fast=False)
803         self.assertEqual(expected, actual)
804         invalid = "return if you can"
805         with self.assertRaises(ValueError) as e:
806             black.format_file_contents(invalid, line_length=ll, fast=False)
807         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
808
809     def test_endmarker(self) -> None:
810         n = black.lib2to3_parse("\n")
811         self.assertEqual(n.type, black.syms.file_input)
812         self.assertEqual(len(n.children), 1)
813         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
814
815     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
816     def test_assertFormatEqual(self) -> None:
817         out_lines = []
818         err_lines = []
819
820         def out(msg: str, **kwargs: Any) -> None:
821             out_lines.append(msg)
822
823         def err(msg: str, **kwargs: Any) -> None:
824             err_lines.append(msg)
825
826         with patch("black.out", out), patch("black.err", err):
827             with self.assertRaises(AssertionError):
828                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
829
830         out_str = "".join(out_lines)
831         self.assertTrue("Expected tree:" in out_str)
832         self.assertTrue("Actual tree:" in out_str)
833         self.assertEqual("".join(err_lines), "")
834
835     def test_cache_broken_file(self) -> None:
836         mode = black.FileMode.AUTO_DETECT
837         with cache_dir() as workspace:
838             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
839             with cache_file.open("w") as fobj:
840                 fobj.write("this is not a pickle")
841             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
842             src = (workspace / "test.py").resolve()
843             with src.open("w") as fobj:
844                 fobj.write("print('hello')")
845             result = CliRunner().invoke(black.main, [str(src)])
846             self.assertEqual(result.exit_code, 0)
847             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
848             self.assertIn(src, cache)
849
850     def test_cache_single_file_already_cached(self) -> None:
851         mode = black.FileMode.AUTO_DETECT
852         with cache_dir() as workspace:
853             src = (workspace / "test.py").resolve()
854             with src.open("w") as fobj:
855                 fobj.write("print('hello')")
856             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
857             result = CliRunner().invoke(black.main, [str(src)])
858             self.assertEqual(result.exit_code, 0)
859             with src.open("r") as fobj:
860                 self.assertEqual(fobj.read(), "print('hello')")
861
862     @event_loop(close=False)
863     def test_cache_multiple_files(self) -> None:
864         mode = black.FileMode.AUTO_DETECT
865         with cache_dir() as workspace, patch(
866             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
867         ):
868             one = (workspace / "one.py").resolve()
869             with one.open("w") as fobj:
870                 fobj.write("print('hello')")
871             two = (workspace / "two.py").resolve()
872             with two.open("w") as fobj:
873                 fobj.write("print('hello')")
874             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
875             result = CliRunner().invoke(black.main, [str(workspace)])
876             self.assertEqual(result.exit_code, 0)
877             with one.open("r") as fobj:
878                 self.assertEqual(fobj.read(), "print('hello')")
879             with two.open("r") as fobj:
880                 self.assertEqual(fobj.read(), 'print("hello")\n')
881             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
882             self.assertIn(one, cache)
883             self.assertIn(two, cache)
884
885     def test_no_cache_when_writeback_diff(self) -> None:
886         mode = black.FileMode.AUTO_DETECT
887         with cache_dir() as workspace:
888             src = (workspace / "test.py").resolve()
889             with src.open("w") as fobj:
890                 fobj.write("print('hello')")
891             result = CliRunner().invoke(black.main, [str(src), "--diff"])
892             self.assertEqual(result.exit_code, 0)
893             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
894             self.assertFalse(cache_file.exists())
895
896     def test_no_cache_when_stdin(self) -> None:
897         mode = black.FileMode.AUTO_DETECT
898         with cache_dir():
899             result = CliRunner().invoke(
900                 black.main, ["-"], input=BytesIO(b"print('hello')")
901             )
902             self.assertEqual(result.exit_code, 0)
903             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
904             self.assertFalse(cache_file.exists())
905
906     def test_read_cache_no_cachefile(self) -> None:
907         mode = black.FileMode.AUTO_DETECT
908         with cache_dir():
909             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
910
911     def test_write_cache_read_cache(self) -> None:
912         mode = black.FileMode.AUTO_DETECT
913         with cache_dir() as workspace:
914             src = (workspace / "test.py").resolve()
915             src.touch()
916             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
917             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
918             self.assertIn(src, cache)
919             self.assertEqual(cache[src], black.get_cache_info(src))
920
921     def test_filter_cached(self) -> None:
922         with TemporaryDirectory() as workspace:
923             path = Path(workspace)
924             uncached = (path / "uncached").resolve()
925             cached = (path / "cached").resolve()
926             cached_but_changed = (path / "changed").resolve()
927             uncached.touch()
928             cached.touch()
929             cached_but_changed.touch()
930             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
931             todo, done = black.filter_cached(
932                 cache, {uncached, cached, cached_but_changed}
933             )
934             self.assertEqual(todo, {uncached, cached_but_changed})
935             self.assertEqual(done, {cached})
936
937     def test_write_cache_creates_directory_if_needed(self) -> None:
938         mode = black.FileMode.AUTO_DETECT
939         with cache_dir(exists=False) as workspace:
940             self.assertFalse(workspace.exists())
941             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
942             self.assertTrue(workspace.exists())
943
944     @event_loop(close=False)
945     def test_failed_formatting_does_not_get_cached(self) -> None:
946         mode = black.FileMode.AUTO_DETECT
947         with cache_dir() as workspace, patch(
948             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
949         ):
950             failing = (workspace / "failing.py").resolve()
951             with failing.open("w") as fobj:
952                 fobj.write("not actually python")
953             clean = (workspace / "clean.py").resolve()
954             with clean.open("w") as fobj:
955                 fobj.write('print("hello")\n')
956             result = CliRunner().invoke(black.main, [str(workspace)])
957             self.assertEqual(result.exit_code, 123)
958             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
959             self.assertNotIn(failing, cache)
960             self.assertIn(clean, cache)
961
962     def test_write_cache_write_fail(self) -> None:
963         mode = black.FileMode.AUTO_DETECT
964         with cache_dir(), patch.object(Path, "open") as mock:
965             mock.side_effect = OSError
966             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
967
968     @event_loop(close=False)
969     def test_check_diff_use_together(self) -> None:
970         with cache_dir():
971             # Files which will be reformatted.
972             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
973             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
974             self.assertEqual(result.exit_code, 1, result.output)
975             # Files which will not be reformatted.
976             src2 = (THIS_DIR / "data" / "composition.py").resolve()
977             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
978             self.assertEqual(result.exit_code, 0, result.output)
979             # Multi file command.
980             result = CliRunner().invoke(
981                 black.main, [str(src1), str(src2), "--diff", "--check"]
982             )
983             self.assertEqual(result.exit_code, 1, result.output)
984
985     def test_no_files(self) -> None:
986         with cache_dir():
987             # Without an argument, black exits with error code 0.
988             result = CliRunner().invoke(black.main, [])
989             self.assertEqual(result.exit_code, 0)
990
991     def test_broken_symlink(self) -> None:
992         with cache_dir() as workspace:
993             symlink = workspace / "broken_link.py"
994             try:
995                 symlink.symlink_to("nonexistent.py")
996             except OSError as e:
997                 self.skipTest(f"Can't create symlinks: {e}")
998             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
999             self.assertEqual(result.exit_code, 0)
1000
1001     def test_read_cache_line_lengths(self) -> None:
1002         mode = black.FileMode.AUTO_DETECT
1003         with cache_dir() as workspace:
1004             path = (workspace / "file.py").resolve()
1005             path.touch()
1006             black.write_cache({}, [path], 1, mode)
1007             one = black.read_cache(1, mode)
1008             self.assertIn(path, one)
1009             two = black.read_cache(2, mode)
1010             self.assertNotIn(path, two)
1011
1012     def test_single_file_force_pyi(self) -> None:
1013         reg_mode = black.FileMode.AUTO_DETECT
1014         pyi_mode = black.FileMode.PYI
1015         contents, expected = read_data("force_pyi")
1016         with cache_dir() as workspace:
1017             path = (workspace / "file.py").resolve()
1018             with open(path, "w") as fh:
1019                 fh.write(contents)
1020             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
1021             self.assertEqual(result.exit_code, 0)
1022             with open(path, "r") as fh:
1023                 actual = fh.read()
1024             # verify cache with --pyi is separate
1025             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1026             self.assertIn(path, pyi_cache)
1027             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1028             self.assertNotIn(path, normal_cache)
1029         self.assertEqual(actual, expected)
1030
1031     @event_loop(close=False)
1032     def test_multi_file_force_pyi(self) -> None:
1033         reg_mode = black.FileMode.AUTO_DETECT
1034         pyi_mode = black.FileMode.PYI
1035         contents, expected = read_data("force_pyi")
1036         with cache_dir() as workspace:
1037             paths = [
1038                 (workspace / "file1.py").resolve(),
1039                 (workspace / "file2.py").resolve(),
1040             ]
1041             for path in paths:
1042                 with open(path, "w") as fh:
1043                     fh.write(contents)
1044             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
1045             self.assertEqual(result.exit_code, 0)
1046             for path in paths:
1047                 with open(path, "r") as fh:
1048                     actual = fh.read()
1049                 self.assertEqual(actual, expected)
1050             # verify cache with --pyi is separate
1051             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1052             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1053             for path in paths:
1054                 self.assertIn(path, pyi_cache)
1055                 self.assertNotIn(path, normal_cache)
1056
1057     def test_pipe_force_pyi(self) -> None:
1058         source, expected = read_data("force_pyi")
1059         result = CliRunner().invoke(
1060             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1061         )
1062         self.assertEqual(result.exit_code, 0)
1063         actual = result.output
1064         self.assertFormatEqual(actual, expected)
1065
1066     def test_single_file_force_py36(self) -> None:
1067         reg_mode = black.FileMode.AUTO_DETECT
1068         py36_mode = black.FileMode.PYTHON36
1069         source, expected = read_data("force_py36")
1070         with cache_dir() as workspace:
1071             path = (workspace / "file.py").resolve()
1072             with open(path, "w") as fh:
1073                 fh.write(source)
1074             result = CliRunner().invoke(black.main, [str(path), "--py36"])
1075             self.assertEqual(result.exit_code, 0)
1076             with open(path, "r") as fh:
1077                 actual = fh.read()
1078             # verify cache with --py36 is separate
1079             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1080             self.assertIn(path, py36_cache)
1081             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1082             self.assertNotIn(path, normal_cache)
1083         self.assertEqual(actual, expected)
1084
1085     @event_loop(close=False)
1086     def test_multi_file_force_py36(self) -> None:
1087         reg_mode = black.FileMode.AUTO_DETECT
1088         py36_mode = black.FileMode.PYTHON36
1089         source, expected = read_data("force_py36")
1090         with cache_dir() as workspace:
1091             paths = [
1092                 (workspace / "file1.py").resolve(),
1093                 (workspace / "file2.py").resolve(),
1094             ]
1095             for path in paths:
1096                 with open(path, "w") as fh:
1097                     fh.write(source)
1098             result = CliRunner().invoke(
1099                 black.main, [str(p) for p in paths] + ["--py36"]
1100             )
1101             self.assertEqual(result.exit_code, 0)
1102             for path in paths:
1103                 with open(path, "r") as fh:
1104                     actual = fh.read()
1105                 self.assertEqual(actual, expected)
1106             # verify cache with --py36 is separate
1107             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1108             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1109             for path in paths:
1110                 self.assertIn(path, pyi_cache)
1111                 self.assertNotIn(path, normal_cache)
1112
1113     def test_pipe_force_py36(self) -> None:
1114         source, expected = read_data("force_py36")
1115         result = CliRunner().invoke(
1116             black.main, ["-", "-q", "--py36"], input=BytesIO(source.encode("utf8"))
1117         )
1118         self.assertEqual(result.exit_code, 0)
1119         actual = result.output
1120         self.assertFormatEqual(actual, expected)
1121
1122     def test_include_exclude(self) -> None:
1123         path = THIS_DIR / "data" / "include_exclude_tests"
1124         include = re.compile(r"\.pyi?$")
1125         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1126         report = black.Report()
1127         sources: List[Path] = []
1128         expected = [
1129             Path(path / "b/dont_exclude/a.py"),
1130             Path(path / "b/dont_exclude/a.pyi"),
1131         ]
1132         this_abs = THIS_DIR.resolve()
1133         sources.extend(
1134             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1135         )
1136         self.assertEqual(sorted(expected), sorted(sources))
1137
1138     def test_empty_include(self) -> None:
1139         path = THIS_DIR / "data" / "include_exclude_tests"
1140         report = black.Report()
1141         empty = re.compile(r"")
1142         sources: List[Path] = []
1143         expected = [
1144             Path(path / "b/exclude/a.pie"),
1145             Path(path / "b/exclude/a.py"),
1146             Path(path / "b/exclude/a.pyi"),
1147             Path(path / "b/dont_exclude/a.pie"),
1148             Path(path / "b/dont_exclude/a.py"),
1149             Path(path / "b/dont_exclude/a.pyi"),
1150             Path(path / "b/.definitely_exclude/a.pie"),
1151             Path(path / "b/.definitely_exclude/a.py"),
1152             Path(path / "b/.definitely_exclude/a.pyi"),
1153         ]
1154         this_abs = THIS_DIR.resolve()
1155         sources.extend(
1156             black.gen_python_files_in_dir(
1157                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1158             )
1159         )
1160         self.assertEqual(sorted(expected), sorted(sources))
1161
1162     def test_empty_exclude(self) -> None:
1163         path = THIS_DIR / "data" / "include_exclude_tests"
1164         report = black.Report()
1165         empty = re.compile(r"")
1166         sources: List[Path] = []
1167         expected = [
1168             Path(path / "b/dont_exclude/a.py"),
1169             Path(path / "b/dont_exclude/a.pyi"),
1170             Path(path / "b/exclude/a.py"),
1171             Path(path / "b/exclude/a.pyi"),
1172             Path(path / "b/.definitely_exclude/a.py"),
1173             Path(path / "b/.definitely_exclude/a.pyi"),
1174         ]
1175         this_abs = THIS_DIR.resolve()
1176         sources.extend(
1177             black.gen_python_files_in_dir(
1178                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1179             )
1180         )
1181         self.assertEqual(sorted(expected), sorted(sources))
1182
1183     def test_invalid_include_exclude(self) -> None:
1184         for option in ["--include", "--exclude"]:
1185             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
1186             self.assertEqual(result.exit_code, 2)
1187
1188     def test_preserves_line_endings(self) -> None:
1189         with TemporaryDirectory() as workspace:
1190             test_file = Path(workspace) / "test.py"
1191             for nl in ["\n", "\r\n"]:
1192                 contents = nl.join(["def f(  ):", "    pass"])
1193                 test_file.write_bytes(contents.encode())
1194                 ff(test_file, write_back=black.WriteBack.YES)
1195                 updated_contents: bytes = test_file.read_bytes()
1196                 self.assertIn(nl.encode(), updated_contents)
1197                 if nl == "\n":
1198                     self.assertNotIn(b"\r\n", updated_contents)
1199
1200     def test_assert_equivalent_different_asts(self) -> None:
1201         with self.assertRaises(AssertionError):
1202             black.assert_equivalent("{}", "None")
1203
1204     def test_symlink_out_of_root_directory(self) -> None:
1205         path = MagicMock()
1206         root = THIS_DIR
1207         child = MagicMock()
1208         include = re.compile(black.DEFAULT_INCLUDES)
1209         exclude = re.compile(black.DEFAULT_EXCLUDES)
1210         report = black.Report()
1211         # `child` should behave like a symlink which resolved path is clearly
1212         # outside of the `root` directory.
1213         path.iterdir.return_value = [child]
1214         child.resolve.return_value = Path("/a/b/c")
1215         child.is_symlink.return_value = True
1216         try:
1217             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1218         except ValueError as ve:
1219             self.fail("`get_python_files_in_dir()` failed: {ve}")
1220         path.iterdir.assert_called_once()
1221         child.resolve.assert_called_once()
1222         child.is_symlink.assert_called_once()
1223         # `child` should behave like a strange file which resolved path is clearly
1224         # outside of the `root` directory.
1225         child.is_symlink.return_value = False
1226         with self.assertRaises(ValueError):
1227             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1228         path.iterdir.assert_called()
1229         self.assertEqual(path.iterdir.call_count, 2)
1230         child.resolve.assert_called()
1231         self.assertEqual(child.resolve.call_count, 2)
1232         child.is_symlink.assert_called()
1233         self.assertEqual(child.is_symlink.call_count, 2)
1234
1235     def test_shhh_click(self) -> None:
1236         try:
1237             from click import _unicodefun  # type: ignore
1238         except ModuleNotFoundError:
1239             self.skipTest("Incompatible Click version")
1240         if not hasattr(_unicodefun, "_verify_python3_env"):
1241             self.skipTest("Incompatible Click version")
1242         # First, let's see if Click is crashing with a preferred ASCII charset.
1243         with patch("locale.getpreferredencoding") as gpe:
1244             gpe.return_value = "ASCII"
1245             with self.assertRaises(RuntimeError):
1246                 _unicodefun._verify_python3_env()
1247         # Now, let's silence Click...
1248         black.patch_click()
1249         # ...and confirm it's silent.
1250         with patch("locale.getpreferredencoding") as gpe:
1251             gpe.return_value = "ASCII"
1252             try:
1253                 _unicodefun._verify_python3_env()
1254             except RuntimeError as re:
1255                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1256
1257
1258 if __name__ == "__main__":
1259     unittest.main(module="test_black")