]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

python/black -> psf/black (#936)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_function_trailing_comma(self) -> None:
269         source, expected = read_data("function_trailing_comma")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     @patch("black.dump_to_file", dump_to_stderr)
276     def test_expression(self) -> None:
277         source, expected = read_data("expression")
278         actual = fs(source)
279         self.assertFormatEqual(expected, actual)
280         black.assert_equivalent(source, actual)
281         black.assert_stable(source, actual, black.FileMode())
282
283     def test_expression_ff(self) -> None:
284         source, expected = read_data("expression")
285         tmp_file = Path(black.dump_to_file(source))
286         try:
287             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
288             with open(tmp_file, encoding="utf8") as f:
289                 actual = f.read()
290         finally:
291             os.unlink(tmp_file)
292         self.assertFormatEqual(expected, actual)
293         with patch("black.dump_to_file", dump_to_stderr):
294             black.assert_equivalent(source, actual)
295             black.assert_stable(source, actual, black.FileMode())
296
297     def test_expression_diff(self) -> None:
298         source, _ = read_data("expression.py")
299         expected, _ = read_data("expression.diff")
300         tmp_file = Path(black.dump_to_file(source))
301         diff_header = re.compile(
302             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
303             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
304         )
305         try:
306             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
307             self.assertEqual(result.exit_code, 0)
308         finally:
309             os.unlink(tmp_file)
310         actual = result.output
311         actual = diff_header.sub("[Deterministic header]", actual)
312         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
313         if expected != actual:
314             dump = black.dump_to_file(actual)
315             msg = (
316                 f"Expected diff isn't equal to the actual. If you made changes "
317                 f"to expression.py and this is an anticipated difference, "
318                 f"overwrite tests/data/expression.diff with {dump}"
319             )
320             self.assertEqual(expected, actual, msg)
321
322     @patch("black.dump_to_file", dump_to_stderr)
323     def test_fstring(self) -> None:
324         source, expected = read_data("fstring")
325         actual = fs(source)
326         self.assertFormatEqual(expected, actual)
327         black.assert_equivalent(source, actual)
328         black.assert_stable(source, actual, black.FileMode())
329
330     @patch("black.dump_to_file", dump_to_stderr)
331     def test_string_quotes(self) -> None:
332         source, expected = read_data("string_quotes")
333         actual = fs(source)
334         self.assertFormatEqual(expected, actual)
335         black.assert_equivalent(source, actual)
336         black.assert_stable(source, actual, black.FileMode())
337         mode = black.FileMode(string_normalization=False)
338         not_normalized = fs(source, mode=mode)
339         self.assertFormatEqual(source, not_normalized)
340         black.assert_equivalent(source, not_normalized)
341         black.assert_stable(source, not_normalized, mode=mode)
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_slices(self) -> None:
345         source, expected = read_data("slices")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, black.FileMode())
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_comments(self) -> None:
353         source, expected = read_data("comments")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, black.FileMode())
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_comments2(self) -> None:
361         source, expected = read_data("comments2")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, black.FileMode())
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_comments3(self) -> None:
369         source, expected = read_data("comments3")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, black.FileMode())
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_comments4(self) -> None:
377         source, expected = read_data("comments4")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, black.FileMode())
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_comments5(self) -> None:
385         source, expected = read_data("comments5")
386         actual = fs(source)
387         self.assertFormatEqual(expected, actual)
388         black.assert_equivalent(source, actual)
389         black.assert_stable(source, actual, black.FileMode())
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_comments6(self) -> None:
393         source, expected = read_data("comments6")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, black.FileMode())
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_comments7(self) -> None:
401         source, expected = read_data("comments7")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, black.FileMode())
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_comment_after_escaped_newline(self) -> None:
409         source, expected = read_data("comment_after_escaped_newline")
410         actual = fs(source)
411         self.assertFormatEqual(expected, actual)
412         black.assert_equivalent(source, actual)
413         black.assert_stable(source, actual, black.FileMode())
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_cantfit(self) -> None:
417         source, expected = read_data("cantfit")
418         actual = fs(source)
419         self.assertFormatEqual(expected, actual)
420         black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, black.FileMode())
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_import_spacing(self) -> None:
425         source, expected = read_data("import_spacing")
426         actual = fs(source)
427         self.assertFormatEqual(expected, actual)
428         black.assert_equivalent(source, actual)
429         black.assert_stable(source, actual, black.FileMode())
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_composition(self) -> None:
433         source, expected = read_data("composition")
434         actual = fs(source)
435         self.assertFormatEqual(expected, actual)
436         black.assert_equivalent(source, actual)
437         black.assert_stable(source, actual, black.FileMode())
438
439     @patch("black.dump_to_file", dump_to_stderr)
440     def test_empty_lines(self) -> None:
441         source, expected = read_data("empty_lines")
442         actual = fs(source)
443         self.assertFormatEqual(expected, actual)
444         black.assert_equivalent(source, actual)
445         black.assert_stable(source, actual, black.FileMode())
446
447     @patch("black.dump_to_file", dump_to_stderr)
448     def test_remove_parens(self) -> None:
449         source, expected = read_data("remove_parens")
450         actual = fs(source)
451         self.assertFormatEqual(expected, actual)
452         black.assert_equivalent(source, actual)
453         black.assert_stable(source, actual, black.FileMode())
454
455     @patch("black.dump_to_file", dump_to_stderr)
456     def test_string_prefixes(self) -> None:
457         source, expected = read_data("string_prefixes")
458         actual = fs(source)
459         self.assertFormatEqual(expected, actual)
460         black.assert_equivalent(source, actual)
461         black.assert_stable(source, actual, black.FileMode())
462
463     @patch("black.dump_to_file", dump_to_stderr)
464     def test_numeric_literals(self) -> None:
465         source, expected = read_data("numeric_literals")
466         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
467         actual = fs(source, mode=mode)
468         self.assertFormatEqual(expected, actual)
469         black.assert_equivalent(source, actual)
470         black.assert_stable(source, actual, mode)
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_numeric_literals_ignoring_underscores(self) -> None:
474         source, expected = read_data("numeric_literals_skip_underscores")
475         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
476         actual = fs(source, mode=mode)
477         self.assertFormatEqual(expected, actual)
478         black.assert_equivalent(source, actual)
479         black.assert_stable(source, actual, mode)
480
481     @patch("black.dump_to_file", dump_to_stderr)
482     def test_numeric_literals_py2(self) -> None:
483         source, expected = read_data("numeric_literals_py2")
484         actual = fs(source)
485         self.assertFormatEqual(expected, actual)
486         black.assert_stable(source, actual, black.FileMode())
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_python2(self) -> None:
490         source, expected = read_data("python2")
491         actual = fs(source)
492         self.assertFormatEqual(expected, actual)
493         black.assert_equivalent(source, actual)
494         black.assert_stable(source, actual, black.FileMode())
495
496     @patch("black.dump_to_file", dump_to_stderr)
497     def test_python2_print_function(self) -> None:
498         source, expected = read_data("python2_print_function")
499         mode = black.FileMode(target_versions={TargetVersion.PY27})
500         actual = fs(source, mode=mode)
501         self.assertFormatEqual(expected, actual)
502         black.assert_equivalent(source, actual)
503         black.assert_stable(source, actual, mode)
504
505     @patch("black.dump_to_file", dump_to_stderr)
506     def test_python2_unicode_literals(self) -> None:
507         source, expected = read_data("python2_unicode_literals")
508         actual = fs(source)
509         self.assertFormatEqual(expected, actual)
510         black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, black.FileMode())
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_stub(self) -> None:
515         mode = black.FileMode(is_pyi=True)
516         source, expected = read_data("stub.pyi")
517         actual = fs(source, mode=mode)
518         self.assertFormatEqual(expected, actual)
519         black.assert_stable(source, actual, mode)
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_async_as_identifier(self) -> None:
523         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
524         source, expected = read_data("async_as_identifier")
525         actual = fs(source)
526         self.assertFormatEqual(expected, actual)
527         major, minor = sys.version_info[:2]
528         if major < 3 or (major <= 3 and minor < 7):
529             black.assert_equivalent(source, actual)
530         black.assert_stable(source, actual, black.FileMode())
531         # ensure black can parse this when the target is 3.6
532         self.invokeBlack([str(source_path), "--target-version", "py36"])
533         # but not on 3.7, because async/await is no longer an identifier
534         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
535
536     @patch("black.dump_to_file", dump_to_stderr)
537     def test_python37(self) -> None:
538         source_path = (THIS_DIR / "data" / "python37.py").resolve()
539         source, expected = read_data("python37")
540         actual = fs(source)
541         self.assertFormatEqual(expected, actual)
542         major, minor = sys.version_info[:2]
543         if major > 3 or (major == 3 and minor >= 7):
544             black.assert_equivalent(source, actual)
545         black.assert_stable(source, actual, black.FileMode())
546         # ensure black can parse this when the target is 3.7
547         self.invokeBlack([str(source_path), "--target-version", "py37"])
548         # but not on 3.6, because we use async as a reserved keyword
549         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
550
551     @patch("black.dump_to_file", dump_to_stderr)
552     def test_fmtonoff(self) -> None:
553         source, expected = read_data("fmtonoff")
554         actual = fs(source)
555         self.assertFormatEqual(expected, actual)
556         black.assert_equivalent(source, actual)
557         black.assert_stable(source, actual, black.FileMode())
558
559     @patch("black.dump_to_file", dump_to_stderr)
560     def test_fmtonoff2(self) -> None:
561         source, expected = read_data("fmtonoff2")
562         actual = fs(source)
563         self.assertFormatEqual(expected, actual)
564         black.assert_equivalent(source, actual)
565         black.assert_stable(source, actual, black.FileMode())
566
567     @patch("black.dump_to_file", dump_to_stderr)
568     def test_remove_empty_parentheses_after_class(self) -> None:
569         source, expected = read_data("class_blank_parentheses")
570         actual = fs(source)
571         self.assertFormatEqual(expected, actual)
572         black.assert_equivalent(source, actual)
573         black.assert_stable(source, actual, black.FileMode())
574
575     @patch("black.dump_to_file", dump_to_stderr)
576     def test_new_line_between_class_and_code(self) -> None:
577         source, expected = read_data("class_methods_new_line")
578         actual = fs(source)
579         self.assertFormatEqual(expected, actual)
580         black.assert_equivalent(source, actual)
581         black.assert_stable(source, actual, black.FileMode())
582
583     @patch("black.dump_to_file", dump_to_stderr)
584     def test_bracket_match(self) -> None:
585         source, expected = read_data("bracketmatch")
586         actual = fs(source)
587         self.assertFormatEqual(expected, actual)
588         black.assert_equivalent(source, actual)
589         black.assert_stable(source, actual, black.FileMode())
590
591     @patch("black.dump_to_file", dump_to_stderr)
592     def test_tuple_assign(self) -> None:
593         source, expected = read_data("tupleassign")
594         actual = fs(source)
595         self.assertFormatEqual(expected, actual)
596         black.assert_equivalent(source, actual)
597         black.assert_stable(source, actual, black.FileMode())
598
599     def test_tab_comment_indentation(self) -> None:
600         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
601         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
602         self.assertFormatEqual(contents_spc, fs(contents_spc))
603         self.assertFormatEqual(contents_spc, fs(contents_tab))
604
605         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
606         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
607         self.assertFormatEqual(contents_spc, fs(contents_spc))
608         self.assertFormatEqual(contents_spc, fs(contents_tab))
609
610         # mixed tabs and spaces (valid Python 2 code)
611         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
612         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
613         self.assertFormatEqual(contents_spc, fs(contents_spc))
614         self.assertFormatEqual(contents_spc, fs(contents_tab))
615
616         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
617         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
618         self.assertFormatEqual(contents_spc, fs(contents_spc))
619         self.assertFormatEqual(contents_spc, fs(contents_tab))
620
621     def test_report_verbose(self) -> None:
622         report = black.Report(verbose=True)
623         out_lines = []
624         err_lines = []
625
626         def out(msg: str, **kwargs: Any) -> None:
627             out_lines.append(msg)
628
629         def err(msg: str, **kwargs: Any) -> None:
630             err_lines.append(msg)
631
632         with patch("black.out", out), patch("black.err", err):
633             report.done(Path("f1"), black.Changed.NO)
634             self.assertEqual(len(out_lines), 1)
635             self.assertEqual(len(err_lines), 0)
636             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
637             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
638             self.assertEqual(report.return_code, 0)
639             report.done(Path("f2"), black.Changed.YES)
640             self.assertEqual(len(out_lines), 2)
641             self.assertEqual(len(err_lines), 0)
642             self.assertEqual(out_lines[-1], "reformatted f2")
643             self.assertEqual(
644                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
645             )
646             report.done(Path("f3"), black.Changed.CACHED)
647             self.assertEqual(len(out_lines), 3)
648             self.assertEqual(len(err_lines), 0)
649             self.assertEqual(
650                 out_lines[-1], "f3 wasn't modified on disk since last run."
651             )
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
654             )
655             self.assertEqual(report.return_code, 0)
656             report.check = True
657             self.assertEqual(report.return_code, 1)
658             report.check = False
659             report.failed(Path("e1"), "boom")
660             self.assertEqual(len(out_lines), 3)
661             self.assertEqual(len(err_lines), 1)
662             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "1 file reformatted, 2 files left unchanged, "
666                 "1 file failed to reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.done(Path("f3"), black.Changed.YES)
670             self.assertEqual(len(out_lines), 4)
671             self.assertEqual(len(err_lines), 1)
672             self.assertEqual(out_lines[-1], "reformatted f3")
673             self.assertEqual(
674                 unstyle(str(report)),
675                 "2 files reformatted, 2 files left unchanged, "
676                 "1 file failed to reformat.",
677             )
678             self.assertEqual(report.return_code, 123)
679             report.failed(Path("e2"), "boom")
680             self.assertEqual(len(out_lines), 4)
681             self.assertEqual(len(err_lines), 2)
682             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "2 files reformatted, 2 files left unchanged, "
686                 "2 files failed to reformat.",
687             )
688             self.assertEqual(report.return_code, 123)
689             report.path_ignored(Path("wat"), "no match")
690             self.assertEqual(len(out_lines), 5)
691             self.assertEqual(len(err_lines), 2)
692             self.assertEqual(out_lines[-1], "wat ignored: no match")
693             self.assertEqual(
694                 unstyle(str(report)),
695                 "2 files reformatted, 2 files left unchanged, "
696                 "2 files failed to reformat.",
697             )
698             self.assertEqual(report.return_code, 123)
699             report.done(Path("f4"), black.Changed.NO)
700             self.assertEqual(len(out_lines), 6)
701             self.assertEqual(len(err_lines), 2)
702             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
703             self.assertEqual(
704                 unstyle(str(report)),
705                 "2 files reformatted, 3 files left unchanged, "
706                 "2 files failed to reformat.",
707             )
708             self.assertEqual(report.return_code, 123)
709             report.check = True
710             self.assertEqual(
711                 unstyle(str(report)),
712                 "2 files would be reformatted, 3 files would be left unchanged, "
713                 "2 files would fail to reformat.",
714             )
715
716     def test_report_quiet(self) -> None:
717         report = black.Report(quiet=True)
718         out_lines = []
719         err_lines = []
720
721         def out(msg: str, **kwargs: Any) -> None:
722             out_lines.append(msg)
723
724         def err(msg: str, **kwargs: Any) -> None:
725             err_lines.append(msg)
726
727         with patch("black.out", out), patch("black.err", err):
728             report.done(Path("f1"), black.Changed.NO)
729             self.assertEqual(len(out_lines), 0)
730             self.assertEqual(len(err_lines), 0)
731             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
732             self.assertEqual(report.return_code, 0)
733             report.done(Path("f2"), black.Changed.YES)
734             self.assertEqual(len(out_lines), 0)
735             self.assertEqual(len(err_lines), 0)
736             self.assertEqual(
737                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
738             )
739             report.done(Path("f3"), black.Changed.CACHED)
740             self.assertEqual(len(out_lines), 0)
741             self.assertEqual(len(err_lines), 0)
742             self.assertEqual(
743                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
744             )
745             self.assertEqual(report.return_code, 0)
746             report.check = True
747             self.assertEqual(report.return_code, 1)
748             report.check = False
749             report.failed(Path("e1"), "boom")
750             self.assertEqual(len(out_lines), 0)
751             self.assertEqual(len(err_lines), 1)
752             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
753             self.assertEqual(
754                 unstyle(str(report)),
755                 "1 file reformatted, 2 files left unchanged, "
756                 "1 file failed to reformat.",
757             )
758             self.assertEqual(report.return_code, 123)
759             report.done(Path("f3"), black.Changed.YES)
760             self.assertEqual(len(out_lines), 0)
761             self.assertEqual(len(err_lines), 1)
762             self.assertEqual(
763                 unstyle(str(report)),
764                 "2 files reformatted, 2 files left unchanged, "
765                 "1 file failed to reformat.",
766             )
767             self.assertEqual(report.return_code, 123)
768             report.failed(Path("e2"), "boom")
769             self.assertEqual(len(out_lines), 0)
770             self.assertEqual(len(err_lines), 2)
771             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
772             self.assertEqual(
773                 unstyle(str(report)),
774                 "2 files reformatted, 2 files left unchanged, "
775                 "2 files failed to reformat.",
776             )
777             self.assertEqual(report.return_code, 123)
778             report.path_ignored(Path("wat"), "no match")
779             self.assertEqual(len(out_lines), 0)
780             self.assertEqual(len(err_lines), 2)
781             self.assertEqual(
782                 unstyle(str(report)),
783                 "2 files reformatted, 2 files left unchanged, "
784                 "2 files failed to reformat.",
785             )
786             self.assertEqual(report.return_code, 123)
787             report.done(Path("f4"), black.Changed.NO)
788             self.assertEqual(len(out_lines), 0)
789             self.assertEqual(len(err_lines), 2)
790             self.assertEqual(
791                 unstyle(str(report)),
792                 "2 files reformatted, 3 files left unchanged, "
793                 "2 files failed to reformat.",
794             )
795             self.assertEqual(report.return_code, 123)
796             report.check = True
797             self.assertEqual(
798                 unstyle(str(report)),
799                 "2 files would be reformatted, 3 files would be left unchanged, "
800                 "2 files would fail to reformat.",
801             )
802
803     def test_report_normal(self) -> None:
804         report = black.Report()
805         out_lines = []
806         err_lines = []
807
808         def out(msg: str, **kwargs: Any) -> None:
809             out_lines.append(msg)
810
811         def err(msg: str, **kwargs: Any) -> None:
812             err_lines.append(msg)
813
814         with patch("black.out", out), patch("black.err", err):
815             report.done(Path("f1"), black.Changed.NO)
816             self.assertEqual(len(out_lines), 0)
817             self.assertEqual(len(err_lines), 0)
818             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
819             self.assertEqual(report.return_code, 0)
820             report.done(Path("f2"), black.Changed.YES)
821             self.assertEqual(len(out_lines), 1)
822             self.assertEqual(len(err_lines), 0)
823             self.assertEqual(out_lines[-1], "reformatted f2")
824             self.assertEqual(
825                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
826             )
827             report.done(Path("f3"), black.Changed.CACHED)
828             self.assertEqual(len(out_lines), 1)
829             self.assertEqual(len(err_lines), 0)
830             self.assertEqual(out_lines[-1], "reformatted f2")
831             self.assertEqual(
832                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
833             )
834             self.assertEqual(report.return_code, 0)
835             report.check = True
836             self.assertEqual(report.return_code, 1)
837             report.check = False
838             report.failed(Path("e1"), "boom")
839             self.assertEqual(len(out_lines), 1)
840             self.assertEqual(len(err_lines), 1)
841             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
842             self.assertEqual(
843                 unstyle(str(report)),
844                 "1 file reformatted, 2 files left unchanged, "
845                 "1 file failed to reformat.",
846             )
847             self.assertEqual(report.return_code, 123)
848             report.done(Path("f3"), black.Changed.YES)
849             self.assertEqual(len(out_lines), 2)
850             self.assertEqual(len(err_lines), 1)
851             self.assertEqual(out_lines[-1], "reformatted f3")
852             self.assertEqual(
853                 unstyle(str(report)),
854                 "2 files reformatted, 2 files left unchanged, "
855                 "1 file failed to reformat.",
856             )
857             self.assertEqual(report.return_code, 123)
858             report.failed(Path("e2"), "boom")
859             self.assertEqual(len(out_lines), 2)
860             self.assertEqual(len(err_lines), 2)
861             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
862             self.assertEqual(
863                 unstyle(str(report)),
864                 "2 files reformatted, 2 files left unchanged, "
865                 "2 files failed to reformat.",
866             )
867             self.assertEqual(report.return_code, 123)
868             report.path_ignored(Path("wat"), "no match")
869             self.assertEqual(len(out_lines), 2)
870             self.assertEqual(len(err_lines), 2)
871             self.assertEqual(
872                 unstyle(str(report)),
873                 "2 files reformatted, 2 files left unchanged, "
874                 "2 files failed to reformat.",
875             )
876             self.assertEqual(report.return_code, 123)
877             report.done(Path("f4"), black.Changed.NO)
878             self.assertEqual(len(out_lines), 2)
879             self.assertEqual(len(err_lines), 2)
880             self.assertEqual(
881                 unstyle(str(report)),
882                 "2 files reformatted, 3 files left unchanged, "
883                 "2 files failed to reformat.",
884             )
885             self.assertEqual(report.return_code, 123)
886             report.check = True
887             self.assertEqual(
888                 unstyle(str(report)),
889                 "2 files would be reformatted, 3 files would be left unchanged, "
890                 "2 files would fail to reformat.",
891             )
892
893     def test_lib2to3_parse(self) -> None:
894         with self.assertRaises(black.InvalidInput):
895             black.lib2to3_parse("invalid syntax")
896
897         straddling = "x + y"
898         black.lib2to3_parse(straddling)
899         black.lib2to3_parse(straddling, {TargetVersion.PY27})
900         black.lib2to3_parse(straddling, {TargetVersion.PY36})
901         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
902
903         py2_only = "print x"
904         black.lib2to3_parse(py2_only)
905         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
906         with self.assertRaises(black.InvalidInput):
907             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
908         with self.assertRaises(black.InvalidInput):
909             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
910
911         py3_only = "exec(x, end=y)"
912         black.lib2to3_parse(py3_only)
913         with self.assertRaises(black.InvalidInput):
914             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
915         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
916         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
917
918     def test_get_features_used(self) -> None:
919         node = black.lib2to3_parse("def f(*, arg): ...\n")
920         self.assertEqual(black.get_features_used(node), set())
921         node = black.lib2to3_parse("def f(*, arg,): ...\n")
922         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
923         node = black.lib2to3_parse("f(*arg,)\n")
924         self.assertEqual(
925             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
926         )
927         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
928         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
929         node = black.lib2to3_parse("123_456\n")
930         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
931         node = black.lib2to3_parse("123456\n")
932         self.assertEqual(black.get_features_used(node), set())
933         source, expected = read_data("function")
934         node = black.lib2to3_parse(source)
935         expected_features = {
936             Feature.TRAILING_COMMA_IN_CALL,
937             Feature.TRAILING_COMMA_IN_DEF,
938             Feature.F_STRINGS,
939         }
940         self.assertEqual(black.get_features_used(node), expected_features)
941         node = black.lib2to3_parse(expected)
942         self.assertEqual(black.get_features_used(node), expected_features)
943         source, expected = read_data("expression")
944         node = black.lib2to3_parse(source)
945         self.assertEqual(black.get_features_used(node), set())
946         node = black.lib2to3_parse(expected)
947         self.assertEqual(black.get_features_used(node), set())
948
949     def test_get_future_imports(self) -> None:
950         node = black.lib2to3_parse("\n")
951         self.assertEqual(set(), black.get_future_imports(node))
952         node = black.lib2to3_parse("from __future__ import black\n")
953         self.assertEqual({"black"}, black.get_future_imports(node))
954         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
955         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
956         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
957         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
958         node = black.lib2to3_parse(
959             "from __future__ import multiple\nfrom __future__ import imports\n"
960         )
961         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
962         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
963         self.assertEqual({"black"}, black.get_future_imports(node))
964         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
965         self.assertEqual({"black"}, black.get_future_imports(node))
966         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
967         self.assertEqual(set(), black.get_future_imports(node))
968         node = black.lib2to3_parse("from some.module import black\n")
969         self.assertEqual(set(), black.get_future_imports(node))
970         node = black.lib2to3_parse(
971             "from __future__ import unicode_literals as _unicode_literals"
972         )
973         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
974         node = black.lib2to3_parse(
975             "from __future__ import unicode_literals as _lol, print"
976         )
977         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
978
979     def test_debug_visitor(self) -> None:
980         source, _ = read_data("debug_visitor.py")
981         expected, _ = read_data("debug_visitor.out")
982         out_lines = []
983         err_lines = []
984
985         def out(msg: str, **kwargs: Any) -> None:
986             out_lines.append(msg)
987
988         def err(msg: str, **kwargs: Any) -> None:
989             err_lines.append(msg)
990
991         with patch("black.out", out), patch("black.err", err):
992             black.DebugVisitor.show(source)
993         actual = "\n".join(out_lines) + "\n"
994         log_name = ""
995         if expected != actual:
996             log_name = black.dump_to_file(*out_lines)
997         self.assertEqual(
998             expected,
999             actual,
1000             f"AST print out is different. Actual version dumped to {log_name}",
1001         )
1002
1003     def test_format_file_contents(self) -> None:
1004         empty = ""
1005         mode = black.FileMode()
1006         with self.assertRaises(black.NothingChanged):
1007             black.format_file_contents(empty, mode=mode, fast=False)
1008         just_nl = "\n"
1009         with self.assertRaises(black.NothingChanged):
1010             black.format_file_contents(just_nl, mode=mode, fast=False)
1011         same = "l = [1, 2, 3]\n"
1012         with self.assertRaises(black.NothingChanged):
1013             black.format_file_contents(same, mode=mode, fast=False)
1014         different = "l = [1,2,3]"
1015         expected = same
1016         actual = black.format_file_contents(different, mode=mode, fast=False)
1017         self.assertEqual(expected, actual)
1018         invalid = "return if you can"
1019         with self.assertRaises(black.InvalidInput) as e:
1020             black.format_file_contents(invalid, mode=mode, fast=False)
1021         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1022
1023     def test_endmarker(self) -> None:
1024         n = black.lib2to3_parse("\n")
1025         self.assertEqual(n.type, black.syms.file_input)
1026         self.assertEqual(len(n.children), 1)
1027         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1028
1029     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1030     def test_assertFormatEqual(self) -> None:
1031         out_lines = []
1032         err_lines = []
1033
1034         def out(msg: str, **kwargs: Any) -> None:
1035             out_lines.append(msg)
1036
1037         def err(msg: str, **kwargs: Any) -> None:
1038             err_lines.append(msg)
1039
1040         with patch("black.out", out), patch("black.err", err):
1041             with self.assertRaises(AssertionError):
1042                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
1043
1044         out_str = "".join(out_lines)
1045         self.assertTrue("Expected tree:" in out_str)
1046         self.assertTrue("Actual tree:" in out_str)
1047         self.assertEqual("".join(err_lines), "")
1048
1049     def test_cache_broken_file(self) -> None:
1050         mode = black.FileMode()
1051         with cache_dir() as workspace:
1052             cache_file = black.get_cache_file(mode)
1053             with cache_file.open("w") as fobj:
1054                 fobj.write("this is not a pickle")
1055             self.assertEqual(black.read_cache(mode), {})
1056             src = (workspace / "test.py").resolve()
1057             with src.open("w") as fobj:
1058                 fobj.write("print('hello')")
1059             self.invokeBlack([str(src)])
1060             cache = black.read_cache(mode)
1061             self.assertIn(src, cache)
1062
1063     def test_cache_single_file_already_cached(self) -> None:
1064         mode = black.FileMode()
1065         with cache_dir() as workspace:
1066             src = (workspace / "test.py").resolve()
1067             with src.open("w") as fobj:
1068                 fobj.write("print('hello')")
1069             black.write_cache({}, [src], mode)
1070             self.invokeBlack([str(src)])
1071             with src.open("r") as fobj:
1072                 self.assertEqual(fobj.read(), "print('hello')")
1073
1074     @event_loop(close=False)
1075     def test_cache_multiple_files(self) -> None:
1076         mode = black.FileMode()
1077         with cache_dir() as workspace, patch(
1078             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1079         ):
1080             one = (workspace / "one.py").resolve()
1081             with one.open("w") as fobj:
1082                 fobj.write("print('hello')")
1083             two = (workspace / "two.py").resolve()
1084             with two.open("w") as fobj:
1085                 fobj.write("print('hello')")
1086             black.write_cache({}, [one], mode)
1087             self.invokeBlack([str(workspace)])
1088             with one.open("r") as fobj:
1089                 self.assertEqual(fobj.read(), "print('hello')")
1090             with two.open("r") as fobj:
1091                 self.assertEqual(fobj.read(), 'print("hello")\n')
1092             cache = black.read_cache(mode)
1093             self.assertIn(one, cache)
1094             self.assertIn(two, cache)
1095
1096     def test_no_cache_when_writeback_diff(self) -> None:
1097         mode = black.FileMode()
1098         with cache_dir() as workspace:
1099             src = (workspace / "test.py").resolve()
1100             with src.open("w") as fobj:
1101                 fobj.write("print('hello')")
1102             self.invokeBlack([str(src), "--diff"])
1103             cache_file = black.get_cache_file(mode)
1104             self.assertFalse(cache_file.exists())
1105
1106     def test_no_cache_when_stdin(self) -> None:
1107         mode = black.FileMode()
1108         with cache_dir():
1109             result = CliRunner().invoke(
1110                 black.main, ["-"], input=BytesIO(b"print('hello')")
1111             )
1112             self.assertEqual(result.exit_code, 0)
1113             cache_file = black.get_cache_file(mode)
1114             self.assertFalse(cache_file.exists())
1115
1116     def test_read_cache_no_cachefile(self) -> None:
1117         mode = black.FileMode()
1118         with cache_dir():
1119             self.assertEqual(black.read_cache(mode), {})
1120
1121     def test_write_cache_read_cache(self) -> None:
1122         mode = black.FileMode()
1123         with cache_dir() as workspace:
1124             src = (workspace / "test.py").resolve()
1125             src.touch()
1126             black.write_cache({}, [src], mode)
1127             cache = black.read_cache(mode)
1128             self.assertIn(src, cache)
1129             self.assertEqual(cache[src], black.get_cache_info(src))
1130
1131     def test_filter_cached(self) -> None:
1132         with TemporaryDirectory() as workspace:
1133             path = Path(workspace)
1134             uncached = (path / "uncached").resolve()
1135             cached = (path / "cached").resolve()
1136             cached_but_changed = (path / "changed").resolve()
1137             uncached.touch()
1138             cached.touch()
1139             cached_but_changed.touch()
1140             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1141             todo, done = black.filter_cached(
1142                 cache, {uncached, cached, cached_but_changed}
1143             )
1144             self.assertEqual(todo, {uncached, cached_but_changed})
1145             self.assertEqual(done, {cached})
1146
1147     def test_write_cache_creates_directory_if_needed(self) -> None:
1148         mode = black.FileMode()
1149         with cache_dir(exists=False) as workspace:
1150             self.assertFalse(workspace.exists())
1151             black.write_cache({}, [], mode)
1152             self.assertTrue(workspace.exists())
1153
1154     @event_loop(close=False)
1155     def test_failed_formatting_does_not_get_cached(self) -> None:
1156         mode = black.FileMode()
1157         with cache_dir() as workspace, patch(
1158             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1159         ):
1160             failing = (workspace / "failing.py").resolve()
1161             with failing.open("w") as fobj:
1162                 fobj.write("not actually python")
1163             clean = (workspace / "clean.py").resolve()
1164             with clean.open("w") as fobj:
1165                 fobj.write('print("hello")\n')
1166             self.invokeBlack([str(workspace)], exit_code=123)
1167             cache = black.read_cache(mode)
1168             self.assertNotIn(failing, cache)
1169             self.assertIn(clean, cache)
1170
1171     def test_write_cache_write_fail(self) -> None:
1172         mode = black.FileMode()
1173         with cache_dir(), patch.object(Path, "open") as mock:
1174             mock.side_effect = OSError
1175             black.write_cache({}, [], mode)
1176
1177     @event_loop(close=False)
1178     def test_check_diff_use_together(self) -> None:
1179         with cache_dir():
1180             # Files which will be reformatted.
1181             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1182             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1183             # Files which will not be reformatted.
1184             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1185             self.invokeBlack([str(src2), "--diff", "--check"])
1186             # Multi file command.
1187             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1188
1189     def test_no_files(self) -> None:
1190         with cache_dir():
1191             # Without an argument, black exits with error code 0.
1192             self.invokeBlack([])
1193
1194     def test_broken_symlink(self) -> None:
1195         with cache_dir() as workspace:
1196             symlink = workspace / "broken_link.py"
1197             try:
1198                 symlink.symlink_to("nonexistent.py")
1199             except OSError as e:
1200                 self.skipTest(f"Can't create symlinks: {e}")
1201             self.invokeBlack([str(workspace.resolve())])
1202
1203     def test_read_cache_line_lengths(self) -> None:
1204         mode = black.FileMode()
1205         short_mode = black.FileMode(line_length=1)
1206         with cache_dir() as workspace:
1207             path = (workspace / "file.py").resolve()
1208             path.touch()
1209             black.write_cache({}, [path], mode)
1210             one = black.read_cache(mode)
1211             self.assertIn(path, one)
1212             two = black.read_cache(short_mode)
1213             self.assertNotIn(path, two)
1214
1215     def test_single_file_force_pyi(self) -> None:
1216         reg_mode = black.FileMode()
1217         pyi_mode = black.FileMode(is_pyi=True)
1218         contents, expected = read_data("force_pyi")
1219         with cache_dir() as workspace:
1220             path = (workspace / "file.py").resolve()
1221             with open(path, "w") as fh:
1222                 fh.write(contents)
1223             self.invokeBlack([str(path), "--pyi"])
1224             with open(path, "r") as fh:
1225                 actual = fh.read()
1226             # verify cache with --pyi is separate
1227             pyi_cache = black.read_cache(pyi_mode)
1228             self.assertIn(path, pyi_cache)
1229             normal_cache = black.read_cache(reg_mode)
1230             self.assertNotIn(path, normal_cache)
1231         self.assertEqual(actual, expected)
1232
1233     @event_loop(close=False)
1234     def test_multi_file_force_pyi(self) -> None:
1235         reg_mode = black.FileMode()
1236         pyi_mode = black.FileMode(is_pyi=True)
1237         contents, expected = read_data("force_pyi")
1238         with cache_dir() as workspace:
1239             paths = [
1240                 (workspace / "file1.py").resolve(),
1241                 (workspace / "file2.py").resolve(),
1242             ]
1243             for path in paths:
1244                 with open(path, "w") as fh:
1245                     fh.write(contents)
1246             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1247             for path in paths:
1248                 with open(path, "r") as fh:
1249                     actual = fh.read()
1250                 self.assertEqual(actual, expected)
1251             # verify cache with --pyi is separate
1252             pyi_cache = black.read_cache(pyi_mode)
1253             normal_cache = black.read_cache(reg_mode)
1254             for path in paths:
1255                 self.assertIn(path, pyi_cache)
1256                 self.assertNotIn(path, normal_cache)
1257
1258     def test_pipe_force_pyi(self) -> None:
1259         source, expected = read_data("force_pyi")
1260         result = CliRunner().invoke(
1261             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1262         )
1263         self.assertEqual(result.exit_code, 0)
1264         actual = result.output
1265         self.assertFormatEqual(actual, expected)
1266
1267     def test_single_file_force_py36(self) -> None:
1268         reg_mode = black.FileMode()
1269         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1270         source, expected = read_data("force_py36")
1271         with cache_dir() as workspace:
1272             path = (workspace / "file.py").resolve()
1273             with open(path, "w") as fh:
1274                 fh.write(source)
1275             self.invokeBlack([str(path), *PY36_ARGS])
1276             with open(path, "r") as fh:
1277                 actual = fh.read()
1278             # verify cache with --target-version is separate
1279             py36_cache = black.read_cache(py36_mode)
1280             self.assertIn(path, py36_cache)
1281             normal_cache = black.read_cache(reg_mode)
1282             self.assertNotIn(path, normal_cache)
1283         self.assertEqual(actual, expected)
1284
1285     @event_loop(close=False)
1286     def test_multi_file_force_py36(self) -> None:
1287         reg_mode = black.FileMode()
1288         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1289         source, expected = read_data("force_py36")
1290         with cache_dir() as workspace:
1291             paths = [
1292                 (workspace / "file1.py").resolve(),
1293                 (workspace / "file2.py").resolve(),
1294             ]
1295             for path in paths:
1296                 with open(path, "w") as fh:
1297                     fh.write(source)
1298             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1299             for path in paths:
1300                 with open(path, "r") as fh:
1301                     actual = fh.read()
1302                 self.assertEqual(actual, expected)
1303             # verify cache with --target-version is separate
1304             pyi_cache = black.read_cache(py36_mode)
1305             normal_cache = black.read_cache(reg_mode)
1306             for path in paths:
1307                 self.assertIn(path, pyi_cache)
1308                 self.assertNotIn(path, normal_cache)
1309
1310     def test_pipe_force_py36(self) -> None:
1311         source, expected = read_data("force_py36")
1312         result = CliRunner().invoke(
1313             black.main,
1314             ["-", "-q", "--target-version=py36"],
1315             input=BytesIO(source.encode("utf8")),
1316         )
1317         self.assertEqual(result.exit_code, 0)
1318         actual = result.output
1319         self.assertFormatEqual(actual, expected)
1320
1321     def test_include_exclude(self) -> None:
1322         path = THIS_DIR / "data" / "include_exclude_tests"
1323         include = re.compile(r"\.pyi?$")
1324         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1325         report = black.Report()
1326         sources: List[Path] = []
1327         expected = [
1328             Path(path / "b/dont_exclude/a.py"),
1329             Path(path / "b/dont_exclude/a.pyi"),
1330         ]
1331         this_abs = THIS_DIR.resolve()
1332         sources.extend(
1333             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1334         )
1335         self.assertEqual(sorted(expected), sorted(sources))
1336
1337     def test_empty_include(self) -> None:
1338         path = THIS_DIR / "data" / "include_exclude_tests"
1339         report = black.Report()
1340         empty = re.compile(r"")
1341         sources: List[Path] = []
1342         expected = [
1343             Path(path / "b/exclude/a.pie"),
1344             Path(path / "b/exclude/a.py"),
1345             Path(path / "b/exclude/a.pyi"),
1346             Path(path / "b/dont_exclude/a.pie"),
1347             Path(path / "b/dont_exclude/a.py"),
1348             Path(path / "b/dont_exclude/a.pyi"),
1349             Path(path / "b/.definitely_exclude/a.pie"),
1350             Path(path / "b/.definitely_exclude/a.py"),
1351             Path(path / "b/.definitely_exclude/a.pyi"),
1352         ]
1353         this_abs = THIS_DIR.resolve()
1354         sources.extend(
1355             black.gen_python_files_in_dir(
1356                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1357             )
1358         )
1359         self.assertEqual(sorted(expected), sorted(sources))
1360
1361     def test_empty_exclude(self) -> None:
1362         path = THIS_DIR / "data" / "include_exclude_tests"
1363         report = black.Report()
1364         empty = re.compile(r"")
1365         sources: List[Path] = []
1366         expected = [
1367             Path(path / "b/dont_exclude/a.py"),
1368             Path(path / "b/dont_exclude/a.pyi"),
1369             Path(path / "b/exclude/a.py"),
1370             Path(path / "b/exclude/a.pyi"),
1371             Path(path / "b/.definitely_exclude/a.py"),
1372             Path(path / "b/.definitely_exclude/a.pyi"),
1373         ]
1374         this_abs = THIS_DIR.resolve()
1375         sources.extend(
1376             black.gen_python_files_in_dir(
1377                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1378             )
1379         )
1380         self.assertEqual(sorted(expected), sorted(sources))
1381
1382     def test_invalid_include_exclude(self) -> None:
1383         for option in ["--include", "--exclude"]:
1384             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1385
1386     def test_preserves_line_endings(self) -> None:
1387         with TemporaryDirectory() as workspace:
1388             test_file = Path(workspace) / "test.py"
1389             for nl in ["\n", "\r\n"]:
1390                 contents = nl.join(["def f(  ):", "    pass"])
1391                 test_file.write_bytes(contents.encode())
1392                 ff(test_file, write_back=black.WriteBack.YES)
1393                 updated_contents: bytes = test_file.read_bytes()
1394                 self.assertIn(nl.encode(), updated_contents)
1395                 if nl == "\n":
1396                     self.assertNotIn(b"\r\n", updated_contents)
1397
1398     def test_preserves_line_endings_via_stdin(self) -> None:
1399         for nl in ["\n", "\r\n"]:
1400             contents = nl.join(["def f(  ):", "    pass"])
1401             runner = BlackRunner()
1402             result = runner.invoke(
1403                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1404             )
1405             self.assertEqual(result.exit_code, 0)
1406             output = runner.stdout_bytes
1407             self.assertIn(nl.encode("utf8"), output)
1408             if nl == "\n":
1409                 self.assertNotIn(b"\r\n", output)
1410
1411     def test_assert_equivalent_different_asts(self) -> None:
1412         with self.assertRaises(AssertionError):
1413             black.assert_equivalent("{}", "None")
1414
1415     def test_symlink_out_of_root_directory(self) -> None:
1416         path = MagicMock()
1417         root = THIS_DIR
1418         child = MagicMock()
1419         include = re.compile(black.DEFAULT_INCLUDES)
1420         exclude = re.compile(black.DEFAULT_EXCLUDES)
1421         report = black.Report()
1422         # `child` should behave like a symlink which resolved path is clearly
1423         # outside of the `root` directory.
1424         path.iterdir.return_value = [child]
1425         child.resolve.return_value = Path("/a/b/c")
1426         child.is_symlink.return_value = True
1427         try:
1428             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1429         except ValueError as ve:
1430             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1431         path.iterdir.assert_called_once()
1432         child.resolve.assert_called_once()
1433         child.is_symlink.assert_called_once()
1434         # `child` should behave like a strange file which resolved path is clearly
1435         # outside of the `root` directory.
1436         child.is_symlink.return_value = False
1437         with self.assertRaises(ValueError):
1438             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1439         path.iterdir.assert_called()
1440         self.assertEqual(path.iterdir.call_count, 2)
1441         child.resolve.assert_called()
1442         self.assertEqual(child.resolve.call_count, 2)
1443         child.is_symlink.assert_called()
1444         self.assertEqual(child.is_symlink.call_count, 2)
1445
1446     def test_shhh_click(self) -> None:
1447         try:
1448             from click import _unicodefun  # type: ignore
1449         except ModuleNotFoundError:
1450             self.skipTest("Incompatible Click version")
1451         if not hasattr(_unicodefun, "_verify_python3_env"):
1452             self.skipTest("Incompatible Click version")
1453         # First, let's see if Click is crashing with a preferred ASCII charset.
1454         with patch("locale.getpreferredencoding") as gpe:
1455             gpe.return_value = "ASCII"
1456             with self.assertRaises(RuntimeError):
1457                 _unicodefun._verify_python3_env()
1458         # Now, let's silence Click...
1459         black.patch_click()
1460         # ...and confirm it's silent.
1461         with patch("locale.getpreferredencoding") as gpe:
1462             gpe.return_value = "ASCII"
1463             try:
1464                 _unicodefun._verify_python3_env()
1465             except RuntimeError as re:
1466                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1467
1468     def test_root_logger_not_used_directly(self) -> None:
1469         def fail(*args: Any, **kwargs: Any) -> None:
1470             self.fail("Record created with root logger")
1471
1472         with patch.multiple(
1473             logging.root,
1474             debug=fail,
1475             info=fail,
1476             warning=fail,
1477             error=fail,
1478             critical=fail,
1479             log=fail,
1480         ):
1481             ff(THIS_FILE)
1482
1483     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1484     @async_test
1485     async def test_blackd_request_needs_formatting(self) -> None:
1486         app = blackd.make_app()
1487         async with TestClient(TestServer(app)) as client:
1488             response = await client.post("/", data=b"print('hello world')")
1489             self.assertEqual(response.status, 200)
1490             self.assertEqual(response.charset, "utf8")
1491             self.assertEqual(await response.read(), b'print("hello world")\n')
1492
1493     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1494     @async_test
1495     async def test_blackd_request_no_change(self) -> None:
1496         app = blackd.make_app()
1497         async with TestClient(TestServer(app)) as client:
1498             response = await client.post("/", data=b'print("hello world")\n')
1499             self.assertEqual(response.status, 204)
1500             self.assertEqual(await response.read(), b"")
1501
1502     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1503     @async_test
1504     async def test_blackd_request_syntax_error(self) -> None:
1505         app = blackd.make_app()
1506         async with TestClient(TestServer(app)) as client:
1507             response = await client.post("/", data=b"what even ( is")
1508             self.assertEqual(response.status, 400)
1509             content = await response.text()
1510             self.assertTrue(
1511                 content.startswith("Cannot parse"),
1512                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1513             )
1514
1515     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1516     @async_test
1517     async def test_blackd_unsupported_version(self) -> None:
1518         app = blackd.make_app()
1519         async with TestClient(TestServer(app)) as client:
1520             response = await client.post(
1521                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1522             )
1523             self.assertEqual(response.status, 501)
1524
1525     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1526     @async_test
1527     async def test_blackd_supported_version(self) -> None:
1528         app = blackd.make_app()
1529         async with TestClient(TestServer(app)) as client:
1530             response = await client.post(
1531                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1532             )
1533             self.assertEqual(response.status, 200)
1534
1535     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1536     @async_test
1537     async def test_blackd_invalid_python_variant(self) -> None:
1538         app = blackd.make_app()
1539         async with TestClient(TestServer(app)) as client:
1540
1541             async def check(header_value: str, expected_status: int = 400) -> None:
1542                 response = await client.post(
1543                     "/",
1544                     data=b"what",
1545                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1546                 )
1547                 self.assertEqual(response.status, expected_status)
1548
1549             await check("lol")
1550             await check("ruby3.5")
1551             await check("pyi3.6")
1552             await check("py1.5")
1553             await check("2.8")
1554             await check("py2.8")
1555             await check("3.0")
1556             await check("pypy3.0")
1557             await check("jython3.4")
1558
1559     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1560     @async_test
1561     async def test_blackd_pyi(self) -> None:
1562         app = blackd.make_app()
1563         async with TestClient(TestServer(app)) as client:
1564             source, expected = read_data("stub.pyi")
1565             response = await client.post(
1566                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1567             )
1568             self.assertEqual(response.status, 200)
1569             self.assertEqual(await response.text(), expected)
1570
1571     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1572     @async_test
1573     async def test_blackd_python_variant(self) -> None:
1574         app = blackd.make_app()
1575         code = (
1576             "def f(\n"
1577             "    and_has_a_bunch_of,\n"
1578             "    very_long_arguments_too,\n"
1579             "    and_lots_of_them_as_well_lol,\n"
1580             "    **and_very_long_keyword_arguments\n"
1581             "):\n"
1582             "    pass\n"
1583         )
1584         async with TestClient(TestServer(app)) as client:
1585
1586             async def check(header_value: str, expected_status: int) -> None:
1587                 response = await client.post(
1588                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1589                 )
1590                 self.assertEqual(response.status, expected_status)
1591
1592             await check("3.6", 200)
1593             await check("py3.6", 200)
1594             await check("3.6,3.7", 200)
1595             await check("3.6,py3.7", 200)
1596
1597             await check("2", 204)
1598             await check("2.7", 204)
1599             await check("py2.7", 204)
1600             await check("3.4", 204)
1601             await check("py3.4", 204)
1602
1603     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1604     @async_test
1605     async def test_blackd_line_length(self) -> None:
1606         app = blackd.make_app()
1607         async with TestClient(TestServer(app)) as client:
1608             response = await client.post(
1609                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1610             )
1611             self.assertEqual(response.status, 200)
1612
1613     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1614     @async_test
1615     async def test_blackd_invalid_line_length(self) -> None:
1616         app = blackd.make_app()
1617         async with TestClient(TestServer(app)) as client:
1618             response = await client.post(
1619                 "/",
1620                 data=b'print("hello")\n',
1621                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1622             )
1623             self.assertEqual(response.status, 400)
1624
1625     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1626     def test_blackd_main(self) -> None:
1627         with patch("blackd.web.run_app"):
1628             result = CliRunner().invoke(blackd.main, [])
1629             if result.exception is not None:
1630                 raise result.exception
1631             self.assertEqual(result.exit_code, 0)
1632
1633
1634 if __name__ == "__main__":
1635     unittest.main(module="test_black")