]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

fix misformatting of floats with leading zeros (#464)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import BytesIO, TextIOWrapper
7 import os
8 from pathlib import Path
9 import re
10 import sys
11 from tempfile import TemporaryDirectory
12 from typing import Any, BinaryIO, Generator, List, Tuple, Iterator
13 import unittest
14 from unittest.mock import patch, MagicMock
15
16 from click import unstyle
17 from click.testing import CliRunner
18
19 import black
20
21
22 ll = 88
23 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
24 fs = partial(black.format_str, line_length=ll)
25 THIS_FILE = Path(__file__)
26 THIS_DIR = THIS_FILE.parent
27 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
28
29
30 def dump_to_stderr(*output: str) -> str:
31     return "\n" + "\n".join(output) + "\n"
32
33
34 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
35     """read_data('test_name') -> 'input', 'output'"""
36     if not name.endswith((".py", ".pyi", ".out", ".diff")):
37         name += ".py"
38     _input: List[str] = []
39     _output: List[str] = []
40     base_dir = THIS_DIR / "data" if data else THIS_DIR
41     with open(base_dir / name, "r", encoding="utf8") as test:
42         lines = test.readlines()
43     result = _input
44     for line in lines:
45         line = line.replace(EMPTY_LINE, "")
46         if line.rstrip() == "# output":
47             result = _output
48             continue
49
50         result.append(line)
51     if _input and not _output:
52         # If there's no output marker, treat the entire file as already pre-formatted.
53         _output = _input[:]
54     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
55
56
57 @contextmanager
58 def cache_dir(exists: bool = True) -> Iterator[Path]:
59     with TemporaryDirectory() as workspace:
60         cache_dir = Path(workspace)
61         if not exists:
62             cache_dir = cache_dir / "new"
63         with patch("black.CACHE_DIR", cache_dir):
64             yield cache_dir
65
66
67 @contextmanager
68 def event_loop(close: bool) -> Iterator[None]:
69     policy = asyncio.get_event_loop_policy()
70     old_loop = policy.get_event_loop()
71     loop = policy.new_event_loop()
72     asyncio.set_event_loop(loop)
73     try:
74         yield
75
76     finally:
77         policy.set_event_loop(old_loop)
78         if close:
79             loop.close()
80
81
82 class BlackRunner(CliRunner):
83     """Modify CliRunner so that stderr is not merged with stdout.
84
85     This is a hack that can be removed once we depend on Click 7.x"""
86
87     def __init__(self, stderrbuf: BinaryIO) -> None:
88         self.stderrbuf = stderrbuf
89         super().__init__()
90
91     @contextmanager
92     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
93         with super().isolation(*args, **kwargs) as output:
94             try:
95                 hold_stderr = sys.stderr
96                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
97                 yield output
98             finally:
99                 sys.stderr = hold_stderr
100
101
102 class BlackTestCase(unittest.TestCase):
103     maxDiff = None
104
105     def assertFormatEqual(self, expected: str, actual: str) -> None:
106         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
107             bdv: black.DebugVisitor[Any]
108             black.out("Expected tree:", fg="green")
109             try:
110                 exp_node = black.lib2to3_parse(expected)
111                 bdv = black.DebugVisitor()
112                 list(bdv.visit(exp_node))
113             except Exception as ve:
114                 black.err(str(ve))
115             black.out("Actual tree:", fg="red")
116             try:
117                 exp_node = black.lib2to3_parse(actual)
118                 bdv = black.DebugVisitor()
119                 list(bdv.visit(exp_node))
120             except Exception as ve:
121                 black.err(str(ve))
122         self.assertEqual(expected, actual)
123
124     @patch("black.dump_to_file", dump_to_stderr)
125     def test_empty(self) -> None:
126         source = expected = ""
127         actual = fs(source)
128         self.assertFormatEqual(expected, actual)
129         black.assert_equivalent(source, actual)
130         black.assert_stable(source, actual, line_length=ll)
131
132     def test_empty_ff(self) -> None:
133         expected = ""
134         tmp_file = Path(black.dump_to_file())
135         try:
136             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
137             with open(tmp_file, encoding="utf8") as f:
138                 actual = f.read()
139         finally:
140             os.unlink(tmp_file)
141         self.assertFormatEqual(expected, actual)
142
143     @patch("black.dump_to_file", dump_to_stderr)
144     def test_self(self) -> None:
145         source, expected = read_data("test_black", data=False)
146         actual = fs(source)
147         self.assertFormatEqual(expected, actual)
148         black.assert_equivalent(source, actual)
149         black.assert_stable(source, actual, line_length=ll)
150         self.assertFalse(ff(THIS_FILE))
151
152     @patch("black.dump_to_file", dump_to_stderr)
153     def test_black(self) -> None:
154         source, expected = read_data("../black", data=False)
155         actual = fs(source)
156         self.assertFormatEqual(expected, actual)
157         black.assert_equivalent(source, actual)
158         black.assert_stable(source, actual, line_length=ll)
159         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
160
161     def test_piping(self) -> None:
162         source, expected = read_data("../black", data=False)
163         stderrbuf = BytesIO()
164         result = BlackRunner(stderrbuf).invoke(
165             black.main,
166             ["-", "--fast", f"--line-length={ll}"],
167             input=BytesIO(source.encode("utf8")),
168         )
169         self.assertEqual(result.exit_code, 0)
170         self.assertFormatEqual(expected, result.output)
171         black.assert_equivalent(source, result.output)
172         black.assert_stable(source, result.output, line_length=ll)
173
174     def test_piping_diff(self) -> None:
175         diff_header = re.compile(
176             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
177             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
178         )
179         source, _ = read_data("expression.py")
180         expected, _ = read_data("expression.diff")
181         config = THIS_DIR / "data" / "empty_pyproject.toml"
182         stderrbuf = BytesIO()
183         args = ["-", "--fast", f"--line-length={ll}", "--diff", f"--config={config}"]
184         result = BlackRunner(stderrbuf).invoke(
185             black.main, args, input=BytesIO(source.encode("utf8"))
186         )
187         self.assertEqual(result.exit_code, 0)
188         actual = diff_header.sub("[Deterministic header]", result.output)
189         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
190         self.assertEqual(expected, actual)
191
192     @patch("black.dump_to_file", dump_to_stderr)
193     def test_setup(self) -> None:
194         source, expected = read_data("../setup", data=False)
195         actual = fs(source)
196         self.assertFormatEqual(expected, actual)
197         black.assert_equivalent(source, actual)
198         black.assert_stable(source, actual, line_length=ll)
199         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
200
201     @patch("black.dump_to_file", dump_to_stderr)
202     def test_function(self) -> None:
203         source, expected = read_data("function")
204         actual = fs(source)
205         self.assertFormatEqual(expected, actual)
206         black.assert_equivalent(source, actual)
207         black.assert_stable(source, actual, line_length=ll)
208
209     @patch("black.dump_to_file", dump_to_stderr)
210     def test_function2(self) -> None:
211         source, expected = read_data("function2")
212         actual = fs(source)
213         self.assertFormatEqual(expected, actual)
214         black.assert_equivalent(source, actual)
215         black.assert_stable(source, actual, line_length=ll)
216
217     @patch("black.dump_to_file", dump_to_stderr)
218     def test_expression(self) -> None:
219         source, expected = read_data("expression")
220         actual = fs(source)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, line_length=ll)
224
225     def test_expression_ff(self) -> None:
226         source, expected = read_data("expression")
227         tmp_file = Path(black.dump_to_file(source))
228         try:
229             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
230             with open(tmp_file, encoding="utf8") as f:
231                 actual = f.read()
232         finally:
233             os.unlink(tmp_file)
234         self.assertFormatEqual(expected, actual)
235         with patch("black.dump_to_file", dump_to_stderr):
236             black.assert_equivalent(source, actual)
237             black.assert_stable(source, actual, line_length=ll)
238
239     def test_expression_diff(self) -> None:
240         source, _ = read_data("expression.py")
241         expected, _ = read_data("expression.diff")
242         tmp_file = Path(black.dump_to_file(source))
243         diff_header = re.compile(
244             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
245             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
246         )
247         stderrbuf = BytesIO()
248         try:
249             result = BlackRunner(stderrbuf).invoke(
250                 black.main, ["--diff", str(tmp_file)]
251             )
252             self.assertEqual(result.exit_code, 0)
253         finally:
254             os.unlink(tmp_file)
255         actual = result.output
256         actual = diff_header.sub("[Deterministic header]", actual)
257         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
258         if expected != actual:
259             dump = black.dump_to_file(actual)
260             msg = (
261                 f"Expected diff isn't equal to the actual. If you made changes "
262                 f"to expression.py and this is an anticipated difference, "
263                 f"overwrite tests/expression.diff with {dump}"
264             )
265             self.assertEqual(expected, actual, msg)
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_fstring(self) -> None:
269         source, expected = read_data("fstring")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, line_length=ll)
274
275     @patch("black.dump_to_file", dump_to_stderr)
276     def test_string_quotes(self) -> None:
277         source, expected = read_data("string_quotes")
278         actual = fs(source)
279         self.assertFormatEqual(expected, actual)
280         black.assert_equivalent(source, actual)
281         black.assert_stable(source, actual, line_length=ll)
282         mode = black.FileMode.NO_STRING_NORMALIZATION
283         not_normalized = fs(source, mode=mode)
284         self.assertFormatEqual(source, not_normalized)
285         black.assert_equivalent(source, not_normalized)
286         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_slices(self) -> None:
290         source, expected = read_data("slices")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_comments(self) -> None:
298         source, expected = read_data("comments")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_comments2(self) -> None:
306         source, expected = read_data("comments2")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     @patch("black.dump_to_file", dump_to_stderr)
313     def test_comments3(self) -> None:
314         source, expected = read_data("comments3")
315         actual = fs(source)
316         self.assertFormatEqual(expected, actual)
317         black.assert_equivalent(source, actual)
318         black.assert_stable(source, actual, line_length=ll)
319
320     @patch("black.dump_to_file", dump_to_stderr)
321     def test_comments4(self) -> None:
322         source, expected = read_data("comments4")
323         actual = fs(source)
324         self.assertFormatEqual(expected, actual)
325         black.assert_equivalent(source, actual)
326         black.assert_stable(source, actual, line_length=ll)
327
328     @patch("black.dump_to_file", dump_to_stderr)
329     def test_comments5(self) -> None:
330         source, expected = read_data("comments5")
331         actual = fs(source)
332         self.assertFormatEqual(expected, actual)
333         black.assert_equivalent(source, actual)
334         black.assert_stable(source, actual, line_length=ll)
335
336     @patch("black.dump_to_file", dump_to_stderr)
337     def test_cantfit(self) -> None:
338         source, expected = read_data("cantfit")
339         actual = fs(source)
340         self.assertFormatEqual(expected, actual)
341         black.assert_equivalent(source, actual)
342         black.assert_stable(source, actual, line_length=ll)
343
344     @patch("black.dump_to_file", dump_to_stderr)
345     def test_import_spacing(self) -> None:
346         source, expected = read_data("import_spacing")
347         actual = fs(source)
348         self.assertFormatEqual(expected, actual)
349         black.assert_equivalent(source, actual)
350         black.assert_stable(source, actual, line_length=ll)
351
352     @patch("black.dump_to_file", dump_to_stderr)
353     def test_composition(self) -> None:
354         source, expected = read_data("composition")
355         actual = fs(source)
356         self.assertFormatEqual(expected, actual)
357         black.assert_equivalent(source, actual)
358         black.assert_stable(source, actual, line_length=ll)
359
360     @patch("black.dump_to_file", dump_to_stderr)
361     def test_empty_lines(self) -> None:
362         source, expected = read_data("empty_lines")
363         actual = fs(source)
364         self.assertFormatEqual(expected, actual)
365         black.assert_equivalent(source, actual)
366         black.assert_stable(source, actual, line_length=ll)
367
368     @patch("black.dump_to_file", dump_to_stderr)
369     def test_string_prefixes(self) -> None:
370         source, expected = read_data("string_prefixes")
371         actual = fs(source)
372         self.assertFormatEqual(expected, actual)
373         black.assert_equivalent(source, actual)
374         black.assert_stable(source, actual, line_length=ll)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_numeric_literals(self) -> None:
378         source, expected = read_data("numeric_literals")
379         actual = fs(source, mode=black.FileMode.PYTHON36)
380         self.assertFormatEqual(expected, actual)
381         black.assert_equivalent(source, actual)
382         black.assert_stable(source, actual, line_length=ll)
383
384     @patch("black.dump_to_file", dump_to_stderr)
385     def test_numeric_literals_py2(self) -> None:
386         source, expected = read_data("numeric_literals_py2")
387         actual = fs(source)
388         self.assertFormatEqual(expected, actual)
389         black.assert_stable(source, actual, line_length=ll)
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_python2(self) -> None:
393         source, expected = read_data("python2")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         # black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, line_length=ll)
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_python2_unicode_literals(self) -> None:
401         source, expected = read_data("python2_unicode_literals")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_stable(source, actual, line_length=ll)
405
406     @patch("black.dump_to_file", dump_to_stderr)
407     def test_stub(self) -> None:
408         mode = black.FileMode.PYI
409         source, expected = read_data("stub.pyi")
410         actual = fs(source, mode=mode)
411         self.assertFormatEqual(expected, actual)
412         black.assert_stable(source, actual, line_length=ll, mode=mode)
413
414     @patch("black.dump_to_file", dump_to_stderr)
415     def test_python37(self) -> None:
416         source, expected = read_data("python37")
417         actual = fs(source)
418         self.assertFormatEqual(expected, actual)
419         major, minor = sys.version_info[:2]
420         if major > 3 or (major == 3 and minor >= 7):
421             black.assert_equivalent(source, actual)
422         black.assert_stable(source, actual, line_length=ll)
423
424     @patch("black.dump_to_file", dump_to_stderr)
425     def test_fmtonoff(self) -> None:
426         source, expected = read_data("fmtonoff")
427         actual = fs(source)
428         self.assertFormatEqual(expected, actual)
429         black.assert_equivalent(source, actual)
430         black.assert_stable(source, actual, line_length=ll)
431
432     @patch("black.dump_to_file", dump_to_stderr)
433     def test_fmtonoff2(self) -> None:
434         source, expected = read_data("fmtonoff2")
435         actual = fs(source)
436         self.assertFormatEqual(expected, actual)
437         black.assert_equivalent(source, actual)
438         black.assert_stable(source, actual, line_length=ll)
439
440     @patch("black.dump_to_file", dump_to_stderr)
441     def test_remove_empty_parentheses_after_class(self) -> None:
442         source, expected = read_data("class_blank_parentheses")
443         actual = fs(source)
444         self.assertFormatEqual(expected, actual)
445         black.assert_equivalent(source, actual)
446         black.assert_stable(source, actual, line_length=ll)
447
448     @patch("black.dump_to_file", dump_to_stderr)
449     def test_new_line_between_class_and_code(self) -> None:
450         source, expected = read_data("class_methods_new_line")
451         actual = fs(source)
452         self.assertFormatEqual(expected, actual)
453         black.assert_equivalent(source, actual)
454         black.assert_stable(source, actual, line_length=ll)
455
456     def test_report_verbose(self) -> None:
457         report = black.Report(verbose=True)
458         out_lines = []
459         err_lines = []
460
461         def out(msg: str, **kwargs: Any) -> None:
462             out_lines.append(msg)
463
464         def err(msg: str, **kwargs: Any) -> None:
465             err_lines.append(msg)
466
467         with patch("black.out", out), patch("black.err", err):
468             report.done(Path("f1"), black.Changed.NO)
469             self.assertEqual(len(out_lines), 1)
470             self.assertEqual(len(err_lines), 0)
471             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
472             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
473             self.assertEqual(report.return_code, 0)
474             report.done(Path("f2"), black.Changed.YES)
475             self.assertEqual(len(out_lines), 2)
476             self.assertEqual(len(err_lines), 0)
477             self.assertEqual(out_lines[-1], "reformatted f2")
478             self.assertEqual(
479                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
480             )
481             report.done(Path("f3"), black.Changed.CACHED)
482             self.assertEqual(len(out_lines), 3)
483             self.assertEqual(len(err_lines), 0)
484             self.assertEqual(
485                 out_lines[-1], "f3 wasn't modified on disk since last run."
486             )
487             self.assertEqual(
488                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
489             )
490             self.assertEqual(report.return_code, 0)
491             report.check = True
492             self.assertEqual(report.return_code, 1)
493             report.check = False
494             report.failed(Path("e1"), "boom")
495             self.assertEqual(len(out_lines), 3)
496             self.assertEqual(len(err_lines), 1)
497             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
498             self.assertEqual(
499                 unstyle(str(report)),
500                 "1 file reformatted, 2 files left unchanged, "
501                 "1 file failed to reformat.",
502             )
503             self.assertEqual(report.return_code, 123)
504             report.done(Path("f3"), black.Changed.YES)
505             self.assertEqual(len(out_lines), 4)
506             self.assertEqual(len(err_lines), 1)
507             self.assertEqual(out_lines[-1], "reformatted f3")
508             self.assertEqual(
509                 unstyle(str(report)),
510                 "2 files reformatted, 2 files left unchanged, "
511                 "1 file failed to reformat.",
512             )
513             self.assertEqual(report.return_code, 123)
514             report.failed(Path("e2"), "boom")
515             self.assertEqual(len(out_lines), 4)
516             self.assertEqual(len(err_lines), 2)
517             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
518             self.assertEqual(
519                 unstyle(str(report)),
520                 "2 files reformatted, 2 files left unchanged, "
521                 "2 files failed to reformat.",
522             )
523             self.assertEqual(report.return_code, 123)
524             report.path_ignored(Path("wat"), "no match")
525             self.assertEqual(len(out_lines), 5)
526             self.assertEqual(len(err_lines), 2)
527             self.assertEqual(out_lines[-1], "wat ignored: no match")
528             self.assertEqual(
529                 unstyle(str(report)),
530                 "2 files reformatted, 2 files left unchanged, "
531                 "2 files failed to reformat.",
532             )
533             self.assertEqual(report.return_code, 123)
534             report.done(Path("f4"), black.Changed.NO)
535             self.assertEqual(len(out_lines), 6)
536             self.assertEqual(len(err_lines), 2)
537             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
538             self.assertEqual(
539                 unstyle(str(report)),
540                 "2 files reformatted, 3 files left unchanged, "
541                 "2 files failed to reformat.",
542             )
543             self.assertEqual(report.return_code, 123)
544             report.check = True
545             self.assertEqual(
546                 unstyle(str(report)),
547                 "2 files would be reformatted, 3 files would be left unchanged, "
548                 "2 files would fail to reformat.",
549             )
550
551     def test_report_quiet(self) -> None:
552         report = black.Report(quiet=True)
553         out_lines = []
554         err_lines = []
555
556         def out(msg: str, **kwargs: Any) -> None:
557             out_lines.append(msg)
558
559         def err(msg: str, **kwargs: Any) -> None:
560             err_lines.append(msg)
561
562         with patch("black.out", out), patch("black.err", err):
563             report.done(Path("f1"), black.Changed.NO)
564             self.assertEqual(len(out_lines), 0)
565             self.assertEqual(len(err_lines), 0)
566             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
567             self.assertEqual(report.return_code, 0)
568             report.done(Path("f2"), black.Changed.YES)
569             self.assertEqual(len(out_lines), 0)
570             self.assertEqual(len(err_lines), 0)
571             self.assertEqual(
572                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
573             )
574             report.done(Path("f3"), black.Changed.CACHED)
575             self.assertEqual(len(out_lines), 0)
576             self.assertEqual(len(err_lines), 0)
577             self.assertEqual(
578                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
579             )
580             self.assertEqual(report.return_code, 0)
581             report.check = True
582             self.assertEqual(report.return_code, 1)
583             report.check = False
584             report.failed(Path("e1"), "boom")
585             self.assertEqual(len(out_lines), 0)
586             self.assertEqual(len(err_lines), 1)
587             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
588             self.assertEqual(
589                 unstyle(str(report)),
590                 "1 file reformatted, 2 files left unchanged, "
591                 "1 file failed to reformat.",
592             )
593             self.assertEqual(report.return_code, 123)
594             report.done(Path("f3"), black.Changed.YES)
595             self.assertEqual(len(out_lines), 0)
596             self.assertEqual(len(err_lines), 1)
597             self.assertEqual(
598                 unstyle(str(report)),
599                 "2 files reformatted, 2 files left unchanged, "
600                 "1 file failed to reformat.",
601             )
602             self.assertEqual(report.return_code, 123)
603             report.failed(Path("e2"), "boom")
604             self.assertEqual(len(out_lines), 0)
605             self.assertEqual(len(err_lines), 2)
606             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
607             self.assertEqual(
608                 unstyle(str(report)),
609                 "2 files reformatted, 2 files left unchanged, "
610                 "2 files failed to reformat.",
611             )
612             self.assertEqual(report.return_code, 123)
613             report.path_ignored(Path("wat"), "no match")
614             self.assertEqual(len(out_lines), 0)
615             self.assertEqual(len(err_lines), 2)
616             self.assertEqual(
617                 unstyle(str(report)),
618                 "2 files reformatted, 2 files left unchanged, "
619                 "2 files failed to reformat.",
620             )
621             self.assertEqual(report.return_code, 123)
622             report.done(Path("f4"), black.Changed.NO)
623             self.assertEqual(len(out_lines), 0)
624             self.assertEqual(len(err_lines), 2)
625             self.assertEqual(
626                 unstyle(str(report)),
627                 "2 files reformatted, 3 files left unchanged, "
628                 "2 files failed to reformat.",
629             )
630             self.assertEqual(report.return_code, 123)
631             report.check = True
632             self.assertEqual(
633                 unstyle(str(report)),
634                 "2 files would be reformatted, 3 files would be left unchanged, "
635                 "2 files would fail to reformat.",
636             )
637
638     def test_report_normal(self) -> None:
639         report = black.Report()
640         out_lines = []
641         err_lines = []
642
643         def out(msg: str, **kwargs: Any) -> None:
644             out_lines.append(msg)
645
646         def err(msg: str, **kwargs: Any) -> None:
647             err_lines.append(msg)
648
649         with patch("black.out", out), patch("black.err", err):
650             report.done(Path("f1"), black.Changed.NO)
651             self.assertEqual(len(out_lines), 0)
652             self.assertEqual(len(err_lines), 0)
653             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
654             self.assertEqual(report.return_code, 0)
655             report.done(Path("f2"), black.Changed.YES)
656             self.assertEqual(len(out_lines), 1)
657             self.assertEqual(len(err_lines), 0)
658             self.assertEqual(out_lines[-1], "reformatted f2")
659             self.assertEqual(
660                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
661             )
662             report.done(Path("f3"), black.Changed.CACHED)
663             self.assertEqual(len(out_lines), 1)
664             self.assertEqual(len(err_lines), 0)
665             self.assertEqual(out_lines[-1], "reformatted f2")
666             self.assertEqual(
667                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
668             )
669             self.assertEqual(report.return_code, 0)
670             report.check = True
671             self.assertEqual(report.return_code, 1)
672             report.check = False
673             report.failed(Path("e1"), "boom")
674             self.assertEqual(len(out_lines), 1)
675             self.assertEqual(len(err_lines), 1)
676             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
677             self.assertEqual(
678                 unstyle(str(report)),
679                 "1 file reformatted, 2 files left unchanged, "
680                 "1 file failed to reformat.",
681             )
682             self.assertEqual(report.return_code, 123)
683             report.done(Path("f3"), black.Changed.YES)
684             self.assertEqual(len(out_lines), 2)
685             self.assertEqual(len(err_lines), 1)
686             self.assertEqual(out_lines[-1], "reformatted f3")
687             self.assertEqual(
688                 unstyle(str(report)),
689                 "2 files reformatted, 2 files left unchanged, "
690                 "1 file failed to reformat.",
691             )
692             self.assertEqual(report.return_code, 123)
693             report.failed(Path("e2"), "boom")
694             self.assertEqual(len(out_lines), 2)
695             self.assertEqual(len(err_lines), 2)
696             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
697             self.assertEqual(
698                 unstyle(str(report)),
699                 "2 files reformatted, 2 files left unchanged, "
700                 "2 files failed to reformat.",
701             )
702             self.assertEqual(report.return_code, 123)
703             report.path_ignored(Path("wat"), "no match")
704             self.assertEqual(len(out_lines), 2)
705             self.assertEqual(len(err_lines), 2)
706             self.assertEqual(
707                 unstyle(str(report)),
708                 "2 files reformatted, 2 files left unchanged, "
709                 "2 files failed to reformat.",
710             )
711             self.assertEqual(report.return_code, 123)
712             report.done(Path("f4"), black.Changed.NO)
713             self.assertEqual(len(out_lines), 2)
714             self.assertEqual(len(err_lines), 2)
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files reformatted, 3 files left unchanged, "
718                 "2 files failed to reformat.",
719             )
720             self.assertEqual(report.return_code, 123)
721             report.check = True
722             self.assertEqual(
723                 unstyle(str(report)),
724                 "2 files would be reformatted, 3 files would be left unchanged, "
725                 "2 files would fail to reformat.",
726             )
727
728     def test_is_python36(self) -> None:
729         node = black.lib2to3_parse("def f(*, arg): ...\n")
730         self.assertFalse(black.is_python36(node))
731         node = black.lib2to3_parse("def f(*, arg,): ...\n")
732         self.assertTrue(black.is_python36(node))
733         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
734         self.assertTrue(black.is_python36(node))
735         node = black.lib2to3_parse("123_456\n")
736         self.assertTrue(black.is_python36(node))
737         node = black.lib2to3_parse("123456\n")
738         self.assertFalse(black.is_python36(node))
739         source, expected = read_data("function")
740         node = black.lib2to3_parse(source)
741         self.assertTrue(black.is_python36(node))
742         node = black.lib2to3_parse(expected)
743         self.assertTrue(black.is_python36(node))
744         source, expected = read_data("expression")
745         node = black.lib2to3_parse(source)
746         self.assertFalse(black.is_python36(node))
747         node = black.lib2to3_parse(expected)
748         self.assertFalse(black.is_python36(node))
749
750     def test_get_future_imports(self) -> None:
751         node = black.lib2to3_parse("\n")
752         self.assertEqual(set(), black.get_future_imports(node))
753         node = black.lib2to3_parse("from __future__ import black\n")
754         self.assertEqual({"black"}, black.get_future_imports(node))
755         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
756         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
757         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
758         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
759         node = black.lib2to3_parse(
760             "from __future__ import multiple\nfrom __future__ import imports\n"
761         )
762         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
763         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
764         self.assertEqual({"black"}, black.get_future_imports(node))
765         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
766         self.assertEqual({"black"}, black.get_future_imports(node))
767         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
768         self.assertEqual(set(), black.get_future_imports(node))
769         node = black.lib2to3_parse("from some.module import black\n")
770         self.assertEqual(set(), black.get_future_imports(node))
771         node = black.lib2to3_parse(
772             "from __future__ import unicode_literals as _unicode_literals"
773         )
774         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
775         node = black.lib2to3_parse(
776             "from __future__ import unicode_literals as _lol, print"
777         )
778         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
779
780     def test_debug_visitor(self) -> None:
781         source, _ = read_data("debug_visitor.py")
782         expected, _ = read_data("debug_visitor.out")
783         out_lines = []
784         err_lines = []
785
786         def out(msg: str, **kwargs: Any) -> None:
787             out_lines.append(msg)
788
789         def err(msg: str, **kwargs: Any) -> None:
790             err_lines.append(msg)
791
792         with patch("black.out", out), patch("black.err", err):
793             black.DebugVisitor.show(source)
794         actual = "\n".join(out_lines) + "\n"
795         log_name = ""
796         if expected != actual:
797             log_name = black.dump_to_file(*out_lines)
798         self.assertEqual(
799             expected,
800             actual,
801             f"AST print out is different. Actual version dumped to {log_name}",
802         )
803
804     def test_format_file_contents(self) -> None:
805         empty = ""
806         with self.assertRaises(black.NothingChanged):
807             black.format_file_contents(empty, line_length=ll, fast=False)
808         just_nl = "\n"
809         with self.assertRaises(black.NothingChanged):
810             black.format_file_contents(just_nl, line_length=ll, fast=False)
811         same = "l = [1, 2, 3]\n"
812         with self.assertRaises(black.NothingChanged):
813             black.format_file_contents(same, line_length=ll, fast=False)
814         different = "l = [1,2,3]"
815         expected = same
816         actual = black.format_file_contents(different, line_length=ll, fast=False)
817         self.assertEqual(expected, actual)
818         invalid = "return if you can"
819         with self.assertRaises(ValueError) as e:
820             black.format_file_contents(invalid, line_length=ll, fast=False)
821         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
822
823     def test_endmarker(self) -> None:
824         n = black.lib2to3_parse("\n")
825         self.assertEqual(n.type, black.syms.file_input)
826         self.assertEqual(len(n.children), 1)
827         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
828
829     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
830     def test_assertFormatEqual(self) -> None:
831         out_lines = []
832         err_lines = []
833
834         def out(msg: str, **kwargs: Any) -> None:
835             out_lines.append(msg)
836
837         def err(msg: str, **kwargs: Any) -> None:
838             err_lines.append(msg)
839
840         with patch("black.out", out), patch("black.err", err):
841             with self.assertRaises(AssertionError):
842                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
843
844         out_str = "".join(out_lines)
845         self.assertTrue("Expected tree:" in out_str)
846         self.assertTrue("Actual tree:" in out_str)
847         self.assertEqual("".join(err_lines), "")
848
849     def test_cache_broken_file(self) -> None:
850         mode = black.FileMode.AUTO_DETECT
851         with cache_dir() as workspace:
852             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
853             with cache_file.open("w") as fobj:
854                 fobj.write("this is not a pickle")
855             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
856             src = (workspace / "test.py").resolve()
857             with src.open("w") as fobj:
858                 fobj.write("print('hello')")
859             result = CliRunner().invoke(black.main, [str(src)])
860             self.assertEqual(result.exit_code, 0)
861             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
862             self.assertIn(src, cache)
863
864     def test_cache_single_file_already_cached(self) -> None:
865         mode = black.FileMode.AUTO_DETECT
866         with cache_dir() as workspace:
867             src = (workspace / "test.py").resolve()
868             with src.open("w") as fobj:
869                 fobj.write("print('hello')")
870             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
871             result = CliRunner().invoke(black.main, [str(src)])
872             self.assertEqual(result.exit_code, 0)
873             with src.open("r") as fobj:
874                 self.assertEqual(fobj.read(), "print('hello')")
875
876     @event_loop(close=False)
877     def test_cache_multiple_files(self) -> None:
878         mode = black.FileMode.AUTO_DETECT
879         with cache_dir() as workspace, patch(
880             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
881         ):
882             one = (workspace / "one.py").resolve()
883             with one.open("w") as fobj:
884                 fobj.write("print('hello')")
885             two = (workspace / "two.py").resolve()
886             with two.open("w") as fobj:
887                 fobj.write("print('hello')")
888             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
889             result = CliRunner().invoke(black.main, [str(workspace)])
890             self.assertEqual(result.exit_code, 0)
891             with one.open("r") as fobj:
892                 self.assertEqual(fobj.read(), "print('hello')")
893             with two.open("r") as fobj:
894                 self.assertEqual(fobj.read(), 'print("hello")\n')
895             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
896             self.assertIn(one, cache)
897             self.assertIn(two, cache)
898
899     def test_no_cache_when_writeback_diff(self) -> None:
900         mode = black.FileMode.AUTO_DETECT
901         with cache_dir() as workspace:
902             src = (workspace / "test.py").resolve()
903             with src.open("w") as fobj:
904                 fobj.write("print('hello')")
905             result = CliRunner().invoke(black.main, [str(src), "--diff"])
906             self.assertEqual(result.exit_code, 0)
907             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
908             self.assertFalse(cache_file.exists())
909
910     def test_no_cache_when_stdin(self) -> None:
911         mode = black.FileMode.AUTO_DETECT
912         with cache_dir():
913             result = CliRunner().invoke(
914                 black.main, ["-"], input=BytesIO(b"print('hello')")
915             )
916             self.assertEqual(result.exit_code, 0)
917             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
918             self.assertFalse(cache_file.exists())
919
920     def test_read_cache_no_cachefile(self) -> None:
921         mode = black.FileMode.AUTO_DETECT
922         with cache_dir():
923             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
924
925     def test_write_cache_read_cache(self) -> None:
926         mode = black.FileMode.AUTO_DETECT
927         with cache_dir() as workspace:
928             src = (workspace / "test.py").resolve()
929             src.touch()
930             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
931             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
932             self.assertIn(src, cache)
933             self.assertEqual(cache[src], black.get_cache_info(src))
934
935     def test_filter_cached(self) -> None:
936         with TemporaryDirectory() as workspace:
937             path = Path(workspace)
938             uncached = (path / "uncached").resolve()
939             cached = (path / "cached").resolve()
940             cached_but_changed = (path / "changed").resolve()
941             uncached.touch()
942             cached.touch()
943             cached_but_changed.touch()
944             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
945             todo, done = black.filter_cached(
946                 cache, {uncached, cached, cached_but_changed}
947             )
948             self.assertEqual(todo, {uncached, cached_but_changed})
949             self.assertEqual(done, {cached})
950
951     def test_write_cache_creates_directory_if_needed(self) -> None:
952         mode = black.FileMode.AUTO_DETECT
953         with cache_dir(exists=False) as workspace:
954             self.assertFalse(workspace.exists())
955             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
956             self.assertTrue(workspace.exists())
957
958     @event_loop(close=False)
959     def test_failed_formatting_does_not_get_cached(self) -> None:
960         mode = black.FileMode.AUTO_DETECT
961         with cache_dir() as workspace, patch(
962             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
963         ):
964             failing = (workspace / "failing.py").resolve()
965             with failing.open("w") as fobj:
966                 fobj.write("not actually python")
967             clean = (workspace / "clean.py").resolve()
968             with clean.open("w") as fobj:
969                 fobj.write('print("hello")\n')
970             result = CliRunner().invoke(black.main, [str(workspace)])
971             self.assertEqual(result.exit_code, 123)
972             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
973             self.assertNotIn(failing, cache)
974             self.assertIn(clean, cache)
975
976     def test_write_cache_write_fail(self) -> None:
977         mode = black.FileMode.AUTO_DETECT
978         with cache_dir(), patch.object(Path, "open") as mock:
979             mock.side_effect = OSError
980             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
981
982     @event_loop(close=False)
983     def test_check_diff_use_together(self) -> None:
984         with cache_dir():
985             # Files which will be reformatted.
986             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
987             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
988             self.assertEqual(result.exit_code, 1, result.output)
989             # Files which will not be reformatted.
990             src2 = (THIS_DIR / "data" / "composition.py").resolve()
991             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
992             self.assertEqual(result.exit_code, 0, result.output)
993             # Multi file command.
994             result = CliRunner().invoke(
995                 black.main, [str(src1), str(src2), "--diff", "--check"]
996             )
997             self.assertEqual(result.exit_code, 1, result.output)
998
999     def test_no_files(self) -> None:
1000         with cache_dir():
1001             # Without an argument, black exits with error code 0.
1002             result = CliRunner().invoke(black.main, [])
1003             self.assertEqual(result.exit_code, 0)
1004
1005     def test_broken_symlink(self) -> None:
1006         with cache_dir() as workspace:
1007             symlink = workspace / "broken_link.py"
1008             try:
1009                 symlink.symlink_to("nonexistent.py")
1010             except OSError as e:
1011                 self.skipTest(f"Can't create symlinks: {e}")
1012             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
1013             self.assertEqual(result.exit_code, 0)
1014
1015     def test_read_cache_line_lengths(self) -> None:
1016         mode = black.FileMode.AUTO_DETECT
1017         with cache_dir() as workspace:
1018             path = (workspace / "file.py").resolve()
1019             path.touch()
1020             black.write_cache({}, [path], 1, mode)
1021             one = black.read_cache(1, mode)
1022             self.assertIn(path, one)
1023             two = black.read_cache(2, mode)
1024             self.assertNotIn(path, two)
1025
1026     def test_single_file_force_pyi(self) -> None:
1027         reg_mode = black.FileMode.AUTO_DETECT
1028         pyi_mode = black.FileMode.PYI
1029         contents, expected = read_data("force_pyi")
1030         with cache_dir() as workspace:
1031             path = (workspace / "file.py").resolve()
1032             with open(path, "w") as fh:
1033                 fh.write(contents)
1034             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
1035             self.assertEqual(result.exit_code, 0)
1036             with open(path, "r") as fh:
1037                 actual = fh.read()
1038             # verify cache with --pyi is separate
1039             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1040             self.assertIn(path, pyi_cache)
1041             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1042             self.assertNotIn(path, normal_cache)
1043         self.assertEqual(actual, expected)
1044
1045     @event_loop(close=False)
1046     def test_multi_file_force_pyi(self) -> None:
1047         reg_mode = black.FileMode.AUTO_DETECT
1048         pyi_mode = black.FileMode.PYI
1049         contents, expected = read_data("force_pyi")
1050         with cache_dir() as workspace:
1051             paths = [
1052                 (workspace / "file1.py").resolve(),
1053                 (workspace / "file2.py").resolve(),
1054             ]
1055             for path in paths:
1056                 with open(path, "w") as fh:
1057                     fh.write(contents)
1058             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
1059             self.assertEqual(result.exit_code, 0)
1060             for path in paths:
1061                 with open(path, "r") as fh:
1062                     actual = fh.read()
1063                 self.assertEqual(actual, expected)
1064             # verify cache with --pyi is separate
1065             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1066             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1067             for path in paths:
1068                 self.assertIn(path, pyi_cache)
1069                 self.assertNotIn(path, normal_cache)
1070
1071     def test_pipe_force_pyi(self) -> None:
1072         source, expected = read_data("force_pyi")
1073         result = CliRunner().invoke(
1074             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1075         )
1076         self.assertEqual(result.exit_code, 0)
1077         actual = result.output
1078         self.assertFormatEqual(actual, expected)
1079
1080     def test_single_file_force_py36(self) -> None:
1081         reg_mode = black.FileMode.AUTO_DETECT
1082         py36_mode = black.FileMode.PYTHON36
1083         source, expected = read_data("force_py36")
1084         with cache_dir() as workspace:
1085             path = (workspace / "file.py").resolve()
1086             with open(path, "w") as fh:
1087                 fh.write(source)
1088             result = CliRunner().invoke(black.main, [str(path), "--py36"])
1089             self.assertEqual(result.exit_code, 0)
1090             with open(path, "r") as fh:
1091                 actual = fh.read()
1092             # verify cache with --py36 is separate
1093             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1094             self.assertIn(path, py36_cache)
1095             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1096             self.assertNotIn(path, normal_cache)
1097         self.assertEqual(actual, expected)
1098
1099     @event_loop(close=False)
1100     def test_multi_file_force_py36(self) -> None:
1101         reg_mode = black.FileMode.AUTO_DETECT
1102         py36_mode = black.FileMode.PYTHON36
1103         source, expected = read_data("force_py36")
1104         with cache_dir() as workspace:
1105             paths = [
1106                 (workspace / "file1.py").resolve(),
1107                 (workspace / "file2.py").resolve(),
1108             ]
1109             for path in paths:
1110                 with open(path, "w") as fh:
1111                     fh.write(source)
1112             result = CliRunner().invoke(
1113                 black.main, [str(p) for p in paths] + ["--py36"]
1114             )
1115             self.assertEqual(result.exit_code, 0)
1116             for path in paths:
1117                 with open(path, "r") as fh:
1118                     actual = fh.read()
1119                 self.assertEqual(actual, expected)
1120             # verify cache with --py36 is separate
1121             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1122             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1123             for path in paths:
1124                 self.assertIn(path, pyi_cache)
1125                 self.assertNotIn(path, normal_cache)
1126
1127     def test_pipe_force_py36(self) -> None:
1128         source, expected = read_data("force_py36")
1129         result = CliRunner().invoke(
1130             black.main, ["-", "-q", "--py36"], input=BytesIO(source.encode("utf8"))
1131         )
1132         self.assertEqual(result.exit_code, 0)
1133         actual = result.output
1134         self.assertFormatEqual(actual, expected)
1135
1136     def test_include_exclude(self) -> None:
1137         path = THIS_DIR / "data" / "include_exclude_tests"
1138         include = re.compile(r"\.pyi?$")
1139         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1140         report = black.Report()
1141         sources: List[Path] = []
1142         expected = [
1143             Path(path / "b/dont_exclude/a.py"),
1144             Path(path / "b/dont_exclude/a.pyi"),
1145         ]
1146         this_abs = THIS_DIR.resolve()
1147         sources.extend(
1148             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1149         )
1150         self.assertEqual(sorted(expected), sorted(sources))
1151
1152     def test_empty_include(self) -> None:
1153         path = THIS_DIR / "data" / "include_exclude_tests"
1154         report = black.Report()
1155         empty = re.compile(r"")
1156         sources: List[Path] = []
1157         expected = [
1158             Path(path / "b/exclude/a.pie"),
1159             Path(path / "b/exclude/a.py"),
1160             Path(path / "b/exclude/a.pyi"),
1161             Path(path / "b/dont_exclude/a.pie"),
1162             Path(path / "b/dont_exclude/a.py"),
1163             Path(path / "b/dont_exclude/a.pyi"),
1164             Path(path / "b/.definitely_exclude/a.pie"),
1165             Path(path / "b/.definitely_exclude/a.py"),
1166             Path(path / "b/.definitely_exclude/a.pyi"),
1167         ]
1168         this_abs = THIS_DIR.resolve()
1169         sources.extend(
1170             black.gen_python_files_in_dir(
1171                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1172             )
1173         )
1174         self.assertEqual(sorted(expected), sorted(sources))
1175
1176     def test_empty_exclude(self) -> None:
1177         path = THIS_DIR / "data" / "include_exclude_tests"
1178         report = black.Report()
1179         empty = re.compile(r"")
1180         sources: List[Path] = []
1181         expected = [
1182             Path(path / "b/dont_exclude/a.py"),
1183             Path(path / "b/dont_exclude/a.pyi"),
1184             Path(path / "b/exclude/a.py"),
1185             Path(path / "b/exclude/a.pyi"),
1186             Path(path / "b/.definitely_exclude/a.py"),
1187             Path(path / "b/.definitely_exclude/a.pyi"),
1188         ]
1189         this_abs = THIS_DIR.resolve()
1190         sources.extend(
1191             black.gen_python_files_in_dir(
1192                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1193             )
1194         )
1195         self.assertEqual(sorted(expected), sorted(sources))
1196
1197     def test_invalid_include_exclude(self) -> None:
1198         for option in ["--include", "--exclude"]:
1199             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
1200             self.assertEqual(result.exit_code, 2)
1201
1202     def test_preserves_line_endings(self) -> None:
1203         with TemporaryDirectory() as workspace:
1204             test_file = Path(workspace) / "test.py"
1205             for nl in ["\n", "\r\n"]:
1206                 contents = nl.join(["def f(  ):", "    pass"])
1207                 test_file.write_bytes(contents.encode())
1208                 ff(test_file, write_back=black.WriteBack.YES)
1209                 updated_contents: bytes = test_file.read_bytes()
1210                 self.assertIn(nl.encode(), updated_contents)
1211                 if nl == "\n":
1212                     self.assertNotIn(b"\r\n", updated_contents)
1213
1214     def test_assert_equivalent_different_asts(self) -> None:
1215         with self.assertRaises(AssertionError):
1216             black.assert_equivalent("{}", "None")
1217
1218     def test_symlink_out_of_root_directory(self) -> None:
1219         path = MagicMock()
1220         root = THIS_DIR
1221         child = MagicMock()
1222         include = re.compile(black.DEFAULT_INCLUDES)
1223         exclude = re.compile(black.DEFAULT_EXCLUDES)
1224         report = black.Report()
1225         # `child` should behave like a symlink which resolved path is clearly
1226         # outside of the `root` directory.
1227         path.iterdir.return_value = [child]
1228         child.resolve.return_value = Path("/a/b/c")
1229         child.is_symlink.return_value = True
1230         try:
1231             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1232         except ValueError as ve:
1233             self.fail("`get_python_files_in_dir()` failed: {ve}")
1234         path.iterdir.assert_called_once()
1235         child.resolve.assert_called_once()
1236         child.is_symlink.assert_called_once()
1237         # `child` should behave like a strange file which resolved path is clearly
1238         # outside of the `root` directory.
1239         child.is_symlink.return_value = False
1240         with self.assertRaises(ValueError):
1241             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1242         path.iterdir.assert_called()
1243         self.assertEqual(path.iterdir.call_count, 2)
1244         child.resolve.assert_called()
1245         self.assertEqual(child.resolve.call_count, 2)
1246         child.is_symlink.assert_called()
1247         self.assertEqual(child.is_symlink.call_count, 2)
1248
1249     def test_shhh_click(self) -> None:
1250         try:
1251             from click import _unicodefun  # type: ignore
1252         except ModuleNotFoundError:
1253             self.skipTest("Incompatible Click version")
1254         if not hasattr(_unicodefun, "_verify_python3_env"):
1255             self.skipTest("Incompatible Click version")
1256         # First, let's see if Click is crashing with a preferred ASCII charset.
1257         with patch("locale.getpreferredencoding") as gpe:
1258             gpe.return_value = "ASCII"
1259             with self.assertRaises(RuntimeError):
1260                 _unicodefun._verify_python3_env()
1261         # Now, let's silence Click...
1262         black.patch_click()
1263         # ...and confirm it's silent.
1264         with patch("locale.getpreferredencoding") as gpe:
1265             gpe.return_value = "ASCII"
1266             try:
1267                 _unicodefun._verify_python3_env()
1268             except RuntimeError as re:
1269                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1270
1271
1272 if __name__ == "__main__":
1273     unittest.main(module="test_black")