]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Prefer https:// links where available (#485)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import BytesIO, TextIOWrapper
7 import os
8 from pathlib import Path
9 import re
10 import sys
11 from tempfile import TemporaryDirectory
12 from typing import Any, BinaryIO, Generator, List, Tuple, Iterator
13 import unittest
14 from unittest.mock import patch, MagicMock
15
16 from click import unstyle
17 from click.testing import CliRunner
18
19 import black
20
21
22 ll = 88
23 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
24 fs = partial(black.format_str, line_length=ll)
25 THIS_FILE = Path(__file__)
26 THIS_DIR = THIS_FILE.parent
27 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
28
29
30 def dump_to_stderr(*output: str) -> str:
31     return "\n" + "\n".join(output) + "\n"
32
33
34 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
35     """read_data('test_name') -> 'input', 'output'"""
36     if not name.endswith((".py", ".pyi", ".out", ".diff")):
37         name += ".py"
38     _input: List[str] = []
39     _output: List[str] = []
40     base_dir = THIS_DIR / "data" if data else THIS_DIR
41     with open(base_dir / name, "r", encoding="utf8") as test:
42         lines = test.readlines()
43     result = _input
44     for line in lines:
45         line = line.replace(EMPTY_LINE, "")
46         if line.rstrip() == "# output":
47             result = _output
48             continue
49
50         result.append(line)
51     if _input and not _output:
52         # If there's no output marker, treat the entire file as already pre-formatted.
53         _output = _input[:]
54     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
55
56
57 @contextmanager
58 def cache_dir(exists: bool = True) -> Iterator[Path]:
59     with TemporaryDirectory() as workspace:
60         cache_dir = Path(workspace)
61         if not exists:
62             cache_dir = cache_dir / "new"
63         with patch("black.CACHE_DIR", cache_dir):
64             yield cache_dir
65
66
67 @contextmanager
68 def event_loop(close: bool) -> Iterator[None]:
69     policy = asyncio.get_event_loop_policy()
70     old_loop = policy.get_event_loop()
71     loop = policy.new_event_loop()
72     asyncio.set_event_loop(loop)
73     try:
74         yield
75
76     finally:
77         policy.set_event_loop(old_loop)
78         if close:
79             loop.close()
80
81
82 class BlackRunner(CliRunner):
83     """Modify CliRunner so that stderr is not merged with stdout.
84
85     This is a hack that can be removed once we depend on Click 7.x"""
86
87     def __init__(self) -> None:
88         self.stderrbuf = BytesIO()
89         self.stdoutbuf = BytesIO()
90         self.stdout_bytes = b""
91         self.stderr_bytes = b""
92         super().__init__()
93
94     @contextmanager
95     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
96         with super().isolation(*args, **kwargs) as output:
97             try:
98                 hold_stderr = sys.stderr
99                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
100                 yield output
101             finally:
102                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
103                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
104                 sys.stderr = hold_stderr
105
106
107 class BlackTestCase(unittest.TestCase):
108     maxDiff = None
109
110     def assertFormatEqual(self, expected: str, actual: str) -> None:
111         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
112             bdv: black.DebugVisitor[Any]
113             black.out("Expected tree:", fg="green")
114             try:
115                 exp_node = black.lib2to3_parse(expected)
116                 bdv = black.DebugVisitor()
117                 list(bdv.visit(exp_node))
118             except Exception as ve:
119                 black.err(str(ve))
120             black.out("Actual tree:", fg="red")
121             try:
122                 exp_node = black.lib2to3_parse(actual)
123                 bdv = black.DebugVisitor()
124                 list(bdv.visit(exp_node))
125             except Exception as ve:
126                 black.err(str(ve))
127         self.assertEqual(expected, actual)
128
129     @patch("black.dump_to_file", dump_to_stderr)
130     def test_empty(self) -> None:
131         source = expected = ""
132         actual = fs(source)
133         self.assertFormatEqual(expected, actual)
134         black.assert_equivalent(source, actual)
135         black.assert_stable(source, actual, line_length=ll)
136
137     def test_empty_ff(self) -> None:
138         expected = ""
139         tmp_file = Path(black.dump_to_file())
140         try:
141             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
142             with open(tmp_file, encoding="utf8") as f:
143                 actual = f.read()
144         finally:
145             os.unlink(tmp_file)
146         self.assertFormatEqual(expected, actual)
147
148     @patch("black.dump_to_file", dump_to_stderr)
149     def test_self(self) -> None:
150         source, expected = read_data("test_black", data=False)
151         actual = fs(source)
152         self.assertFormatEqual(expected, actual)
153         black.assert_equivalent(source, actual)
154         black.assert_stable(source, actual, line_length=ll)
155         self.assertFalse(ff(THIS_FILE))
156
157     @patch("black.dump_to_file", dump_to_stderr)
158     def test_black(self) -> None:
159         source, expected = read_data("../black", data=False)
160         actual = fs(source)
161         self.assertFormatEqual(expected, actual)
162         black.assert_equivalent(source, actual)
163         black.assert_stable(source, actual, line_length=ll)
164         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
165
166     def test_piping(self) -> None:
167         source, expected = read_data("../black", data=False)
168         result = BlackRunner().invoke(
169             black.main,
170             ["-", "--fast", f"--line-length={ll}"],
171             input=BytesIO(source.encode("utf8")),
172         )
173         self.assertEqual(result.exit_code, 0)
174         self.assertFormatEqual(expected, result.output)
175         black.assert_equivalent(source, result.output)
176         black.assert_stable(source, result.output, line_length=ll)
177
178     def test_piping_diff(self) -> None:
179         diff_header = re.compile(
180             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
181             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
182         )
183         source, _ = read_data("expression.py")
184         expected, _ = read_data("expression.diff")
185         config = THIS_DIR / "data" / "empty_pyproject.toml"
186         args = ["-", "--fast", f"--line-length={ll}", "--diff", f"--config={config}"]
187         result = BlackRunner().invoke(
188             black.main, args, input=BytesIO(source.encode("utf8"))
189         )
190         self.assertEqual(result.exit_code, 0)
191         actual = diff_header.sub("[Deterministic header]", result.output)
192         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
193         self.assertEqual(expected, actual)
194
195     @patch("black.dump_to_file", dump_to_stderr)
196     def test_setup(self) -> None:
197         source, expected = read_data("../setup", data=False)
198         actual = fs(source)
199         self.assertFormatEqual(expected, actual)
200         black.assert_equivalent(source, actual)
201         black.assert_stable(source, actual, line_length=ll)
202         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
203
204     @patch("black.dump_to_file", dump_to_stderr)
205     def test_function(self) -> None:
206         source, expected = read_data("function")
207         actual = fs(source)
208         self.assertFormatEqual(expected, actual)
209         black.assert_equivalent(source, actual)
210         black.assert_stable(source, actual, line_length=ll)
211
212     @patch("black.dump_to_file", dump_to_stderr)
213     def test_function2(self) -> None:
214         source, expected = read_data("function2")
215         actual = fs(source)
216         self.assertFormatEqual(expected, actual)
217         black.assert_equivalent(source, actual)
218         black.assert_stable(source, actual, line_length=ll)
219
220     @patch("black.dump_to_file", dump_to_stderr)
221     def test_expression(self) -> None:
222         source, expected = read_data("expression")
223         actual = fs(source)
224         self.assertFormatEqual(expected, actual)
225         black.assert_equivalent(source, actual)
226         black.assert_stable(source, actual, line_length=ll)
227
228     def test_expression_ff(self) -> None:
229         source, expected = read_data("expression")
230         tmp_file = Path(black.dump_to_file(source))
231         try:
232             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
233             with open(tmp_file, encoding="utf8") as f:
234                 actual = f.read()
235         finally:
236             os.unlink(tmp_file)
237         self.assertFormatEqual(expected, actual)
238         with patch("black.dump_to_file", dump_to_stderr):
239             black.assert_equivalent(source, actual)
240             black.assert_stable(source, actual, line_length=ll)
241
242     def test_expression_diff(self) -> None:
243         source, _ = read_data("expression.py")
244         expected, _ = read_data("expression.diff")
245         tmp_file = Path(black.dump_to_file(source))
246         diff_header = re.compile(
247             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
248             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
249         )
250         try:
251             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
252             self.assertEqual(result.exit_code, 0)
253         finally:
254             os.unlink(tmp_file)
255         actual = result.output
256         actual = diff_header.sub("[Deterministic header]", actual)
257         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
258         if expected != actual:
259             dump = black.dump_to_file(actual)
260             msg = (
261                 f"Expected diff isn't equal to the actual. If you made changes "
262                 f"to expression.py and this is an anticipated difference, "
263                 f"overwrite tests/expression.diff with {dump}"
264             )
265             self.assertEqual(expected, actual, msg)
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_fstring(self) -> None:
269         source, expected = read_data("fstring")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, line_length=ll)
274
275     @patch("black.dump_to_file", dump_to_stderr)
276     def test_string_quotes(self) -> None:
277         source, expected = read_data("string_quotes")
278         actual = fs(source)
279         self.assertFormatEqual(expected, actual)
280         black.assert_equivalent(source, actual)
281         black.assert_stable(source, actual, line_length=ll)
282         mode = black.FileMode.NO_STRING_NORMALIZATION
283         not_normalized = fs(source, mode=mode)
284         self.assertFormatEqual(source, not_normalized)
285         black.assert_equivalent(source, not_normalized)
286         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_slices(self) -> None:
290         source, expected = read_data("slices")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_comments(self) -> None:
298         source, expected = read_data("comments")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_comments2(self) -> None:
306         source, expected = read_data("comments2")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     @patch("black.dump_to_file", dump_to_stderr)
313     def test_comments3(self) -> None:
314         source, expected = read_data("comments3")
315         actual = fs(source)
316         self.assertFormatEqual(expected, actual)
317         black.assert_equivalent(source, actual)
318         black.assert_stable(source, actual, line_length=ll)
319
320     @patch("black.dump_to_file", dump_to_stderr)
321     def test_comments4(self) -> None:
322         source, expected = read_data("comments4")
323         actual = fs(source)
324         self.assertFormatEqual(expected, actual)
325         black.assert_equivalent(source, actual)
326         black.assert_stable(source, actual, line_length=ll)
327
328     @patch("black.dump_to_file", dump_to_stderr)
329     def test_comments5(self) -> None:
330         source, expected = read_data("comments5")
331         actual = fs(source)
332         self.assertFormatEqual(expected, actual)
333         black.assert_equivalent(source, actual)
334         black.assert_stable(source, actual, line_length=ll)
335
336     @patch("black.dump_to_file", dump_to_stderr)
337     def test_cantfit(self) -> None:
338         source, expected = read_data("cantfit")
339         actual = fs(source)
340         self.assertFormatEqual(expected, actual)
341         black.assert_equivalent(source, actual)
342         black.assert_stable(source, actual, line_length=ll)
343
344     @patch("black.dump_to_file", dump_to_stderr)
345     def test_import_spacing(self) -> None:
346         source, expected = read_data("import_spacing")
347         actual = fs(source)
348         self.assertFormatEqual(expected, actual)
349         black.assert_equivalent(source, actual)
350         black.assert_stable(source, actual, line_length=ll)
351
352     @patch("black.dump_to_file", dump_to_stderr)
353     def test_composition(self) -> None:
354         source, expected = read_data("composition")
355         actual = fs(source)
356         self.assertFormatEqual(expected, actual)
357         black.assert_equivalent(source, actual)
358         black.assert_stable(source, actual, line_length=ll)
359
360     @patch("black.dump_to_file", dump_to_stderr)
361     def test_empty_lines(self) -> None:
362         source, expected = read_data("empty_lines")
363         actual = fs(source)
364         self.assertFormatEqual(expected, actual)
365         black.assert_equivalent(source, actual)
366         black.assert_stable(source, actual, line_length=ll)
367
368     @patch("black.dump_to_file", dump_to_stderr)
369     def test_string_prefixes(self) -> None:
370         source, expected = read_data("string_prefixes")
371         actual = fs(source)
372         self.assertFormatEqual(expected, actual)
373         black.assert_equivalent(source, actual)
374         black.assert_stable(source, actual, line_length=ll)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_numeric_literals(self) -> None:
378         source, expected = read_data("numeric_literals")
379         actual = fs(source, mode=black.FileMode.PYTHON36)
380         self.assertFormatEqual(expected, actual)
381         black.assert_equivalent(source, actual)
382         black.assert_stable(source, actual, line_length=ll)
383
384     @patch("black.dump_to_file", dump_to_stderr)
385     def test_numeric_literals_py2(self) -> None:
386         source, expected = read_data("numeric_literals_py2")
387         actual = fs(source)
388         self.assertFormatEqual(expected, actual)
389         black.assert_stable(source, actual, line_length=ll)
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_python2(self) -> None:
393         source, expected = read_data("python2")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         # black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, line_length=ll)
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_python2_unicode_literals(self) -> None:
401         source, expected = read_data("python2_unicode_literals")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_stable(source, actual, line_length=ll)
405
406     @patch("black.dump_to_file", dump_to_stderr)
407     def test_stub(self) -> None:
408         mode = black.FileMode.PYI
409         source, expected = read_data("stub.pyi")
410         actual = fs(source, mode=mode)
411         self.assertFormatEqual(expected, actual)
412         black.assert_stable(source, actual, line_length=ll, mode=mode)
413
414     @patch("black.dump_to_file", dump_to_stderr)
415     def test_python37(self) -> None:
416         source, expected = read_data("python37")
417         actual = fs(source)
418         self.assertFormatEqual(expected, actual)
419         major, minor = sys.version_info[:2]
420         if major > 3 or (major == 3 and minor >= 7):
421             black.assert_equivalent(source, actual)
422         black.assert_stable(source, actual, line_length=ll)
423
424     @patch("black.dump_to_file", dump_to_stderr)
425     def test_fmtonoff(self) -> None:
426         source, expected = read_data("fmtonoff")
427         actual = fs(source)
428         self.assertFormatEqual(expected, actual)
429         black.assert_equivalent(source, actual)
430         black.assert_stable(source, actual, line_length=ll)
431
432     @patch("black.dump_to_file", dump_to_stderr)
433     def test_fmtonoff2(self) -> None:
434         source, expected = read_data("fmtonoff2")
435         actual = fs(source)
436         self.assertFormatEqual(expected, actual)
437         black.assert_equivalent(source, actual)
438         black.assert_stable(source, actual, line_length=ll)
439
440     @patch("black.dump_to_file", dump_to_stderr)
441     def test_remove_empty_parentheses_after_class(self) -> None:
442         source, expected = read_data("class_blank_parentheses")
443         actual = fs(source)
444         self.assertFormatEqual(expected, actual)
445         black.assert_equivalent(source, actual)
446         black.assert_stable(source, actual, line_length=ll)
447
448     @patch("black.dump_to_file", dump_to_stderr)
449     def test_new_line_between_class_and_code(self) -> None:
450         source, expected = read_data("class_methods_new_line")
451         actual = fs(source)
452         self.assertFormatEqual(expected, actual)
453         black.assert_equivalent(source, actual)
454         black.assert_stable(source, actual, line_length=ll)
455
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_bracket_match(self) -> None:
458         source, expected = read_data("bracketmatch")
459         actual = fs(source)
460         self.assertFormatEqual(expected, actual)
461         black.assert_equivalent(source, actual)
462         black.assert_stable(source, actual, line_length=ll)
463
464     def test_report_verbose(self) -> None:
465         report = black.Report(verbose=True)
466         out_lines = []
467         err_lines = []
468
469         def out(msg: str, **kwargs: Any) -> None:
470             out_lines.append(msg)
471
472         def err(msg: str, **kwargs: Any) -> None:
473             err_lines.append(msg)
474
475         with patch("black.out", out), patch("black.err", err):
476             report.done(Path("f1"), black.Changed.NO)
477             self.assertEqual(len(out_lines), 1)
478             self.assertEqual(len(err_lines), 0)
479             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
480             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
481             self.assertEqual(report.return_code, 0)
482             report.done(Path("f2"), black.Changed.YES)
483             self.assertEqual(len(out_lines), 2)
484             self.assertEqual(len(err_lines), 0)
485             self.assertEqual(out_lines[-1], "reformatted f2")
486             self.assertEqual(
487                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
488             )
489             report.done(Path("f3"), black.Changed.CACHED)
490             self.assertEqual(len(out_lines), 3)
491             self.assertEqual(len(err_lines), 0)
492             self.assertEqual(
493                 out_lines[-1], "f3 wasn't modified on disk since last run."
494             )
495             self.assertEqual(
496                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
497             )
498             self.assertEqual(report.return_code, 0)
499             report.check = True
500             self.assertEqual(report.return_code, 1)
501             report.check = False
502             report.failed(Path("e1"), "boom")
503             self.assertEqual(len(out_lines), 3)
504             self.assertEqual(len(err_lines), 1)
505             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
506             self.assertEqual(
507                 unstyle(str(report)),
508                 "1 file reformatted, 2 files left unchanged, "
509                 "1 file failed to reformat.",
510             )
511             self.assertEqual(report.return_code, 123)
512             report.done(Path("f3"), black.Changed.YES)
513             self.assertEqual(len(out_lines), 4)
514             self.assertEqual(len(err_lines), 1)
515             self.assertEqual(out_lines[-1], "reformatted f3")
516             self.assertEqual(
517                 unstyle(str(report)),
518                 "2 files reformatted, 2 files left unchanged, "
519                 "1 file failed to reformat.",
520             )
521             self.assertEqual(report.return_code, 123)
522             report.failed(Path("e2"), "boom")
523             self.assertEqual(len(out_lines), 4)
524             self.assertEqual(len(err_lines), 2)
525             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
526             self.assertEqual(
527                 unstyle(str(report)),
528                 "2 files reformatted, 2 files left unchanged, "
529                 "2 files failed to reformat.",
530             )
531             self.assertEqual(report.return_code, 123)
532             report.path_ignored(Path("wat"), "no match")
533             self.assertEqual(len(out_lines), 5)
534             self.assertEqual(len(err_lines), 2)
535             self.assertEqual(out_lines[-1], "wat ignored: no match")
536             self.assertEqual(
537                 unstyle(str(report)),
538                 "2 files reformatted, 2 files left unchanged, "
539                 "2 files failed to reformat.",
540             )
541             self.assertEqual(report.return_code, 123)
542             report.done(Path("f4"), black.Changed.NO)
543             self.assertEqual(len(out_lines), 6)
544             self.assertEqual(len(err_lines), 2)
545             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
546             self.assertEqual(
547                 unstyle(str(report)),
548                 "2 files reformatted, 3 files left unchanged, "
549                 "2 files failed to reformat.",
550             )
551             self.assertEqual(report.return_code, 123)
552             report.check = True
553             self.assertEqual(
554                 unstyle(str(report)),
555                 "2 files would be reformatted, 3 files would be left unchanged, "
556                 "2 files would fail to reformat.",
557             )
558
559     def test_report_quiet(self) -> None:
560         report = black.Report(quiet=True)
561         out_lines = []
562         err_lines = []
563
564         def out(msg: str, **kwargs: Any) -> None:
565             out_lines.append(msg)
566
567         def err(msg: str, **kwargs: Any) -> None:
568             err_lines.append(msg)
569
570         with patch("black.out", out), patch("black.err", err):
571             report.done(Path("f1"), black.Changed.NO)
572             self.assertEqual(len(out_lines), 0)
573             self.assertEqual(len(err_lines), 0)
574             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
575             self.assertEqual(report.return_code, 0)
576             report.done(Path("f2"), black.Changed.YES)
577             self.assertEqual(len(out_lines), 0)
578             self.assertEqual(len(err_lines), 0)
579             self.assertEqual(
580                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
581             )
582             report.done(Path("f3"), black.Changed.CACHED)
583             self.assertEqual(len(out_lines), 0)
584             self.assertEqual(len(err_lines), 0)
585             self.assertEqual(
586                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
587             )
588             self.assertEqual(report.return_code, 0)
589             report.check = True
590             self.assertEqual(report.return_code, 1)
591             report.check = False
592             report.failed(Path("e1"), "boom")
593             self.assertEqual(len(out_lines), 0)
594             self.assertEqual(len(err_lines), 1)
595             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
596             self.assertEqual(
597                 unstyle(str(report)),
598                 "1 file reformatted, 2 files left unchanged, "
599                 "1 file failed to reformat.",
600             )
601             self.assertEqual(report.return_code, 123)
602             report.done(Path("f3"), black.Changed.YES)
603             self.assertEqual(len(out_lines), 0)
604             self.assertEqual(len(err_lines), 1)
605             self.assertEqual(
606                 unstyle(str(report)),
607                 "2 files reformatted, 2 files left unchanged, "
608                 "1 file failed to reformat.",
609             )
610             self.assertEqual(report.return_code, 123)
611             report.failed(Path("e2"), "boom")
612             self.assertEqual(len(out_lines), 0)
613             self.assertEqual(len(err_lines), 2)
614             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
615             self.assertEqual(
616                 unstyle(str(report)),
617                 "2 files reformatted, 2 files left unchanged, "
618                 "2 files failed to reformat.",
619             )
620             self.assertEqual(report.return_code, 123)
621             report.path_ignored(Path("wat"), "no match")
622             self.assertEqual(len(out_lines), 0)
623             self.assertEqual(len(err_lines), 2)
624             self.assertEqual(
625                 unstyle(str(report)),
626                 "2 files reformatted, 2 files left unchanged, "
627                 "2 files failed to reformat.",
628             )
629             self.assertEqual(report.return_code, 123)
630             report.done(Path("f4"), black.Changed.NO)
631             self.assertEqual(len(out_lines), 0)
632             self.assertEqual(len(err_lines), 2)
633             self.assertEqual(
634                 unstyle(str(report)),
635                 "2 files reformatted, 3 files left unchanged, "
636                 "2 files failed to reformat.",
637             )
638             self.assertEqual(report.return_code, 123)
639             report.check = True
640             self.assertEqual(
641                 unstyle(str(report)),
642                 "2 files would be reformatted, 3 files would be left unchanged, "
643                 "2 files would fail to reformat.",
644             )
645
646     def test_report_normal(self) -> None:
647         report = black.Report()
648         out_lines = []
649         err_lines = []
650
651         def out(msg: str, **kwargs: Any) -> None:
652             out_lines.append(msg)
653
654         def err(msg: str, **kwargs: Any) -> None:
655             err_lines.append(msg)
656
657         with patch("black.out", out), patch("black.err", err):
658             report.done(Path("f1"), black.Changed.NO)
659             self.assertEqual(len(out_lines), 0)
660             self.assertEqual(len(err_lines), 0)
661             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
662             self.assertEqual(report.return_code, 0)
663             report.done(Path("f2"), black.Changed.YES)
664             self.assertEqual(len(out_lines), 1)
665             self.assertEqual(len(err_lines), 0)
666             self.assertEqual(out_lines[-1], "reformatted f2")
667             self.assertEqual(
668                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
669             )
670             report.done(Path("f3"), black.Changed.CACHED)
671             self.assertEqual(len(out_lines), 1)
672             self.assertEqual(len(err_lines), 0)
673             self.assertEqual(out_lines[-1], "reformatted f2")
674             self.assertEqual(
675                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
676             )
677             self.assertEqual(report.return_code, 0)
678             report.check = True
679             self.assertEqual(report.return_code, 1)
680             report.check = False
681             report.failed(Path("e1"), "boom")
682             self.assertEqual(len(out_lines), 1)
683             self.assertEqual(len(err_lines), 1)
684             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
685             self.assertEqual(
686                 unstyle(str(report)),
687                 "1 file reformatted, 2 files left unchanged, "
688                 "1 file failed to reformat.",
689             )
690             self.assertEqual(report.return_code, 123)
691             report.done(Path("f3"), black.Changed.YES)
692             self.assertEqual(len(out_lines), 2)
693             self.assertEqual(len(err_lines), 1)
694             self.assertEqual(out_lines[-1], "reformatted f3")
695             self.assertEqual(
696                 unstyle(str(report)),
697                 "2 files reformatted, 2 files left unchanged, "
698                 "1 file failed to reformat.",
699             )
700             self.assertEqual(report.return_code, 123)
701             report.failed(Path("e2"), "boom")
702             self.assertEqual(len(out_lines), 2)
703             self.assertEqual(len(err_lines), 2)
704             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
705             self.assertEqual(
706                 unstyle(str(report)),
707                 "2 files reformatted, 2 files left unchanged, "
708                 "2 files failed to reformat.",
709             )
710             self.assertEqual(report.return_code, 123)
711             report.path_ignored(Path("wat"), "no match")
712             self.assertEqual(len(out_lines), 2)
713             self.assertEqual(len(err_lines), 2)
714             self.assertEqual(
715                 unstyle(str(report)),
716                 "2 files reformatted, 2 files left unchanged, "
717                 "2 files failed to reformat.",
718             )
719             self.assertEqual(report.return_code, 123)
720             report.done(Path("f4"), black.Changed.NO)
721             self.assertEqual(len(out_lines), 2)
722             self.assertEqual(len(err_lines), 2)
723             self.assertEqual(
724                 unstyle(str(report)),
725                 "2 files reformatted, 3 files left unchanged, "
726                 "2 files failed to reformat.",
727             )
728             self.assertEqual(report.return_code, 123)
729             report.check = True
730             self.assertEqual(
731                 unstyle(str(report)),
732                 "2 files would be reformatted, 3 files would be left unchanged, "
733                 "2 files would fail to reformat.",
734             )
735
736     def test_is_python36(self) -> None:
737         node = black.lib2to3_parse("def f(*, arg): ...\n")
738         self.assertFalse(black.is_python36(node))
739         node = black.lib2to3_parse("def f(*, arg,): ...\n")
740         self.assertTrue(black.is_python36(node))
741         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
742         self.assertTrue(black.is_python36(node))
743         node = black.lib2to3_parse("123_456\n")
744         self.assertTrue(black.is_python36(node))
745         node = black.lib2to3_parse("123456\n")
746         self.assertFalse(black.is_python36(node))
747         source, expected = read_data("function")
748         node = black.lib2to3_parse(source)
749         self.assertTrue(black.is_python36(node))
750         node = black.lib2to3_parse(expected)
751         self.assertTrue(black.is_python36(node))
752         source, expected = read_data("expression")
753         node = black.lib2to3_parse(source)
754         self.assertFalse(black.is_python36(node))
755         node = black.lib2to3_parse(expected)
756         self.assertFalse(black.is_python36(node))
757
758     def test_get_future_imports(self) -> None:
759         node = black.lib2to3_parse("\n")
760         self.assertEqual(set(), black.get_future_imports(node))
761         node = black.lib2to3_parse("from __future__ import black\n")
762         self.assertEqual({"black"}, black.get_future_imports(node))
763         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
764         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
765         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
766         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
767         node = black.lib2to3_parse(
768             "from __future__ import multiple\nfrom __future__ import imports\n"
769         )
770         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
771         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
772         self.assertEqual({"black"}, black.get_future_imports(node))
773         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
774         self.assertEqual({"black"}, black.get_future_imports(node))
775         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
776         self.assertEqual(set(), black.get_future_imports(node))
777         node = black.lib2to3_parse("from some.module import black\n")
778         self.assertEqual(set(), black.get_future_imports(node))
779         node = black.lib2to3_parse(
780             "from __future__ import unicode_literals as _unicode_literals"
781         )
782         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
783         node = black.lib2to3_parse(
784             "from __future__ import unicode_literals as _lol, print"
785         )
786         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
787
788     def test_debug_visitor(self) -> None:
789         source, _ = read_data("debug_visitor.py")
790         expected, _ = read_data("debug_visitor.out")
791         out_lines = []
792         err_lines = []
793
794         def out(msg: str, **kwargs: Any) -> None:
795             out_lines.append(msg)
796
797         def err(msg: str, **kwargs: Any) -> None:
798             err_lines.append(msg)
799
800         with patch("black.out", out), patch("black.err", err):
801             black.DebugVisitor.show(source)
802         actual = "\n".join(out_lines) + "\n"
803         log_name = ""
804         if expected != actual:
805             log_name = black.dump_to_file(*out_lines)
806         self.assertEqual(
807             expected,
808             actual,
809             f"AST print out is different. Actual version dumped to {log_name}",
810         )
811
812     def test_format_file_contents(self) -> None:
813         empty = ""
814         with self.assertRaises(black.NothingChanged):
815             black.format_file_contents(empty, line_length=ll, fast=False)
816         just_nl = "\n"
817         with self.assertRaises(black.NothingChanged):
818             black.format_file_contents(just_nl, line_length=ll, fast=False)
819         same = "l = [1, 2, 3]\n"
820         with self.assertRaises(black.NothingChanged):
821             black.format_file_contents(same, line_length=ll, fast=False)
822         different = "l = [1,2,3]"
823         expected = same
824         actual = black.format_file_contents(different, line_length=ll, fast=False)
825         self.assertEqual(expected, actual)
826         invalid = "return if you can"
827         with self.assertRaises(ValueError) as e:
828             black.format_file_contents(invalid, line_length=ll, fast=False)
829         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
830
831     def test_endmarker(self) -> None:
832         n = black.lib2to3_parse("\n")
833         self.assertEqual(n.type, black.syms.file_input)
834         self.assertEqual(len(n.children), 1)
835         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
836
837     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
838     def test_assertFormatEqual(self) -> None:
839         out_lines = []
840         err_lines = []
841
842         def out(msg: str, **kwargs: Any) -> None:
843             out_lines.append(msg)
844
845         def err(msg: str, **kwargs: Any) -> None:
846             err_lines.append(msg)
847
848         with patch("black.out", out), patch("black.err", err):
849             with self.assertRaises(AssertionError):
850                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
851
852         out_str = "".join(out_lines)
853         self.assertTrue("Expected tree:" in out_str)
854         self.assertTrue("Actual tree:" in out_str)
855         self.assertEqual("".join(err_lines), "")
856
857     def test_cache_broken_file(self) -> None:
858         mode = black.FileMode.AUTO_DETECT
859         with cache_dir() as workspace:
860             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
861             with cache_file.open("w") as fobj:
862                 fobj.write("this is not a pickle")
863             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
864             src = (workspace / "test.py").resolve()
865             with src.open("w") as fobj:
866                 fobj.write("print('hello')")
867             result = CliRunner().invoke(black.main, [str(src)])
868             self.assertEqual(result.exit_code, 0)
869             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
870             self.assertIn(src, cache)
871
872     def test_cache_single_file_already_cached(self) -> None:
873         mode = black.FileMode.AUTO_DETECT
874         with cache_dir() as workspace:
875             src = (workspace / "test.py").resolve()
876             with src.open("w") as fobj:
877                 fobj.write("print('hello')")
878             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
879             result = CliRunner().invoke(black.main, [str(src)])
880             self.assertEqual(result.exit_code, 0)
881             with src.open("r") as fobj:
882                 self.assertEqual(fobj.read(), "print('hello')")
883
884     @event_loop(close=False)
885     def test_cache_multiple_files(self) -> None:
886         mode = black.FileMode.AUTO_DETECT
887         with cache_dir() as workspace, patch(
888             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
889         ):
890             one = (workspace / "one.py").resolve()
891             with one.open("w") as fobj:
892                 fobj.write("print('hello')")
893             two = (workspace / "two.py").resolve()
894             with two.open("w") as fobj:
895                 fobj.write("print('hello')")
896             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
897             result = CliRunner().invoke(black.main, [str(workspace)])
898             self.assertEqual(result.exit_code, 0)
899             with one.open("r") as fobj:
900                 self.assertEqual(fobj.read(), "print('hello')")
901             with two.open("r") as fobj:
902                 self.assertEqual(fobj.read(), 'print("hello")\n')
903             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
904             self.assertIn(one, cache)
905             self.assertIn(two, cache)
906
907     def test_no_cache_when_writeback_diff(self) -> None:
908         mode = black.FileMode.AUTO_DETECT
909         with cache_dir() as workspace:
910             src = (workspace / "test.py").resolve()
911             with src.open("w") as fobj:
912                 fobj.write("print('hello')")
913             result = CliRunner().invoke(black.main, [str(src), "--diff"])
914             self.assertEqual(result.exit_code, 0)
915             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
916             self.assertFalse(cache_file.exists())
917
918     def test_no_cache_when_stdin(self) -> None:
919         mode = black.FileMode.AUTO_DETECT
920         with cache_dir():
921             result = CliRunner().invoke(
922                 black.main, ["-"], input=BytesIO(b"print('hello')")
923             )
924             self.assertEqual(result.exit_code, 0)
925             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
926             self.assertFalse(cache_file.exists())
927
928     def test_read_cache_no_cachefile(self) -> None:
929         mode = black.FileMode.AUTO_DETECT
930         with cache_dir():
931             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
932
933     def test_write_cache_read_cache(self) -> None:
934         mode = black.FileMode.AUTO_DETECT
935         with cache_dir() as workspace:
936             src = (workspace / "test.py").resolve()
937             src.touch()
938             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
939             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
940             self.assertIn(src, cache)
941             self.assertEqual(cache[src], black.get_cache_info(src))
942
943     def test_filter_cached(self) -> None:
944         with TemporaryDirectory() as workspace:
945             path = Path(workspace)
946             uncached = (path / "uncached").resolve()
947             cached = (path / "cached").resolve()
948             cached_but_changed = (path / "changed").resolve()
949             uncached.touch()
950             cached.touch()
951             cached_but_changed.touch()
952             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
953             todo, done = black.filter_cached(
954                 cache, {uncached, cached, cached_but_changed}
955             )
956             self.assertEqual(todo, {uncached, cached_but_changed})
957             self.assertEqual(done, {cached})
958
959     def test_write_cache_creates_directory_if_needed(self) -> None:
960         mode = black.FileMode.AUTO_DETECT
961         with cache_dir(exists=False) as workspace:
962             self.assertFalse(workspace.exists())
963             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
964             self.assertTrue(workspace.exists())
965
966     @event_loop(close=False)
967     def test_failed_formatting_does_not_get_cached(self) -> None:
968         mode = black.FileMode.AUTO_DETECT
969         with cache_dir() as workspace, patch(
970             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
971         ):
972             failing = (workspace / "failing.py").resolve()
973             with failing.open("w") as fobj:
974                 fobj.write("not actually python")
975             clean = (workspace / "clean.py").resolve()
976             with clean.open("w") as fobj:
977                 fobj.write('print("hello")\n')
978             result = CliRunner().invoke(black.main, [str(workspace)])
979             self.assertEqual(result.exit_code, 123)
980             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
981             self.assertNotIn(failing, cache)
982             self.assertIn(clean, cache)
983
984     def test_write_cache_write_fail(self) -> None:
985         mode = black.FileMode.AUTO_DETECT
986         with cache_dir(), patch.object(Path, "open") as mock:
987             mock.side_effect = OSError
988             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
989
990     @event_loop(close=False)
991     def test_check_diff_use_together(self) -> None:
992         with cache_dir():
993             # Files which will be reformatted.
994             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
995             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
996             self.assertEqual(result.exit_code, 1, result.output)
997             # Files which will not be reformatted.
998             src2 = (THIS_DIR / "data" / "composition.py").resolve()
999             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
1000             self.assertEqual(result.exit_code, 0, result.output)
1001             # Multi file command.
1002             result = CliRunner().invoke(
1003                 black.main, [str(src1), str(src2), "--diff", "--check"]
1004             )
1005             self.assertEqual(result.exit_code, 1, result.output)
1006
1007     def test_no_files(self) -> None:
1008         with cache_dir():
1009             # Without an argument, black exits with error code 0.
1010             result = CliRunner().invoke(black.main, [])
1011             self.assertEqual(result.exit_code, 0)
1012
1013     def test_broken_symlink(self) -> None:
1014         with cache_dir() as workspace:
1015             symlink = workspace / "broken_link.py"
1016             try:
1017                 symlink.symlink_to("nonexistent.py")
1018             except OSError as e:
1019                 self.skipTest(f"Can't create symlinks: {e}")
1020             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
1021             self.assertEqual(result.exit_code, 0)
1022
1023     def test_read_cache_line_lengths(self) -> None:
1024         mode = black.FileMode.AUTO_DETECT
1025         with cache_dir() as workspace:
1026             path = (workspace / "file.py").resolve()
1027             path.touch()
1028             black.write_cache({}, [path], 1, mode)
1029             one = black.read_cache(1, mode)
1030             self.assertIn(path, one)
1031             two = black.read_cache(2, mode)
1032             self.assertNotIn(path, two)
1033
1034     def test_single_file_force_pyi(self) -> None:
1035         reg_mode = black.FileMode.AUTO_DETECT
1036         pyi_mode = black.FileMode.PYI
1037         contents, expected = read_data("force_pyi")
1038         with cache_dir() as workspace:
1039             path = (workspace / "file.py").resolve()
1040             with open(path, "w") as fh:
1041                 fh.write(contents)
1042             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
1043             self.assertEqual(result.exit_code, 0)
1044             with open(path, "r") as fh:
1045                 actual = fh.read()
1046             # verify cache with --pyi is separate
1047             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1048             self.assertIn(path, pyi_cache)
1049             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1050             self.assertNotIn(path, normal_cache)
1051         self.assertEqual(actual, expected)
1052
1053     @event_loop(close=False)
1054     def test_multi_file_force_pyi(self) -> None:
1055         reg_mode = black.FileMode.AUTO_DETECT
1056         pyi_mode = black.FileMode.PYI
1057         contents, expected = read_data("force_pyi")
1058         with cache_dir() as workspace:
1059             paths = [
1060                 (workspace / "file1.py").resolve(),
1061                 (workspace / "file2.py").resolve(),
1062             ]
1063             for path in paths:
1064                 with open(path, "w") as fh:
1065                     fh.write(contents)
1066             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
1067             self.assertEqual(result.exit_code, 0)
1068             for path in paths:
1069                 with open(path, "r") as fh:
1070                     actual = fh.read()
1071                 self.assertEqual(actual, expected)
1072             # verify cache with --pyi is separate
1073             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1074             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1075             for path in paths:
1076                 self.assertIn(path, pyi_cache)
1077                 self.assertNotIn(path, normal_cache)
1078
1079     def test_pipe_force_pyi(self) -> None:
1080         source, expected = read_data("force_pyi")
1081         result = CliRunner().invoke(
1082             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1083         )
1084         self.assertEqual(result.exit_code, 0)
1085         actual = result.output
1086         self.assertFormatEqual(actual, expected)
1087
1088     def test_single_file_force_py36(self) -> None:
1089         reg_mode = black.FileMode.AUTO_DETECT
1090         py36_mode = black.FileMode.PYTHON36
1091         source, expected = read_data("force_py36")
1092         with cache_dir() as workspace:
1093             path = (workspace / "file.py").resolve()
1094             with open(path, "w") as fh:
1095                 fh.write(source)
1096             result = CliRunner().invoke(black.main, [str(path), "--py36"])
1097             self.assertEqual(result.exit_code, 0)
1098             with open(path, "r") as fh:
1099                 actual = fh.read()
1100             # verify cache with --py36 is separate
1101             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1102             self.assertIn(path, py36_cache)
1103             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1104             self.assertNotIn(path, normal_cache)
1105         self.assertEqual(actual, expected)
1106
1107     @event_loop(close=False)
1108     def test_multi_file_force_py36(self) -> None:
1109         reg_mode = black.FileMode.AUTO_DETECT
1110         py36_mode = black.FileMode.PYTHON36
1111         source, expected = read_data("force_py36")
1112         with cache_dir() as workspace:
1113             paths = [
1114                 (workspace / "file1.py").resolve(),
1115                 (workspace / "file2.py").resolve(),
1116             ]
1117             for path in paths:
1118                 with open(path, "w") as fh:
1119                     fh.write(source)
1120             result = CliRunner().invoke(
1121                 black.main, [str(p) for p in paths] + ["--py36"]
1122             )
1123             self.assertEqual(result.exit_code, 0)
1124             for path in paths:
1125                 with open(path, "r") as fh:
1126                     actual = fh.read()
1127                 self.assertEqual(actual, expected)
1128             # verify cache with --py36 is separate
1129             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1130             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1131             for path in paths:
1132                 self.assertIn(path, pyi_cache)
1133                 self.assertNotIn(path, normal_cache)
1134
1135     def test_pipe_force_py36(self) -> None:
1136         source, expected = read_data("force_py36")
1137         result = CliRunner().invoke(
1138             black.main, ["-", "-q", "--py36"], input=BytesIO(source.encode("utf8"))
1139         )
1140         self.assertEqual(result.exit_code, 0)
1141         actual = result.output
1142         self.assertFormatEqual(actual, expected)
1143
1144     def test_include_exclude(self) -> None:
1145         path = THIS_DIR / "data" / "include_exclude_tests"
1146         include = re.compile(r"\.pyi?$")
1147         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1148         report = black.Report()
1149         sources: List[Path] = []
1150         expected = [
1151             Path(path / "b/dont_exclude/a.py"),
1152             Path(path / "b/dont_exclude/a.pyi"),
1153         ]
1154         this_abs = THIS_DIR.resolve()
1155         sources.extend(
1156             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1157         )
1158         self.assertEqual(sorted(expected), sorted(sources))
1159
1160     def test_empty_include(self) -> None:
1161         path = THIS_DIR / "data" / "include_exclude_tests"
1162         report = black.Report()
1163         empty = re.compile(r"")
1164         sources: List[Path] = []
1165         expected = [
1166             Path(path / "b/exclude/a.pie"),
1167             Path(path / "b/exclude/a.py"),
1168             Path(path / "b/exclude/a.pyi"),
1169             Path(path / "b/dont_exclude/a.pie"),
1170             Path(path / "b/dont_exclude/a.py"),
1171             Path(path / "b/dont_exclude/a.pyi"),
1172             Path(path / "b/.definitely_exclude/a.pie"),
1173             Path(path / "b/.definitely_exclude/a.py"),
1174             Path(path / "b/.definitely_exclude/a.pyi"),
1175         ]
1176         this_abs = THIS_DIR.resolve()
1177         sources.extend(
1178             black.gen_python_files_in_dir(
1179                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1180             )
1181         )
1182         self.assertEqual(sorted(expected), sorted(sources))
1183
1184     def test_empty_exclude(self) -> None:
1185         path = THIS_DIR / "data" / "include_exclude_tests"
1186         report = black.Report()
1187         empty = re.compile(r"")
1188         sources: List[Path] = []
1189         expected = [
1190             Path(path / "b/dont_exclude/a.py"),
1191             Path(path / "b/dont_exclude/a.pyi"),
1192             Path(path / "b/exclude/a.py"),
1193             Path(path / "b/exclude/a.pyi"),
1194             Path(path / "b/.definitely_exclude/a.py"),
1195             Path(path / "b/.definitely_exclude/a.pyi"),
1196         ]
1197         this_abs = THIS_DIR.resolve()
1198         sources.extend(
1199             black.gen_python_files_in_dir(
1200                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1201             )
1202         )
1203         self.assertEqual(sorted(expected), sorted(sources))
1204
1205     def test_invalid_include_exclude(self) -> None:
1206         for option in ["--include", "--exclude"]:
1207             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
1208             self.assertEqual(result.exit_code, 2)
1209
1210     def test_preserves_line_endings(self) -> None:
1211         with TemporaryDirectory() as workspace:
1212             test_file = Path(workspace) / "test.py"
1213             for nl in ["\n", "\r\n"]:
1214                 contents = nl.join(["def f(  ):", "    pass"])
1215                 test_file.write_bytes(contents.encode())
1216                 ff(test_file, write_back=black.WriteBack.YES)
1217                 updated_contents: bytes = test_file.read_bytes()
1218                 self.assertIn(nl.encode(), updated_contents)
1219                 if nl == "\n":
1220                     self.assertNotIn(b"\r\n", updated_contents)
1221
1222     def test_preserves_line_endings_via_stdin(self) -> None:
1223         for nl in ["\n", "\r\n"]:
1224             contents = nl.join(["def f(  ):", "    pass"])
1225             runner = BlackRunner()
1226             result = runner.invoke(
1227                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1228             )
1229             self.assertEqual(result.exit_code, 0)
1230             output = runner.stdout_bytes
1231             self.assertIn(nl.encode("utf8"), output)
1232             if nl == "\n":
1233                 self.assertNotIn(b"\r\n", output)
1234
1235     def test_assert_equivalent_different_asts(self) -> None:
1236         with self.assertRaises(AssertionError):
1237             black.assert_equivalent("{}", "None")
1238
1239     def test_symlink_out_of_root_directory(self) -> None:
1240         path = MagicMock()
1241         root = THIS_DIR
1242         child = MagicMock()
1243         include = re.compile(black.DEFAULT_INCLUDES)
1244         exclude = re.compile(black.DEFAULT_EXCLUDES)
1245         report = black.Report()
1246         # `child` should behave like a symlink which resolved path is clearly
1247         # outside of the `root` directory.
1248         path.iterdir.return_value = [child]
1249         child.resolve.return_value = Path("/a/b/c")
1250         child.is_symlink.return_value = True
1251         try:
1252             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1253         except ValueError as ve:
1254             self.fail("`get_python_files_in_dir()` failed: {ve}")
1255         path.iterdir.assert_called_once()
1256         child.resolve.assert_called_once()
1257         child.is_symlink.assert_called_once()
1258         # `child` should behave like a strange file which resolved path is clearly
1259         # outside of the `root` directory.
1260         child.is_symlink.return_value = False
1261         with self.assertRaises(ValueError):
1262             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1263         path.iterdir.assert_called()
1264         self.assertEqual(path.iterdir.call_count, 2)
1265         child.resolve.assert_called()
1266         self.assertEqual(child.resolve.call_count, 2)
1267         child.is_symlink.assert_called()
1268         self.assertEqual(child.is_symlink.call_count, 2)
1269
1270     def test_shhh_click(self) -> None:
1271         try:
1272             from click import _unicodefun  # type: ignore
1273         except ModuleNotFoundError:
1274             self.skipTest("Incompatible Click version")
1275         if not hasattr(_unicodefun, "_verify_python3_env"):
1276             self.skipTest("Incompatible Click version")
1277         # First, let's see if Click is crashing with a preferred ASCII charset.
1278         with patch("locale.getpreferredencoding") as gpe:
1279             gpe.return_value = "ASCII"
1280             with self.assertRaises(RuntimeError):
1281                 _unicodefun._verify_python3_env()
1282         # Now, let's silence Click...
1283         black.patch_click()
1284         # ...and confirm it's silent.
1285         with patch("locale.getpreferredencoding") as gpe:
1286             gpe.return_value = "ASCII"
1287             try:
1288                 _unicodefun._verify_python3_env()
1289             except RuntimeError as re:
1290                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1291
1292
1293 if __name__ == "__main__":
1294     unittest.main(module="test_black")