]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Move tokenizer config onto grammar, rename flag
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_expression(self) -> None:
269         source, expected = read_data("expression")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     def test_expression_ff(self) -> None:
276         source, expected = read_data("expression")
277         tmp_file = Path(black.dump_to_file(source))
278         try:
279             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
280             with open(tmp_file, encoding="utf8") as f:
281                 actual = f.read()
282         finally:
283             os.unlink(tmp_file)
284         self.assertFormatEqual(expected, actual)
285         with patch("black.dump_to_file", dump_to_stderr):
286             black.assert_equivalent(source, actual)
287             black.assert_stable(source, actual, black.FileMode())
288
289     def test_expression_diff(self) -> None:
290         source, _ = read_data("expression.py")
291         expected, _ = read_data("expression.diff")
292         tmp_file = Path(black.dump_to_file(source))
293         diff_header = re.compile(
294             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
295             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
296         )
297         try:
298             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
299             self.assertEqual(result.exit_code, 0)
300         finally:
301             os.unlink(tmp_file)
302         actual = result.output
303         actual = diff_header.sub("[Deterministic header]", actual)
304         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
305         if expected != actual:
306             dump = black.dump_to_file(actual)
307             msg = (
308                 f"Expected diff isn't equal to the actual. If you made changes "
309                 f"to expression.py and this is an anticipated difference, "
310                 f"overwrite tests/data/expression.diff with {dump}"
311             )
312             self.assertEqual(expected, actual, msg)
313
314     @patch("black.dump_to_file", dump_to_stderr)
315     def test_fstring(self) -> None:
316         source, expected = read_data("fstring")
317         actual = fs(source)
318         self.assertFormatEqual(expected, actual)
319         black.assert_equivalent(source, actual)
320         black.assert_stable(source, actual, black.FileMode())
321
322     @patch("black.dump_to_file", dump_to_stderr)
323     def test_string_quotes(self) -> None:
324         source, expected = read_data("string_quotes")
325         actual = fs(source)
326         self.assertFormatEqual(expected, actual)
327         black.assert_equivalent(source, actual)
328         black.assert_stable(source, actual, black.FileMode())
329         mode = black.FileMode(string_normalization=False)
330         not_normalized = fs(source, mode=mode)
331         self.assertFormatEqual(source, not_normalized)
332         black.assert_equivalent(source, not_normalized)
333         black.assert_stable(source, not_normalized, mode=mode)
334
335     @patch("black.dump_to_file", dump_to_stderr)
336     def test_slices(self) -> None:
337         source, expected = read_data("slices")
338         actual = fs(source)
339         self.assertFormatEqual(expected, actual)
340         black.assert_equivalent(source, actual)
341         black.assert_stable(source, actual, black.FileMode())
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_comments(self) -> None:
345         source, expected = read_data("comments")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, black.FileMode())
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_comments2(self) -> None:
353         source, expected = read_data("comments2")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, black.FileMode())
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_comments3(self) -> None:
361         source, expected = read_data("comments3")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, black.FileMode())
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_comments4(self) -> None:
369         source, expected = read_data("comments4")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, black.FileMode())
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_comments5(self) -> None:
377         source, expected = read_data("comments5")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, black.FileMode())
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_comments6(self) -> None:
385         source, expected = read_data("comments6")
386         actual = fs(source)
387         self.assertFormatEqual(expected, actual)
388         black.assert_equivalent(source, actual)
389         black.assert_stable(source, actual, black.FileMode())
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_comments7(self) -> None:
393         source, expected = read_data("comments7")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, black.FileMode())
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_comment_after_escaped_newline(self) -> None:
401         source, expected = read_data("comment_after_escaped_newline")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, black.FileMode())
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_cantfit(self) -> None:
409         source, expected = read_data("cantfit")
410         actual = fs(source)
411         self.assertFormatEqual(expected, actual)
412         black.assert_equivalent(source, actual)
413         black.assert_stable(source, actual, black.FileMode())
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_import_spacing(self) -> None:
417         source, expected = read_data("import_spacing")
418         actual = fs(source)
419         self.assertFormatEqual(expected, actual)
420         black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, black.FileMode())
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_composition(self) -> None:
425         source, expected = read_data("composition")
426         actual = fs(source)
427         self.assertFormatEqual(expected, actual)
428         black.assert_equivalent(source, actual)
429         black.assert_stable(source, actual, black.FileMode())
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_empty_lines(self) -> None:
433         source, expected = read_data("empty_lines")
434         actual = fs(source)
435         self.assertFormatEqual(expected, actual)
436         black.assert_equivalent(source, actual)
437         black.assert_stable(source, actual, black.FileMode())
438
439     @patch("black.dump_to_file", dump_to_stderr)
440     def test_string_prefixes(self) -> None:
441         source, expected = read_data("string_prefixes")
442         actual = fs(source)
443         self.assertFormatEqual(expected, actual)
444         black.assert_equivalent(source, actual)
445         black.assert_stable(source, actual, black.FileMode())
446
447     @patch("black.dump_to_file", dump_to_stderr)
448     def test_numeric_literals(self) -> None:
449         source, expected = read_data("numeric_literals")
450         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
451         actual = fs(source, mode=mode)
452         self.assertFormatEqual(expected, actual)
453         black.assert_equivalent(source, actual)
454         black.assert_stable(source, actual, mode)
455
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_numeric_literals_ignoring_underscores(self) -> None:
458         source, expected = read_data("numeric_literals_skip_underscores")
459         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
460         actual = fs(source, mode=mode)
461         self.assertFormatEqual(expected, actual)
462         black.assert_equivalent(source, actual)
463         black.assert_stable(source, actual, mode)
464
465     @patch("black.dump_to_file", dump_to_stderr)
466     def test_numeric_literals_py2(self) -> None:
467         source, expected = read_data("numeric_literals_py2")
468         actual = fs(source)
469         self.assertFormatEqual(expected, actual)
470         black.assert_stable(source, actual, black.FileMode())
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_python2(self) -> None:
474         source, expected = read_data("python2")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         black.assert_equivalent(source, actual)
478         black.assert_stable(source, actual, black.FileMode())
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_python2_print_function(self) -> None:
482         source, expected = read_data("python2_print_function")
483         mode = black.FileMode(target_versions={TargetVersion.PY27})
484         actual = fs(source, mode=mode)
485         self.assertFormatEqual(expected, actual)
486         black.assert_equivalent(source, actual)
487         black.assert_stable(source, actual, mode)
488
489     @patch("black.dump_to_file", dump_to_stderr)
490     def test_python2_unicode_literals(self) -> None:
491         source, expected = read_data("python2_unicode_literals")
492         actual = fs(source)
493         self.assertFormatEqual(expected, actual)
494         black.assert_equivalent(source, actual)
495         black.assert_stable(source, actual, black.FileMode())
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_stub(self) -> None:
499         mode = black.FileMode(is_pyi=True)
500         source, expected = read_data("stub.pyi")
501         actual = fs(source, mode=mode)
502         self.assertFormatEqual(expected, actual)
503         black.assert_stable(source, actual, mode)
504
505     @patch("black.dump_to_file", dump_to_stderr)
506     def test_async_as_identifier(self) -> None:
507         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
508         source, expected = read_data("async_as_identifier")
509         actual = fs(source)
510         self.assertFormatEqual(expected, actual)
511         major, minor = sys.version_info[:2]
512         if major < 3 or (major <= 3 and minor < 7):
513             black.assert_equivalent(source, actual)
514         black.assert_stable(source, actual, black.FileMode())
515         # ensure black can parse this when the target is 3.6
516         self.invokeBlack([str(source_path), "--target-version", "py36"])
517         # but not on 3.7, because async/await is no longer an identifier
518         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
519
520     @patch("black.dump_to_file", dump_to_stderr)
521     def test_python37(self) -> None:
522         source_path = (THIS_DIR / "data" / "python37.py").resolve()
523         source, expected = read_data("python37")
524         actual = fs(source)
525         self.assertFormatEqual(expected, actual)
526         major, minor = sys.version_info[:2]
527         if major > 3 or (major == 3 and minor >= 7):
528             black.assert_equivalent(source, actual)
529         black.assert_stable(source, actual, black.FileMode())
530         # ensure black can parse this when the target is 3.7
531         self.invokeBlack([str(source_path), "--target-version", "py37"])
532         # but not on 3.6, because we use async as a reserved keyword
533         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
534
535     @patch("black.dump_to_file", dump_to_stderr)
536     def test_fmtonoff(self) -> None:
537         source, expected = read_data("fmtonoff")
538         actual = fs(source)
539         self.assertFormatEqual(expected, actual)
540         black.assert_equivalent(source, actual)
541         black.assert_stable(source, actual, black.FileMode())
542
543     @patch("black.dump_to_file", dump_to_stderr)
544     def test_fmtonoff2(self) -> None:
545         source, expected = read_data("fmtonoff2")
546         actual = fs(source)
547         self.assertFormatEqual(expected, actual)
548         black.assert_equivalent(source, actual)
549         black.assert_stable(source, actual, black.FileMode())
550
551     @patch("black.dump_to_file", dump_to_stderr)
552     def test_remove_empty_parentheses_after_class(self) -> None:
553         source, expected = read_data("class_blank_parentheses")
554         actual = fs(source)
555         self.assertFormatEqual(expected, actual)
556         black.assert_equivalent(source, actual)
557         black.assert_stable(source, actual, black.FileMode())
558
559     @patch("black.dump_to_file", dump_to_stderr)
560     def test_new_line_between_class_and_code(self) -> None:
561         source, expected = read_data("class_methods_new_line")
562         actual = fs(source)
563         self.assertFormatEqual(expected, actual)
564         black.assert_equivalent(source, actual)
565         black.assert_stable(source, actual, black.FileMode())
566
567     @patch("black.dump_to_file", dump_to_stderr)
568     def test_bracket_match(self) -> None:
569         source, expected = read_data("bracketmatch")
570         actual = fs(source)
571         self.assertFormatEqual(expected, actual)
572         black.assert_equivalent(source, actual)
573         black.assert_stable(source, actual, black.FileMode())
574
575     @patch("black.dump_to_file", dump_to_stderr)
576     def test_tuple_assign(self) -> None:
577         source, expected = read_data("tupleassign")
578         actual = fs(source)
579         self.assertFormatEqual(expected, actual)
580         black.assert_equivalent(source, actual)
581         black.assert_stable(source, actual, black.FileMode())
582
583     def test_tab_comment_indentation(self) -> None:
584         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
585         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
586         self.assertFormatEqual(contents_spc, fs(contents_spc))
587         self.assertFormatEqual(contents_spc, fs(contents_tab))
588
589         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
590         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
591         self.assertFormatEqual(contents_spc, fs(contents_spc))
592         self.assertFormatEqual(contents_spc, fs(contents_tab))
593
594         # mixed tabs and spaces (valid Python 2 code)
595         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
596         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
597         self.assertFormatEqual(contents_spc, fs(contents_spc))
598         self.assertFormatEqual(contents_spc, fs(contents_tab))
599
600         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
601         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
602         self.assertFormatEqual(contents_spc, fs(contents_spc))
603         self.assertFormatEqual(contents_spc, fs(contents_tab))
604
605     def test_report_verbose(self) -> None:
606         report = black.Report(verbose=True)
607         out_lines = []
608         err_lines = []
609
610         def out(msg: str, **kwargs: Any) -> None:
611             out_lines.append(msg)
612
613         def err(msg: str, **kwargs: Any) -> None:
614             err_lines.append(msg)
615
616         with patch("black.out", out), patch("black.err", err):
617             report.done(Path("f1"), black.Changed.NO)
618             self.assertEqual(len(out_lines), 1)
619             self.assertEqual(len(err_lines), 0)
620             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
621             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
622             self.assertEqual(report.return_code, 0)
623             report.done(Path("f2"), black.Changed.YES)
624             self.assertEqual(len(out_lines), 2)
625             self.assertEqual(len(err_lines), 0)
626             self.assertEqual(out_lines[-1], "reformatted f2")
627             self.assertEqual(
628                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
629             )
630             report.done(Path("f3"), black.Changed.CACHED)
631             self.assertEqual(len(out_lines), 3)
632             self.assertEqual(len(err_lines), 0)
633             self.assertEqual(
634                 out_lines[-1], "f3 wasn't modified on disk since last run."
635             )
636             self.assertEqual(
637                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
638             )
639             self.assertEqual(report.return_code, 0)
640             report.check = True
641             self.assertEqual(report.return_code, 1)
642             report.check = False
643             report.failed(Path("e1"), "boom")
644             self.assertEqual(len(out_lines), 3)
645             self.assertEqual(len(err_lines), 1)
646             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
647             self.assertEqual(
648                 unstyle(str(report)),
649                 "1 file reformatted, 2 files left unchanged, "
650                 "1 file failed to reformat.",
651             )
652             self.assertEqual(report.return_code, 123)
653             report.done(Path("f3"), black.Changed.YES)
654             self.assertEqual(len(out_lines), 4)
655             self.assertEqual(len(err_lines), 1)
656             self.assertEqual(out_lines[-1], "reformatted f3")
657             self.assertEqual(
658                 unstyle(str(report)),
659                 "2 files reformatted, 2 files left unchanged, "
660                 "1 file failed to reformat.",
661             )
662             self.assertEqual(report.return_code, 123)
663             report.failed(Path("e2"), "boom")
664             self.assertEqual(len(out_lines), 4)
665             self.assertEqual(len(err_lines), 2)
666             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
667             self.assertEqual(
668                 unstyle(str(report)),
669                 "2 files reformatted, 2 files left unchanged, "
670                 "2 files failed to reformat.",
671             )
672             self.assertEqual(report.return_code, 123)
673             report.path_ignored(Path("wat"), "no match")
674             self.assertEqual(len(out_lines), 5)
675             self.assertEqual(len(err_lines), 2)
676             self.assertEqual(out_lines[-1], "wat ignored: no match")
677             self.assertEqual(
678                 unstyle(str(report)),
679                 "2 files reformatted, 2 files left unchanged, "
680                 "2 files failed to reformat.",
681             )
682             self.assertEqual(report.return_code, 123)
683             report.done(Path("f4"), black.Changed.NO)
684             self.assertEqual(len(out_lines), 6)
685             self.assertEqual(len(err_lines), 2)
686             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
687             self.assertEqual(
688                 unstyle(str(report)),
689                 "2 files reformatted, 3 files left unchanged, "
690                 "2 files failed to reformat.",
691             )
692             self.assertEqual(report.return_code, 123)
693             report.check = True
694             self.assertEqual(
695                 unstyle(str(report)),
696                 "2 files would be reformatted, 3 files would be left unchanged, "
697                 "2 files would fail to reformat.",
698             )
699
700     def test_report_quiet(self) -> None:
701         report = black.Report(quiet=True)
702         out_lines = []
703         err_lines = []
704
705         def out(msg: str, **kwargs: Any) -> None:
706             out_lines.append(msg)
707
708         def err(msg: str, **kwargs: Any) -> None:
709             err_lines.append(msg)
710
711         with patch("black.out", out), patch("black.err", err):
712             report.done(Path("f1"), black.Changed.NO)
713             self.assertEqual(len(out_lines), 0)
714             self.assertEqual(len(err_lines), 0)
715             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
716             self.assertEqual(report.return_code, 0)
717             report.done(Path("f2"), black.Changed.YES)
718             self.assertEqual(len(out_lines), 0)
719             self.assertEqual(len(err_lines), 0)
720             self.assertEqual(
721                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
722             )
723             report.done(Path("f3"), black.Changed.CACHED)
724             self.assertEqual(len(out_lines), 0)
725             self.assertEqual(len(err_lines), 0)
726             self.assertEqual(
727                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
728             )
729             self.assertEqual(report.return_code, 0)
730             report.check = True
731             self.assertEqual(report.return_code, 1)
732             report.check = False
733             report.failed(Path("e1"), "boom")
734             self.assertEqual(len(out_lines), 0)
735             self.assertEqual(len(err_lines), 1)
736             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
737             self.assertEqual(
738                 unstyle(str(report)),
739                 "1 file reformatted, 2 files left unchanged, "
740                 "1 file failed to reformat.",
741             )
742             self.assertEqual(report.return_code, 123)
743             report.done(Path("f3"), black.Changed.YES)
744             self.assertEqual(len(out_lines), 0)
745             self.assertEqual(len(err_lines), 1)
746             self.assertEqual(
747                 unstyle(str(report)),
748                 "2 files reformatted, 2 files left unchanged, "
749                 "1 file failed to reformat.",
750             )
751             self.assertEqual(report.return_code, 123)
752             report.failed(Path("e2"), "boom")
753             self.assertEqual(len(out_lines), 0)
754             self.assertEqual(len(err_lines), 2)
755             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
756             self.assertEqual(
757                 unstyle(str(report)),
758                 "2 files reformatted, 2 files left unchanged, "
759                 "2 files failed to reformat.",
760             )
761             self.assertEqual(report.return_code, 123)
762             report.path_ignored(Path("wat"), "no match")
763             self.assertEqual(len(out_lines), 0)
764             self.assertEqual(len(err_lines), 2)
765             self.assertEqual(
766                 unstyle(str(report)),
767                 "2 files reformatted, 2 files left unchanged, "
768                 "2 files failed to reformat.",
769             )
770             self.assertEqual(report.return_code, 123)
771             report.done(Path("f4"), black.Changed.NO)
772             self.assertEqual(len(out_lines), 0)
773             self.assertEqual(len(err_lines), 2)
774             self.assertEqual(
775                 unstyle(str(report)),
776                 "2 files reformatted, 3 files left unchanged, "
777                 "2 files failed to reformat.",
778             )
779             self.assertEqual(report.return_code, 123)
780             report.check = True
781             self.assertEqual(
782                 unstyle(str(report)),
783                 "2 files would be reformatted, 3 files would be left unchanged, "
784                 "2 files would fail to reformat.",
785             )
786
787     def test_report_normal(self) -> None:
788         report = black.Report()
789         out_lines = []
790         err_lines = []
791
792         def out(msg: str, **kwargs: Any) -> None:
793             out_lines.append(msg)
794
795         def err(msg: str, **kwargs: Any) -> None:
796             err_lines.append(msg)
797
798         with patch("black.out", out), patch("black.err", err):
799             report.done(Path("f1"), black.Changed.NO)
800             self.assertEqual(len(out_lines), 0)
801             self.assertEqual(len(err_lines), 0)
802             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
803             self.assertEqual(report.return_code, 0)
804             report.done(Path("f2"), black.Changed.YES)
805             self.assertEqual(len(out_lines), 1)
806             self.assertEqual(len(err_lines), 0)
807             self.assertEqual(out_lines[-1], "reformatted f2")
808             self.assertEqual(
809                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
810             )
811             report.done(Path("f3"), black.Changed.CACHED)
812             self.assertEqual(len(out_lines), 1)
813             self.assertEqual(len(err_lines), 0)
814             self.assertEqual(out_lines[-1], "reformatted f2")
815             self.assertEqual(
816                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
817             )
818             self.assertEqual(report.return_code, 0)
819             report.check = True
820             self.assertEqual(report.return_code, 1)
821             report.check = False
822             report.failed(Path("e1"), "boom")
823             self.assertEqual(len(out_lines), 1)
824             self.assertEqual(len(err_lines), 1)
825             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
826             self.assertEqual(
827                 unstyle(str(report)),
828                 "1 file reformatted, 2 files left unchanged, "
829                 "1 file failed to reformat.",
830             )
831             self.assertEqual(report.return_code, 123)
832             report.done(Path("f3"), black.Changed.YES)
833             self.assertEqual(len(out_lines), 2)
834             self.assertEqual(len(err_lines), 1)
835             self.assertEqual(out_lines[-1], "reformatted f3")
836             self.assertEqual(
837                 unstyle(str(report)),
838                 "2 files reformatted, 2 files left unchanged, "
839                 "1 file failed to reformat.",
840             )
841             self.assertEqual(report.return_code, 123)
842             report.failed(Path("e2"), "boom")
843             self.assertEqual(len(out_lines), 2)
844             self.assertEqual(len(err_lines), 2)
845             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
846             self.assertEqual(
847                 unstyle(str(report)),
848                 "2 files reformatted, 2 files left unchanged, "
849                 "2 files failed to reformat.",
850             )
851             self.assertEqual(report.return_code, 123)
852             report.path_ignored(Path("wat"), "no match")
853             self.assertEqual(len(out_lines), 2)
854             self.assertEqual(len(err_lines), 2)
855             self.assertEqual(
856                 unstyle(str(report)),
857                 "2 files reformatted, 2 files left unchanged, "
858                 "2 files failed to reformat.",
859             )
860             self.assertEqual(report.return_code, 123)
861             report.done(Path("f4"), black.Changed.NO)
862             self.assertEqual(len(out_lines), 2)
863             self.assertEqual(len(err_lines), 2)
864             self.assertEqual(
865                 unstyle(str(report)),
866                 "2 files reformatted, 3 files left unchanged, "
867                 "2 files failed to reformat.",
868             )
869             self.assertEqual(report.return_code, 123)
870             report.check = True
871             self.assertEqual(
872                 unstyle(str(report)),
873                 "2 files would be reformatted, 3 files would be left unchanged, "
874                 "2 files would fail to reformat.",
875             )
876
877     def test_lib2to3_parse(self) -> None:
878         with self.assertRaises(black.InvalidInput):
879             black.lib2to3_parse("invalid syntax")
880
881         straddling = "x + y"
882         black.lib2to3_parse(straddling)
883         black.lib2to3_parse(straddling, {TargetVersion.PY27})
884         black.lib2to3_parse(straddling, {TargetVersion.PY36})
885         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
886
887         py2_only = "print x"
888         black.lib2to3_parse(py2_only)
889         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
890         with self.assertRaises(black.InvalidInput):
891             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
892         with self.assertRaises(black.InvalidInput):
893             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
894
895         py3_only = "exec(x, end=y)"
896         black.lib2to3_parse(py3_only)
897         with self.assertRaises(black.InvalidInput):
898             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
899         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
900         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
901
902     def test_get_features_used(self) -> None:
903         node = black.lib2to3_parse("def f(*, arg): ...\n")
904         self.assertEqual(black.get_features_used(node), set())
905         node = black.lib2to3_parse("def f(*, arg,): ...\n")
906         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
907         node = black.lib2to3_parse("f(*arg,)\n")
908         self.assertEqual(
909             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
910         )
911         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
912         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
913         node = black.lib2to3_parse("123_456\n")
914         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
915         node = black.lib2to3_parse("123456\n")
916         self.assertEqual(black.get_features_used(node), set())
917         source, expected = read_data("function")
918         node = black.lib2to3_parse(source)
919         expected_features = {
920             Feature.TRAILING_COMMA_IN_CALL,
921             Feature.TRAILING_COMMA_IN_DEF,
922             Feature.F_STRINGS,
923         }
924         self.assertEqual(black.get_features_used(node), expected_features)
925         node = black.lib2to3_parse(expected)
926         self.assertEqual(black.get_features_used(node), expected_features)
927         source, expected = read_data("expression")
928         node = black.lib2to3_parse(source)
929         self.assertEqual(black.get_features_used(node), set())
930         node = black.lib2to3_parse(expected)
931         self.assertEqual(black.get_features_used(node), set())
932
933     def test_get_future_imports(self) -> None:
934         node = black.lib2to3_parse("\n")
935         self.assertEqual(set(), black.get_future_imports(node))
936         node = black.lib2to3_parse("from __future__ import black\n")
937         self.assertEqual({"black"}, black.get_future_imports(node))
938         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
939         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
940         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
941         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
942         node = black.lib2to3_parse(
943             "from __future__ import multiple\nfrom __future__ import imports\n"
944         )
945         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
946         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
947         self.assertEqual({"black"}, black.get_future_imports(node))
948         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
949         self.assertEqual({"black"}, black.get_future_imports(node))
950         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
951         self.assertEqual(set(), black.get_future_imports(node))
952         node = black.lib2to3_parse("from some.module import black\n")
953         self.assertEqual(set(), black.get_future_imports(node))
954         node = black.lib2to3_parse(
955             "from __future__ import unicode_literals as _unicode_literals"
956         )
957         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
958         node = black.lib2to3_parse(
959             "from __future__ import unicode_literals as _lol, print"
960         )
961         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
962
963     def test_debug_visitor(self) -> None:
964         source, _ = read_data("debug_visitor.py")
965         expected, _ = read_data("debug_visitor.out")
966         out_lines = []
967         err_lines = []
968
969         def out(msg: str, **kwargs: Any) -> None:
970             out_lines.append(msg)
971
972         def err(msg: str, **kwargs: Any) -> None:
973             err_lines.append(msg)
974
975         with patch("black.out", out), patch("black.err", err):
976             black.DebugVisitor.show(source)
977         actual = "\n".join(out_lines) + "\n"
978         log_name = ""
979         if expected != actual:
980             log_name = black.dump_to_file(*out_lines)
981         self.assertEqual(
982             expected,
983             actual,
984             f"AST print out is different. Actual version dumped to {log_name}",
985         )
986
987     def test_format_file_contents(self) -> None:
988         empty = ""
989         mode = black.FileMode()
990         with self.assertRaises(black.NothingChanged):
991             black.format_file_contents(empty, mode=mode, fast=False)
992         just_nl = "\n"
993         with self.assertRaises(black.NothingChanged):
994             black.format_file_contents(just_nl, mode=mode, fast=False)
995         same = "l = [1, 2, 3]\n"
996         with self.assertRaises(black.NothingChanged):
997             black.format_file_contents(same, mode=mode, fast=False)
998         different = "l = [1,2,3]"
999         expected = same
1000         actual = black.format_file_contents(different, mode=mode, fast=False)
1001         self.assertEqual(expected, actual)
1002         invalid = "return if you can"
1003         with self.assertRaises(black.InvalidInput) as e:
1004             black.format_file_contents(invalid, mode=mode, fast=False)
1005         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1006
1007     def test_endmarker(self) -> None:
1008         n = black.lib2to3_parse("\n")
1009         self.assertEqual(n.type, black.syms.file_input)
1010         self.assertEqual(len(n.children), 1)
1011         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1012
1013     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1014     def test_assertFormatEqual(self) -> None:
1015         out_lines = []
1016         err_lines = []
1017
1018         def out(msg: str, **kwargs: Any) -> None:
1019             out_lines.append(msg)
1020
1021         def err(msg: str, **kwargs: Any) -> None:
1022             err_lines.append(msg)
1023
1024         with patch("black.out", out), patch("black.err", err):
1025             with self.assertRaises(AssertionError):
1026                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
1027
1028         out_str = "".join(out_lines)
1029         self.assertTrue("Expected tree:" in out_str)
1030         self.assertTrue("Actual tree:" in out_str)
1031         self.assertEqual("".join(err_lines), "")
1032
1033     def test_cache_broken_file(self) -> None:
1034         mode = black.FileMode()
1035         with cache_dir() as workspace:
1036             cache_file = black.get_cache_file(mode)
1037             with cache_file.open("w") as fobj:
1038                 fobj.write("this is not a pickle")
1039             self.assertEqual(black.read_cache(mode), {})
1040             src = (workspace / "test.py").resolve()
1041             with src.open("w") as fobj:
1042                 fobj.write("print('hello')")
1043             self.invokeBlack([str(src)])
1044             cache = black.read_cache(mode)
1045             self.assertIn(src, cache)
1046
1047     def test_cache_single_file_already_cached(self) -> None:
1048         mode = black.FileMode()
1049         with cache_dir() as workspace:
1050             src = (workspace / "test.py").resolve()
1051             with src.open("w") as fobj:
1052                 fobj.write("print('hello')")
1053             black.write_cache({}, [src], mode)
1054             self.invokeBlack([str(src)])
1055             with src.open("r") as fobj:
1056                 self.assertEqual(fobj.read(), "print('hello')")
1057
1058     @event_loop(close=False)
1059     def test_cache_multiple_files(self) -> None:
1060         mode = black.FileMode()
1061         with cache_dir() as workspace, patch(
1062             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1063         ):
1064             one = (workspace / "one.py").resolve()
1065             with one.open("w") as fobj:
1066                 fobj.write("print('hello')")
1067             two = (workspace / "two.py").resolve()
1068             with two.open("w") as fobj:
1069                 fobj.write("print('hello')")
1070             black.write_cache({}, [one], mode)
1071             self.invokeBlack([str(workspace)])
1072             with one.open("r") as fobj:
1073                 self.assertEqual(fobj.read(), "print('hello')")
1074             with two.open("r") as fobj:
1075                 self.assertEqual(fobj.read(), 'print("hello")\n')
1076             cache = black.read_cache(mode)
1077             self.assertIn(one, cache)
1078             self.assertIn(two, cache)
1079
1080     def test_no_cache_when_writeback_diff(self) -> None:
1081         mode = black.FileMode()
1082         with cache_dir() as workspace:
1083             src = (workspace / "test.py").resolve()
1084             with src.open("w") as fobj:
1085                 fobj.write("print('hello')")
1086             self.invokeBlack([str(src), "--diff"])
1087             cache_file = black.get_cache_file(mode)
1088             self.assertFalse(cache_file.exists())
1089
1090     def test_no_cache_when_stdin(self) -> None:
1091         mode = black.FileMode()
1092         with cache_dir():
1093             result = CliRunner().invoke(
1094                 black.main, ["-"], input=BytesIO(b"print('hello')")
1095             )
1096             self.assertEqual(result.exit_code, 0)
1097             cache_file = black.get_cache_file(mode)
1098             self.assertFalse(cache_file.exists())
1099
1100     def test_read_cache_no_cachefile(self) -> None:
1101         mode = black.FileMode()
1102         with cache_dir():
1103             self.assertEqual(black.read_cache(mode), {})
1104
1105     def test_write_cache_read_cache(self) -> None:
1106         mode = black.FileMode()
1107         with cache_dir() as workspace:
1108             src = (workspace / "test.py").resolve()
1109             src.touch()
1110             black.write_cache({}, [src], mode)
1111             cache = black.read_cache(mode)
1112             self.assertIn(src, cache)
1113             self.assertEqual(cache[src], black.get_cache_info(src))
1114
1115     def test_filter_cached(self) -> None:
1116         with TemporaryDirectory() as workspace:
1117             path = Path(workspace)
1118             uncached = (path / "uncached").resolve()
1119             cached = (path / "cached").resolve()
1120             cached_but_changed = (path / "changed").resolve()
1121             uncached.touch()
1122             cached.touch()
1123             cached_but_changed.touch()
1124             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1125             todo, done = black.filter_cached(
1126                 cache, {uncached, cached, cached_but_changed}
1127             )
1128             self.assertEqual(todo, {uncached, cached_but_changed})
1129             self.assertEqual(done, {cached})
1130
1131     def test_write_cache_creates_directory_if_needed(self) -> None:
1132         mode = black.FileMode()
1133         with cache_dir(exists=False) as workspace:
1134             self.assertFalse(workspace.exists())
1135             black.write_cache({}, [], mode)
1136             self.assertTrue(workspace.exists())
1137
1138     @event_loop(close=False)
1139     def test_failed_formatting_does_not_get_cached(self) -> None:
1140         mode = black.FileMode()
1141         with cache_dir() as workspace, patch(
1142             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1143         ):
1144             failing = (workspace / "failing.py").resolve()
1145             with failing.open("w") as fobj:
1146                 fobj.write("not actually python")
1147             clean = (workspace / "clean.py").resolve()
1148             with clean.open("w") as fobj:
1149                 fobj.write('print("hello")\n')
1150             self.invokeBlack([str(workspace)], exit_code=123)
1151             cache = black.read_cache(mode)
1152             self.assertNotIn(failing, cache)
1153             self.assertIn(clean, cache)
1154
1155     def test_write_cache_write_fail(self) -> None:
1156         mode = black.FileMode()
1157         with cache_dir(), patch.object(Path, "open") as mock:
1158             mock.side_effect = OSError
1159             black.write_cache({}, [], mode)
1160
1161     @event_loop(close=False)
1162     def test_check_diff_use_together(self) -> None:
1163         with cache_dir():
1164             # Files which will be reformatted.
1165             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1166             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1167             # Files which will not be reformatted.
1168             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1169             self.invokeBlack([str(src2), "--diff", "--check"])
1170             # Multi file command.
1171             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1172
1173     def test_no_files(self) -> None:
1174         with cache_dir():
1175             # Without an argument, black exits with error code 0.
1176             self.invokeBlack([])
1177
1178     def test_broken_symlink(self) -> None:
1179         with cache_dir() as workspace:
1180             symlink = workspace / "broken_link.py"
1181             try:
1182                 symlink.symlink_to("nonexistent.py")
1183             except OSError as e:
1184                 self.skipTest(f"Can't create symlinks: {e}")
1185             self.invokeBlack([str(workspace.resolve())])
1186
1187     def test_read_cache_line_lengths(self) -> None:
1188         mode = black.FileMode()
1189         short_mode = black.FileMode(line_length=1)
1190         with cache_dir() as workspace:
1191             path = (workspace / "file.py").resolve()
1192             path.touch()
1193             black.write_cache({}, [path], mode)
1194             one = black.read_cache(mode)
1195             self.assertIn(path, one)
1196             two = black.read_cache(short_mode)
1197             self.assertNotIn(path, two)
1198
1199     def test_single_file_force_pyi(self) -> None:
1200         reg_mode = black.FileMode()
1201         pyi_mode = black.FileMode(is_pyi=True)
1202         contents, expected = read_data("force_pyi")
1203         with cache_dir() as workspace:
1204             path = (workspace / "file.py").resolve()
1205             with open(path, "w") as fh:
1206                 fh.write(contents)
1207             self.invokeBlack([str(path), "--pyi"])
1208             with open(path, "r") as fh:
1209                 actual = fh.read()
1210             # verify cache with --pyi is separate
1211             pyi_cache = black.read_cache(pyi_mode)
1212             self.assertIn(path, pyi_cache)
1213             normal_cache = black.read_cache(reg_mode)
1214             self.assertNotIn(path, normal_cache)
1215         self.assertEqual(actual, expected)
1216
1217     @event_loop(close=False)
1218     def test_multi_file_force_pyi(self) -> None:
1219         reg_mode = black.FileMode()
1220         pyi_mode = black.FileMode(is_pyi=True)
1221         contents, expected = read_data("force_pyi")
1222         with cache_dir() as workspace:
1223             paths = [
1224                 (workspace / "file1.py").resolve(),
1225                 (workspace / "file2.py").resolve(),
1226             ]
1227             for path in paths:
1228                 with open(path, "w") as fh:
1229                     fh.write(contents)
1230             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1231             for path in paths:
1232                 with open(path, "r") as fh:
1233                     actual = fh.read()
1234                 self.assertEqual(actual, expected)
1235             # verify cache with --pyi is separate
1236             pyi_cache = black.read_cache(pyi_mode)
1237             normal_cache = black.read_cache(reg_mode)
1238             for path in paths:
1239                 self.assertIn(path, pyi_cache)
1240                 self.assertNotIn(path, normal_cache)
1241
1242     def test_pipe_force_pyi(self) -> None:
1243         source, expected = read_data("force_pyi")
1244         result = CliRunner().invoke(
1245             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1246         )
1247         self.assertEqual(result.exit_code, 0)
1248         actual = result.output
1249         self.assertFormatEqual(actual, expected)
1250
1251     def test_single_file_force_py36(self) -> None:
1252         reg_mode = black.FileMode()
1253         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1254         source, expected = read_data("force_py36")
1255         with cache_dir() as workspace:
1256             path = (workspace / "file.py").resolve()
1257             with open(path, "w") as fh:
1258                 fh.write(source)
1259             self.invokeBlack([str(path), *PY36_ARGS])
1260             with open(path, "r") as fh:
1261                 actual = fh.read()
1262             # verify cache with --target-version is separate
1263             py36_cache = black.read_cache(py36_mode)
1264             self.assertIn(path, py36_cache)
1265             normal_cache = black.read_cache(reg_mode)
1266             self.assertNotIn(path, normal_cache)
1267         self.assertEqual(actual, expected)
1268
1269     @event_loop(close=False)
1270     def test_multi_file_force_py36(self) -> None:
1271         reg_mode = black.FileMode()
1272         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1273         source, expected = read_data("force_py36")
1274         with cache_dir() as workspace:
1275             paths = [
1276                 (workspace / "file1.py").resolve(),
1277                 (workspace / "file2.py").resolve(),
1278             ]
1279             for path in paths:
1280                 with open(path, "w") as fh:
1281                     fh.write(source)
1282             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1283             for path in paths:
1284                 with open(path, "r") as fh:
1285                     actual = fh.read()
1286                 self.assertEqual(actual, expected)
1287             # verify cache with --target-version is separate
1288             pyi_cache = black.read_cache(py36_mode)
1289             normal_cache = black.read_cache(reg_mode)
1290             for path in paths:
1291                 self.assertIn(path, pyi_cache)
1292                 self.assertNotIn(path, normal_cache)
1293
1294     def test_pipe_force_py36(self) -> None:
1295         source, expected = read_data("force_py36")
1296         result = CliRunner().invoke(
1297             black.main,
1298             ["-", "-q", "--target-version=py36"],
1299             input=BytesIO(source.encode("utf8")),
1300         )
1301         self.assertEqual(result.exit_code, 0)
1302         actual = result.output
1303         self.assertFormatEqual(actual, expected)
1304
1305     def test_include_exclude(self) -> None:
1306         path = THIS_DIR / "data" / "include_exclude_tests"
1307         include = re.compile(r"\.pyi?$")
1308         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1309         report = black.Report()
1310         sources: List[Path] = []
1311         expected = [
1312             Path(path / "b/dont_exclude/a.py"),
1313             Path(path / "b/dont_exclude/a.pyi"),
1314         ]
1315         this_abs = THIS_DIR.resolve()
1316         sources.extend(
1317             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1318         )
1319         self.assertEqual(sorted(expected), sorted(sources))
1320
1321     def test_empty_include(self) -> None:
1322         path = THIS_DIR / "data" / "include_exclude_tests"
1323         report = black.Report()
1324         empty = re.compile(r"")
1325         sources: List[Path] = []
1326         expected = [
1327             Path(path / "b/exclude/a.pie"),
1328             Path(path / "b/exclude/a.py"),
1329             Path(path / "b/exclude/a.pyi"),
1330             Path(path / "b/dont_exclude/a.pie"),
1331             Path(path / "b/dont_exclude/a.py"),
1332             Path(path / "b/dont_exclude/a.pyi"),
1333             Path(path / "b/.definitely_exclude/a.pie"),
1334             Path(path / "b/.definitely_exclude/a.py"),
1335             Path(path / "b/.definitely_exclude/a.pyi"),
1336         ]
1337         this_abs = THIS_DIR.resolve()
1338         sources.extend(
1339             black.gen_python_files_in_dir(
1340                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1341             )
1342         )
1343         self.assertEqual(sorted(expected), sorted(sources))
1344
1345     def test_empty_exclude(self) -> None:
1346         path = THIS_DIR / "data" / "include_exclude_tests"
1347         report = black.Report()
1348         empty = re.compile(r"")
1349         sources: List[Path] = []
1350         expected = [
1351             Path(path / "b/dont_exclude/a.py"),
1352             Path(path / "b/dont_exclude/a.pyi"),
1353             Path(path / "b/exclude/a.py"),
1354             Path(path / "b/exclude/a.pyi"),
1355             Path(path / "b/.definitely_exclude/a.py"),
1356             Path(path / "b/.definitely_exclude/a.pyi"),
1357         ]
1358         this_abs = THIS_DIR.resolve()
1359         sources.extend(
1360             black.gen_python_files_in_dir(
1361                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1362             )
1363         )
1364         self.assertEqual(sorted(expected), sorted(sources))
1365
1366     def test_invalid_include_exclude(self) -> None:
1367         for option in ["--include", "--exclude"]:
1368             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1369
1370     def test_preserves_line_endings(self) -> None:
1371         with TemporaryDirectory() as workspace:
1372             test_file = Path(workspace) / "test.py"
1373             for nl in ["\n", "\r\n"]:
1374                 contents = nl.join(["def f(  ):", "    pass"])
1375                 test_file.write_bytes(contents.encode())
1376                 ff(test_file, write_back=black.WriteBack.YES)
1377                 updated_contents: bytes = test_file.read_bytes()
1378                 self.assertIn(nl.encode(), updated_contents)
1379                 if nl == "\n":
1380                     self.assertNotIn(b"\r\n", updated_contents)
1381
1382     def test_preserves_line_endings_via_stdin(self) -> None:
1383         for nl in ["\n", "\r\n"]:
1384             contents = nl.join(["def f(  ):", "    pass"])
1385             runner = BlackRunner()
1386             result = runner.invoke(
1387                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1388             )
1389             self.assertEqual(result.exit_code, 0)
1390             output = runner.stdout_bytes
1391             self.assertIn(nl.encode("utf8"), output)
1392             if nl == "\n":
1393                 self.assertNotIn(b"\r\n", output)
1394
1395     def test_assert_equivalent_different_asts(self) -> None:
1396         with self.assertRaises(AssertionError):
1397             black.assert_equivalent("{}", "None")
1398
1399     def test_symlink_out_of_root_directory(self) -> None:
1400         path = MagicMock()
1401         root = THIS_DIR
1402         child = MagicMock()
1403         include = re.compile(black.DEFAULT_INCLUDES)
1404         exclude = re.compile(black.DEFAULT_EXCLUDES)
1405         report = black.Report()
1406         # `child` should behave like a symlink which resolved path is clearly
1407         # outside of the `root` directory.
1408         path.iterdir.return_value = [child]
1409         child.resolve.return_value = Path("/a/b/c")
1410         child.is_symlink.return_value = True
1411         try:
1412             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1413         except ValueError as ve:
1414             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1415         path.iterdir.assert_called_once()
1416         child.resolve.assert_called_once()
1417         child.is_symlink.assert_called_once()
1418         # `child` should behave like a strange file which resolved path is clearly
1419         # outside of the `root` directory.
1420         child.is_symlink.return_value = False
1421         with self.assertRaises(ValueError):
1422             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1423         path.iterdir.assert_called()
1424         self.assertEqual(path.iterdir.call_count, 2)
1425         child.resolve.assert_called()
1426         self.assertEqual(child.resolve.call_count, 2)
1427         child.is_symlink.assert_called()
1428         self.assertEqual(child.is_symlink.call_count, 2)
1429
1430     def test_shhh_click(self) -> None:
1431         try:
1432             from click import _unicodefun  # type: ignore
1433         except ModuleNotFoundError:
1434             self.skipTest("Incompatible Click version")
1435         if not hasattr(_unicodefun, "_verify_python3_env"):
1436             self.skipTest("Incompatible Click version")
1437         # First, let's see if Click is crashing with a preferred ASCII charset.
1438         with patch("locale.getpreferredencoding") as gpe:
1439             gpe.return_value = "ASCII"
1440             with self.assertRaises(RuntimeError):
1441                 _unicodefun._verify_python3_env()
1442         # Now, let's silence Click...
1443         black.patch_click()
1444         # ...and confirm it's silent.
1445         with patch("locale.getpreferredencoding") as gpe:
1446             gpe.return_value = "ASCII"
1447             try:
1448                 _unicodefun._verify_python3_env()
1449             except RuntimeError as re:
1450                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1451
1452     def test_root_logger_not_used_directly(self) -> None:
1453         def fail(*args: Any, **kwargs: Any) -> None:
1454             self.fail("Record created with root logger")
1455
1456         with patch.multiple(
1457             logging.root,
1458             debug=fail,
1459             info=fail,
1460             warning=fail,
1461             error=fail,
1462             critical=fail,
1463             log=fail,
1464         ):
1465             ff(THIS_FILE)
1466
1467     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1468     @async_test
1469     async def test_blackd_request_needs_formatting(self) -> None:
1470         app = blackd.make_app()
1471         async with TestClient(TestServer(app)) as client:
1472             response = await client.post("/", data=b"print('hello world')")
1473             self.assertEqual(response.status, 200)
1474             self.assertEqual(response.charset, "utf8")
1475             self.assertEqual(await response.read(), b'print("hello world")\n')
1476
1477     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1478     @async_test
1479     async def test_blackd_request_no_change(self) -> None:
1480         app = blackd.make_app()
1481         async with TestClient(TestServer(app)) as client:
1482             response = await client.post("/", data=b'print("hello world")\n')
1483             self.assertEqual(response.status, 204)
1484             self.assertEqual(await response.read(), b"")
1485
1486     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1487     @async_test
1488     async def test_blackd_request_syntax_error(self) -> None:
1489         app = blackd.make_app()
1490         async with TestClient(TestServer(app)) as client:
1491             response = await client.post("/", data=b"what even ( is")
1492             self.assertEqual(response.status, 400)
1493             content = await response.text()
1494             self.assertTrue(
1495                 content.startswith("Cannot parse"),
1496                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1497             )
1498
1499     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1500     @async_test
1501     async def test_blackd_unsupported_version(self) -> None:
1502         app = blackd.make_app()
1503         async with TestClient(TestServer(app)) as client:
1504             response = await client.post(
1505                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1506             )
1507             self.assertEqual(response.status, 501)
1508
1509     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1510     @async_test
1511     async def test_blackd_supported_version(self) -> None:
1512         app = blackd.make_app()
1513         async with TestClient(TestServer(app)) as client:
1514             response = await client.post(
1515                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1516             )
1517             self.assertEqual(response.status, 200)
1518
1519     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1520     @async_test
1521     async def test_blackd_invalid_python_variant(self) -> None:
1522         app = blackd.make_app()
1523         async with TestClient(TestServer(app)) as client:
1524
1525             async def check(header_value: str, expected_status: int = 400) -> None:
1526                 response = await client.post(
1527                     "/",
1528                     data=b"what",
1529                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1530                 )
1531                 self.assertEqual(response.status, expected_status)
1532
1533             await check("lol")
1534             await check("ruby3.5")
1535             await check("pyi3.6")
1536             await check("py1.5")
1537             await check("2.8")
1538             await check("py2.8")
1539             await check("3.0")
1540             await check("pypy3.0")
1541             await check("jython3.4")
1542
1543     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1544     @async_test
1545     async def test_blackd_pyi(self) -> None:
1546         app = blackd.make_app()
1547         async with TestClient(TestServer(app)) as client:
1548             source, expected = read_data("stub.pyi")
1549             response = await client.post(
1550                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1551             )
1552             self.assertEqual(response.status, 200)
1553             self.assertEqual(await response.text(), expected)
1554
1555     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1556     @async_test
1557     async def test_blackd_python_variant(self) -> None:
1558         app = blackd.make_app()
1559         code = (
1560             "def f(\n"
1561             "    and_has_a_bunch_of,\n"
1562             "    very_long_arguments_too,\n"
1563             "    and_lots_of_them_as_well_lol,\n"
1564             "    **and_very_long_keyword_arguments\n"
1565             "):\n"
1566             "    pass\n"
1567         )
1568         async with TestClient(TestServer(app)) as client:
1569
1570             async def check(header_value: str, expected_status: int) -> None:
1571                 response = await client.post(
1572                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1573                 )
1574                 self.assertEqual(response.status, expected_status)
1575
1576             await check("3.6", 200)
1577             await check("py3.6", 200)
1578             await check("3.6,3.7", 200)
1579             await check("3.6,py3.7", 200)
1580
1581             await check("2", 204)
1582             await check("2.7", 204)
1583             await check("py2.7", 204)
1584             await check("3.4", 204)
1585             await check("py3.4", 204)
1586
1587     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1588     @async_test
1589     async def test_blackd_line_length(self) -> None:
1590         app = blackd.make_app()
1591         async with TestClient(TestServer(app)) as client:
1592             response = await client.post(
1593                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1594             )
1595             self.assertEqual(response.status, 200)
1596
1597     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1598     @async_test
1599     async def test_blackd_invalid_line_length(self) -> None:
1600         app = blackd.make_app()
1601         async with TestClient(TestServer(app)) as client:
1602             response = await client.post(
1603                 "/",
1604                 data=b'print("hello")\n',
1605                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1606             )
1607             self.assertEqual(response.status, 400)
1608
1609     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1610     def test_blackd_main(self) -> None:
1611         with patch("blackd.web.run_app"):
1612             result = CliRunner().invoke(blackd.main, [])
1613             if result.exception is not None:
1614                 raise result.exception
1615             self.assertEqual(result.exit_code, 0)
1616
1617
1618 if __name__ == "__main__":
1619     unittest.main(module="test_black")