]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

f4055581f86adf1301d041fbc15e087a557d5c8f
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import regex as re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import Any, BinaryIO, Dict, Generator, List, Tuple, Iterator, TypeVar
14 import unittest
15 from unittest.mock import patch, MagicMock
16
17 import click
18 from click import unstyle
19 from click.testing import CliRunner
20
21 import black
22 from black import Feature, TargetVersion
23
24 try:
25     import blackd
26     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
27     from aiohttp import web
28 except ImportError:
29     has_blackd_deps = False
30 else:
31     has_blackd_deps = True
32
33 from pathspec import PathSpec
34
35 # Import other test classes
36 from .test_primer import PrimerCLITests  # noqa: F401
37
38
39 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
40 fs = partial(black.format_str, mode=black.FileMode())
41 THIS_FILE = Path(__file__)
42 THIS_DIR = THIS_FILE.parent
43 PROJECT_ROOT = THIS_DIR.parent
44 DETERMINISTIC_HEADER = "[Deterministic header]"
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop() -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     loop = policy.new_event_loop()
94     asyncio.set_event_loop(loop)
95     try:
96         yield
97
98     finally:
99         loop.close()
100
101
102 @contextmanager
103 def skip_if_exception(e: str) -> Iterator[None]:
104     try:
105         yield
106     except Exception as exc:
107         if exc.__class__.__name__ == e:
108             unittest.skip(f"Encountered expected exception {exc}, skipping")
109         else:
110             raise
111
112
113 class FakeContext(click.Context):
114     """A fake click Context for when calling functions that need it."""
115
116     def __init__(self) -> None:
117         self.default_map: Dict[str, Any] = {}
118
119
120 class FakeParameter(click.Parameter):
121     """A fake click Parameter for when calling functions that need it."""
122
123     def __init__(self) -> None:
124         pass
125
126
127 class BlackRunner(CliRunner):
128     """Modify CliRunner so that stderr is not merged with stdout.
129
130     This is a hack that can be removed once we depend on Click 7.x"""
131
132     def __init__(self) -> None:
133         self.stderrbuf = BytesIO()
134         self.stdoutbuf = BytesIO()
135         self.stdout_bytes = b""
136         self.stderr_bytes = b""
137         super().__init__()
138
139     @contextmanager
140     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
141         with super().isolation(*args, **kwargs) as output:
142             try:
143                 hold_stderr = sys.stderr
144                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
145                 yield output
146             finally:
147                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
148                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
149                 sys.stderr = hold_stderr
150
151
152 class BlackTestCase(unittest.TestCase):
153     maxDiff = None
154
155     def assertFormatEqual(self, expected: str, actual: str) -> None:
156         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
157             bdv: black.DebugVisitor[Any]
158             black.out("Expected tree:", fg="green")
159             try:
160                 exp_node = black.lib2to3_parse(expected)
161                 bdv = black.DebugVisitor()
162                 list(bdv.visit(exp_node))
163             except Exception as ve:
164                 black.err(str(ve))
165             black.out("Actual tree:", fg="red")
166             try:
167                 exp_node = black.lib2to3_parse(actual)
168                 bdv = black.DebugVisitor()
169                 list(bdv.visit(exp_node))
170             except Exception as ve:
171                 black.err(str(ve))
172         self.assertEqual(expected, actual)
173
174     def invokeBlack(
175         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
176     ) -> None:
177         runner = BlackRunner()
178         if ignore_config:
179             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
180         result = runner.invoke(black.main, args)
181         self.assertEqual(
182             result.exit_code,
183             exit_code,
184             msg=(
185                 f"Failed with args: {args}\n"
186                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
187                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
188                 f"exception: {result.exception}"
189             ),
190         )
191
192     @patch("black.dump_to_file", dump_to_stderr)
193     def checkSourceFile(self, name: str) -> None:
194         path = THIS_DIR.parent / name
195         source, expected = read_data(str(path), data=False)
196         actual = fs(source)
197         self.assertFormatEqual(expected, actual)
198         black.assert_equivalent(source, actual)
199         black.assert_stable(source, actual, black.FileMode())
200         self.assertFalse(ff(path))
201
202     @patch("black.dump_to_file", dump_to_stderr)
203     def test_empty(self) -> None:
204         source = expected = ""
205         actual = fs(source)
206         self.assertFormatEqual(expected, actual)
207         black.assert_equivalent(source, actual)
208         black.assert_stable(source, actual, black.FileMode())
209
210     def test_empty_ff(self) -> None:
211         expected = ""
212         tmp_file = Path(black.dump_to_file())
213         try:
214             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
215             with open(tmp_file, encoding="utf8") as f:
216                 actual = f.read()
217         finally:
218             os.unlink(tmp_file)
219         self.assertFormatEqual(expected, actual)
220
221     def test_self(self) -> None:
222         self.checkSourceFile("tests/test_black.py")
223
224     def test_black(self) -> None:
225         self.checkSourceFile("src/black/__init__.py")
226
227     def test_pygram(self) -> None:
228         self.checkSourceFile("src/blib2to3/pygram.py")
229
230     def test_pytree(self) -> None:
231         self.checkSourceFile("src/blib2to3/pytree.py")
232
233     def test_conv(self) -> None:
234         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
235
236     def test_driver(self) -> None:
237         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
238
239     def test_grammar(self) -> None:
240         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
241
242     def test_literals(self) -> None:
243         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
244
245     def test_parse(self) -> None:
246         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
247
248     def test_pgen(self) -> None:
249         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
250
251     def test_tokenize(self) -> None:
252         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
253
254     def test_token(self) -> None:
255         self.checkSourceFile("src/blib2to3/pgen2/token.py")
256
257     def test_setup(self) -> None:
258         self.checkSourceFile("setup.py")
259
260     def test_piping(self) -> None:
261         source, expected = read_data("src/black/__init__", data=False)
262         result = BlackRunner().invoke(
263             black.main,
264             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
265             input=BytesIO(source.encode("utf8")),
266         )
267         self.assertEqual(result.exit_code, 0)
268         self.assertFormatEqual(expected, result.output)
269         black.assert_equivalent(source, result.output)
270         black.assert_stable(source, result.output, black.FileMode())
271
272     def test_piping_diff(self) -> None:
273         diff_header = re.compile(
274             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
275             r"\+\d\d\d\d"
276         )
277         source, _ = read_data("expression.py")
278         expected, _ = read_data("expression.diff")
279         config = THIS_DIR / "data" / "empty_pyproject.toml"
280         args = [
281             "-",
282             "--fast",
283             f"--line-length={black.DEFAULT_LINE_LENGTH}",
284             "--diff",
285             f"--config={config}",
286         ]
287         result = BlackRunner().invoke(
288             black.main, args, input=BytesIO(source.encode("utf8"))
289         )
290         self.assertEqual(result.exit_code, 0)
291         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
292         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
293         self.assertEqual(expected, actual)
294
295     def test_piping_diff_with_color(self) -> None:
296         source, _ = read_data("expression.py")
297         config = THIS_DIR / "data" / "empty_pyproject.toml"
298         args = [
299             "-",
300             "--fast",
301             f"--line-length={black.DEFAULT_LINE_LENGTH}",
302             "--diff",
303             "--color",
304             f"--config={config}",
305         ]
306         result = BlackRunner().invoke(
307             black.main, args, input=BytesIO(source.encode("utf8"))
308         )
309         actual = result.output
310         # Again, the contents are checked in a different test, so only look for colors.
311         self.assertIn("\033[1;37m", actual)
312         self.assertIn("\033[36m", actual)
313         self.assertIn("\033[32m", actual)
314         self.assertIn("\033[31m", actual)
315         self.assertIn("\033[0m", actual)
316
317     @patch("black.dump_to_file", dump_to_stderr)
318     def test_function(self) -> None:
319         source, expected = read_data("function")
320         actual = fs(source)
321         self.assertFormatEqual(expected, actual)
322         black.assert_equivalent(source, actual)
323         black.assert_stable(source, actual, black.FileMode())
324
325     @patch("black.dump_to_file", dump_to_stderr)
326     def test_function2(self) -> None:
327         source, expected = read_data("function2")
328         actual = fs(source)
329         self.assertFormatEqual(expected, actual)
330         black.assert_equivalent(source, actual)
331         black.assert_stable(source, actual, black.FileMode())
332
333     @patch("black.dump_to_file", dump_to_stderr)
334     def test_function_trailing_comma(self) -> None:
335         source, expected = read_data("function_trailing_comma")
336         actual = fs(source)
337         self.assertFormatEqual(expected, actual)
338         black.assert_equivalent(source, actual)
339         black.assert_stable(source, actual, black.FileMode())
340
341     @patch("black.dump_to_file", dump_to_stderr)
342     def test_expression(self) -> None:
343         source, expected = read_data("expression")
344         actual = fs(source)
345         self.assertFormatEqual(expected, actual)
346         black.assert_equivalent(source, actual)
347         black.assert_stable(source, actual, black.FileMode())
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_pep_572(self) -> None:
351         source, expected = read_data("pep_572")
352         actual = fs(source)
353         self.assertFormatEqual(expected, actual)
354         black.assert_stable(source, actual, black.FileMode())
355         if sys.version_info >= (3, 8):
356             black.assert_equivalent(source, actual)
357
358     def test_pep_572_version_detection(self) -> None:
359         source, _ = read_data("pep_572")
360         root = black.lib2to3_parse(source)
361         features = black.get_features_used(root)
362         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
363         versions = black.detect_target_versions(root)
364         self.assertIn(black.TargetVersion.PY38, versions)
365
366     def test_expression_ff(self) -> None:
367         source, expected = read_data("expression")
368         tmp_file = Path(black.dump_to_file(source))
369         try:
370             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
371             with open(tmp_file, encoding="utf8") as f:
372                 actual = f.read()
373         finally:
374             os.unlink(tmp_file)
375         self.assertFormatEqual(expected, actual)
376         with patch("black.dump_to_file", dump_to_stderr):
377             black.assert_equivalent(source, actual)
378             black.assert_stable(source, actual, black.FileMode())
379
380     def test_expression_diff(self) -> None:
381         source, _ = read_data("expression.py")
382         expected, _ = read_data("expression.diff")
383         tmp_file = Path(black.dump_to_file(source))
384         diff_header = re.compile(
385             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
386             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
387         )
388         try:
389             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
390             self.assertEqual(result.exit_code, 0)
391         finally:
392             os.unlink(tmp_file)
393         actual = result.output
394         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
395         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
396         if expected != actual:
397             dump = black.dump_to_file(actual)
398             msg = (
399                 "Expected diff isn't equal to the actual. If you made changes to"
400                 " expression.py and this is an anticipated difference, overwrite"
401                 f" tests/data/expression.diff with {dump}"
402             )
403             self.assertEqual(expected, actual, msg)
404
405     def test_expression_diff_with_color(self) -> None:
406         source, _ = read_data("expression.py")
407         expected, _ = read_data("expression.diff")
408         tmp_file = Path(black.dump_to_file(source))
409         try:
410             result = BlackRunner().invoke(
411                 black.main, ["--diff", "--color", str(tmp_file)]
412             )
413         finally:
414             os.unlink(tmp_file)
415         actual = result.output
416         # We check the contents of the diff in `test_expression_diff`. All
417         # we need to check here is that color codes exist in the result.
418         self.assertIn("\033[1;37m", actual)
419         self.assertIn("\033[36m", actual)
420         self.assertIn("\033[32m", actual)
421         self.assertIn("\033[31m", actual)
422         self.assertIn("\033[0m", actual)
423
424     @patch("black.dump_to_file", dump_to_stderr)
425     def test_fstring(self) -> None:
426         source, expected = read_data("fstring")
427         actual = fs(source)
428         self.assertFormatEqual(expected, actual)
429         black.assert_equivalent(source, actual)
430         black.assert_stable(source, actual, black.FileMode())
431
432     @patch("black.dump_to_file", dump_to_stderr)
433     def test_pep_570(self) -> None:
434         source, expected = read_data("pep_570")
435         actual = fs(source)
436         self.assertFormatEqual(expected, actual)
437         black.assert_stable(source, actual, black.FileMode())
438         if sys.version_info >= (3, 8):
439             black.assert_equivalent(source, actual)
440
441     def test_detect_pos_only_arguments(self) -> None:
442         source, _ = read_data("pep_570")
443         root = black.lib2to3_parse(source)
444         features = black.get_features_used(root)
445         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
446         versions = black.detect_target_versions(root)
447         self.assertIn(black.TargetVersion.PY38, versions)
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_string_quotes(self) -> None:
451         source, expected = read_data("string_quotes")
452         actual = fs(source)
453         self.assertFormatEqual(expected, actual)
454         black.assert_equivalent(source, actual)
455         black.assert_stable(source, actual, black.FileMode())
456         mode = black.FileMode(string_normalization=False)
457         not_normalized = fs(source, mode=mode)
458         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
459         black.assert_equivalent(source, not_normalized)
460         black.assert_stable(source, not_normalized, mode=mode)
461
462     @patch("black.dump_to_file", dump_to_stderr)
463     def test_docstring(self) -> None:
464         source, expected = read_data("docstring")
465         actual = fs(source)
466         self.assertFormatEqual(expected, actual)
467         black.assert_equivalent(source, actual)
468         black.assert_stable(source, actual, black.FileMode())
469
470     def test_long_strings(self) -> None:
471         """Tests for splitting long strings."""
472         source, expected = read_data("long_strings")
473         actual = fs(source)
474         self.assertFormatEqual(expected, actual)
475         black.assert_equivalent(source, actual)
476         black.assert_stable(source, actual, black.FileMode())
477
478     @patch("black.dump_to_file", dump_to_stderr)
479     def test_long_strings__edge_case(self) -> None:
480         """Edge-case tests for splitting long strings."""
481         source, expected = read_data("long_strings__edge_case")
482         actual = fs(source)
483         self.assertFormatEqual(expected, actual)
484         black.assert_equivalent(source, actual)
485         black.assert_stable(source, actual, black.FileMode())
486
487     @patch("black.dump_to_file", dump_to_stderr)
488     def test_long_strings__regression(self) -> None:
489         """Regression tests for splitting long strings."""
490         source, expected = read_data("long_strings__regression")
491         actual = fs(source)
492         self.assertFormatEqual(expected, actual)
493         black.assert_equivalent(source, actual)
494         black.assert_stable(source, actual, black.FileMode())
495
496     @patch("black.dump_to_file", dump_to_stderr)
497     def test_slices(self) -> None:
498         source, expected = read_data("slices")
499         actual = fs(source)
500         self.assertFormatEqual(expected, actual)
501         black.assert_equivalent(source, actual)
502         black.assert_stable(source, actual, black.FileMode())
503
504     @patch("black.dump_to_file", dump_to_stderr)
505     def test_comments(self) -> None:
506         source, expected = read_data("comments")
507         actual = fs(source)
508         self.assertFormatEqual(expected, actual)
509         black.assert_equivalent(source, actual)
510         black.assert_stable(source, actual, black.FileMode())
511
512     @patch("black.dump_to_file", dump_to_stderr)
513     def test_comments2(self) -> None:
514         source, expected = read_data("comments2")
515         actual = fs(source)
516         self.assertFormatEqual(expected, actual)
517         black.assert_equivalent(source, actual)
518         black.assert_stable(source, actual, black.FileMode())
519
520     @patch("black.dump_to_file", dump_to_stderr)
521     def test_comments3(self) -> None:
522         source, expected = read_data("comments3")
523         actual = fs(source)
524         self.assertFormatEqual(expected, actual)
525         black.assert_equivalent(source, actual)
526         black.assert_stable(source, actual, black.FileMode())
527
528     @patch("black.dump_to_file", dump_to_stderr)
529     def test_comments4(self) -> None:
530         source, expected = read_data("comments4")
531         actual = fs(source)
532         self.assertFormatEqual(expected, actual)
533         black.assert_equivalent(source, actual)
534         black.assert_stable(source, actual, black.FileMode())
535
536     @patch("black.dump_to_file", dump_to_stderr)
537     def test_comments5(self) -> None:
538         source, expected = read_data("comments5")
539         actual = fs(source)
540         self.assertFormatEqual(expected, actual)
541         black.assert_equivalent(source, actual)
542         black.assert_stable(source, actual, black.FileMode())
543
544     @patch("black.dump_to_file", dump_to_stderr)
545     def test_comments6(self) -> None:
546         source, expected = read_data("comments6")
547         actual = fs(source)
548         self.assertFormatEqual(expected, actual)
549         black.assert_equivalent(source, actual)
550         black.assert_stable(source, actual, black.FileMode())
551
552     @patch("black.dump_to_file", dump_to_stderr)
553     def test_comments7(self) -> None:
554         source, expected = read_data("comments7")
555         actual = fs(source)
556         self.assertFormatEqual(expected, actual)
557         black.assert_equivalent(source, actual)
558         black.assert_stable(source, actual, black.FileMode())
559
560     @patch("black.dump_to_file", dump_to_stderr)
561     def test_comment_after_escaped_newline(self) -> None:
562         source, expected = read_data("comment_after_escaped_newline")
563         actual = fs(source)
564         self.assertFormatEqual(expected, actual)
565         black.assert_equivalent(source, actual)
566         black.assert_stable(source, actual, black.FileMode())
567
568     @patch("black.dump_to_file", dump_to_stderr)
569     def test_cantfit(self) -> None:
570         source, expected = read_data("cantfit")
571         actual = fs(source)
572         self.assertFormatEqual(expected, actual)
573         black.assert_equivalent(source, actual)
574         black.assert_stable(source, actual, black.FileMode())
575
576     @patch("black.dump_to_file", dump_to_stderr)
577     def test_import_spacing(self) -> None:
578         source, expected = read_data("import_spacing")
579         actual = fs(source)
580         self.assertFormatEqual(expected, actual)
581         black.assert_equivalent(source, actual)
582         black.assert_stable(source, actual, black.FileMode())
583
584     @patch("black.dump_to_file", dump_to_stderr)
585     def test_composition(self) -> None:
586         source, expected = read_data("composition")
587         actual = fs(source)
588         self.assertFormatEqual(expected, actual)
589         black.assert_equivalent(source, actual)
590         black.assert_stable(source, actual, black.FileMode())
591
592     @patch("black.dump_to_file", dump_to_stderr)
593     def test_empty_lines(self) -> None:
594         source, expected = read_data("empty_lines")
595         actual = fs(source)
596         self.assertFormatEqual(expected, actual)
597         black.assert_equivalent(source, actual)
598         black.assert_stable(source, actual, black.FileMode())
599
600     @patch("black.dump_to_file", dump_to_stderr)
601     def test_remove_parens(self) -> None:
602         source, expected = read_data("remove_parens")
603         actual = fs(source)
604         self.assertFormatEqual(expected, actual)
605         black.assert_equivalent(source, actual)
606         black.assert_stable(source, actual, black.FileMode())
607
608     @patch("black.dump_to_file", dump_to_stderr)
609     def test_string_prefixes(self) -> None:
610         source, expected = read_data("string_prefixes")
611         actual = fs(source)
612         self.assertFormatEqual(expected, actual)
613         black.assert_equivalent(source, actual)
614         black.assert_stable(source, actual, black.FileMode())
615
616     @patch("black.dump_to_file", dump_to_stderr)
617     def test_numeric_literals(self) -> None:
618         source, expected = read_data("numeric_literals")
619         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
620         actual = fs(source, mode=mode)
621         self.assertFormatEqual(expected, actual)
622         black.assert_equivalent(source, actual)
623         black.assert_stable(source, actual, mode)
624
625     @patch("black.dump_to_file", dump_to_stderr)
626     def test_numeric_literals_ignoring_underscores(self) -> None:
627         source, expected = read_data("numeric_literals_skip_underscores")
628         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
629         actual = fs(source, mode=mode)
630         self.assertFormatEqual(expected, actual)
631         black.assert_equivalent(source, actual)
632         black.assert_stable(source, actual, mode)
633
634     @patch("black.dump_to_file", dump_to_stderr)
635     def test_numeric_literals_py2(self) -> None:
636         source, expected = read_data("numeric_literals_py2")
637         actual = fs(source)
638         self.assertFormatEqual(expected, actual)
639         black.assert_stable(source, actual, black.FileMode())
640
641     @patch("black.dump_to_file", dump_to_stderr)
642     def test_python2(self) -> None:
643         source, expected = read_data("python2")
644         actual = fs(source)
645         self.assertFormatEqual(expected, actual)
646         black.assert_equivalent(source, actual)
647         black.assert_stable(source, actual, black.FileMode())
648
649     @patch("black.dump_to_file", dump_to_stderr)
650     def test_python2_print_function(self) -> None:
651         source, expected = read_data("python2_print_function")
652         mode = black.FileMode(target_versions={TargetVersion.PY27})
653         actual = fs(source, mode=mode)
654         self.assertFormatEqual(expected, actual)
655         black.assert_equivalent(source, actual)
656         black.assert_stable(source, actual, mode)
657
658     @patch("black.dump_to_file", dump_to_stderr)
659     def test_python2_unicode_literals(self) -> None:
660         source, expected = read_data("python2_unicode_literals")
661         actual = fs(source)
662         self.assertFormatEqual(expected, actual)
663         black.assert_equivalent(source, actual)
664         black.assert_stable(source, actual, black.FileMode())
665
666     @patch("black.dump_to_file", dump_to_stderr)
667     def test_stub(self) -> None:
668         mode = black.FileMode(is_pyi=True)
669         source, expected = read_data("stub.pyi")
670         actual = fs(source, mode=mode)
671         self.assertFormatEqual(expected, actual)
672         black.assert_stable(source, actual, mode)
673
674     @patch("black.dump_to_file", dump_to_stderr)
675     def test_async_as_identifier(self) -> None:
676         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
677         source, expected = read_data("async_as_identifier")
678         actual = fs(source)
679         self.assertFormatEqual(expected, actual)
680         major, minor = sys.version_info[:2]
681         if major < 3 or (major <= 3 and minor < 7):
682             black.assert_equivalent(source, actual)
683         black.assert_stable(source, actual, black.FileMode())
684         # ensure black can parse this when the target is 3.6
685         self.invokeBlack([str(source_path), "--target-version", "py36"])
686         # but not on 3.7, because async/await is no longer an identifier
687         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
688
689     @patch("black.dump_to_file", dump_to_stderr)
690     def test_python37(self) -> None:
691         source_path = (THIS_DIR / "data" / "python37.py").resolve()
692         source, expected = read_data("python37")
693         actual = fs(source)
694         self.assertFormatEqual(expected, actual)
695         major, minor = sys.version_info[:2]
696         if major > 3 or (major == 3 and minor >= 7):
697             black.assert_equivalent(source, actual)
698         black.assert_stable(source, actual, black.FileMode())
699         # ensure black can parse this when the target is 3.7
700         self.invokeBlack([str(source_path), "--target-version", "py37"])
701         # but not on 3.6, because we use async as a reserved keyword
702         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
703
704     @patch("black.dump_to_file", dump_to_stderr)
705     def test_python38(self) -> None:
706         source, expected = read_data("python38")
707         actual = fs(source)
708         self.assertFormatEqual(expected, actual)
709         major, minor = sys.version_info[:2]
710         if major > 3 or (major == 3 and minor >= 8):
711             black.assert_equivalent(source, actual)
712         black.assert_stable(source, actual, black.FileMode())
713
714     @patch("black.dump_to_file", dump_to_stderr)
715     def test_fmtonoff(self) -> None:
716         source, expected = read_data("fmtonoff")
717         actual = fs(source)
718         self.assertFormatEqual(expected, actual)
719         black.assert_equivalent(source, actual)
720         black.assert_stable(source, actual, black.FileMode())
721
722     @patch("black.dump_to_file", dump_to_stderr)
723     def test_fmtonoff2(self) -> None:
724         source, expected = read_data("fmtonoff2")
725         actual = fs(source)
726         self.assertFormatEqual(expected, actual)
727         black.assert_equivalent(source, actual)
728         black.assert_stable(source, actual, black.FileMode())
729
730     @patch("black.dump_to_file", dump_to_stderr)
731     def test_fmtonoff3(self) -> None:
732         source, expected = read_data("fmtonoff3")
733         actual = fs(source)
734         self.assertFormatEqual(expected, actual)
735         black.assert_equivalent(source, actual)
736         black.assert_stable(source, actual, black.FileMode())
737
738     @patch("black.dump_to_file", dump_to_stderr)
739     def test_fmtonoff4(self) -> None:
740         source, expected = read_data("fmtonoff4")
741         actual = fs(source)
742         self.assertFormatEqual(expected, actual)
743         black.assert_equivalent(source, actual)
744         black.assert_stable(source, actual, black.FileMode())
745
746     @patch("black.dump_to_file", dump_to_stderr)
747     def test_remove_empty_parentheses_after_class(self) -> None:
748         source, expected = read_data("class_blank_parentheses")
749         actual = fs(source)
750         self.assertFormatEqual(expected, actual)
751         black.assert_equivalent(source, actual)
752         black.assert_stable(source, actual, black.FileMode())
753
754     @patch("black.dump_to_file", dump_to_stderr)
755     def test_new_line_between_class_and_code(self) -> None:
756         source, expected = read_data("class_methods_new_line")
757         actual = fs(source)
758         self.assertFormatEqual(expected, actual)
759         black.assert_equivalent(source, actual)
760         black.assert_stable(source, actual, black.FileMode())
761
762     @patch("black.dump_to_file", dump_to_stderr)
763     def test_bracket_match(self) -> None:
764         source, expected = read_data("bracketmatch")
765         actual = fs(source)
766         self.assertFormatEqual(expected, actual)
767         black.assert_equivalent(source, actual)
768         black.assert_stable(source, actual, black.FileMode())
769
770     @patch("black.dump_to_file", dump_to_stderr)
771     def test_tuple_assign(self) -> None:
772         source, expected = read_data("tupleassign")
773         actual = fs(source)
774         self.assertFormatEqual(expected, actual)
775         black.assert_equivalent(source, actual)
776         black.assert_stable(source, actual, black.FileMode())
777
778     @patch("black.dump_to_file", dump_to_stderr)
779     def test_beginning_backslash(self) -> None:
780         source, expected = read_data("beginning_backslash")
781         actual = fs(source)
782         self.assertFormatEqual(expected, actual)
783         black.assert_equivalent(source, actual)
784         black.assert_stable(source, actual, black.FileMode())
785
786     def test_tab_comment_indentation(self) -> None:
787         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
788         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
789         self.assertFormatEqual(contents_spc, fs(contents_spc))
790         self.assertFormatEqual(contents_spc, fs(contents_tab))
791
792         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
793         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
794         self.assertFormatEqual(contents_spc, fs(contents_spc))
795         self.assertFormatEqual(contents_spc, fs(contents_tab))
796
797         # mixed tabs and spaces (valid Python 2 code)
798         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
799         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
800         self.assertFormatEqual(contents_spc, fs(contents_spc))
801         self.assertFormatEqual(contents_spc, fs(contents_tab))
802
803         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
804         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
805         self.assertFormatEqual(contents_spc, fs(contents_spc))
806         self.assertFormatEqual(contents_spc, fs(contents_tab))
807
808     def test_report_verbose(self) -> None:
809         report = black.Report(verbose=True)
810         out_lines = []
811         err_lines = []
812
813         def out(msg: str, **kwargs: Any) -> None:
814             out_lines.append(msg)
815
816         def err(msg: str, **kwargs: Any) -> None:
817             err_lines.append(msg)
818
819         with patch("black.out", out), patch("black.err", err):
820             report.done(Path("f1"), black.Changed.NO)
821             self.assertEqual(len(out_lines), 1)
822             self.assertEqual(len(err_lines), 0)
823             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
824             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
825             self.assertEqual(report.return_code, 0)
826             report.done(Path("f2"), black.Changed.YES)
827             self.assertEqual(len(out_lines), 2)
828             self.assertEqual(len(err_lines), 0)
829             self.assertEqual(out_lines[-1], "reformatted f2")
830             self.assertEqual(
831                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
832             )
833             report.done(Path("f3"), black.Changed.CACHED)
834             self.assertEqual(len(out_lines), 3)
835             self.assertEqual(len(err_lines), 0)
836             self.assertEqual(
837                 out_lines[-1], "f3 wasn't modified on disk since last run."
838             )
839             self.assertEqual(
840                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
841             )
842             self.assertEqual(report.return_code, 0)
843             report.check = True
844             self.assertEqual(report.return_code, 1)
845             report.check = False
846             report.failed(Path("e1"), "boom")
847             self.assertEqual(len(out_lines), 3)
848             self.assertEqual(len(err_lines), 1)
849             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
850             self.assertEqual(
851                 unstyle(str(report)),
852                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
853                 " reformat.",
854             )
855             self.assertEqual(report.return_code, 123)
856             report.done(Path("f3"), black.Changed.YES)
857             self.assertEqual(len(out_lines), 4)
858             self.assertEqual(len(err_lines), 1)
859             self.assertEqual(out_lines[-1], "reformatted f3")
860             self.assertEqual(
861                 unstyle(str(report)),
862                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
863                 " reformat.",
864             )
865             self.assertEqual(report.return_code, 123)
866             report.failed(Path("e2"), "boom")
867             self.assertEqual(len(out_lines), 4)
868             self.assertEqual(len(err_lines), 2)
869             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
870             self.assertEqual(
871                 unstyle(str(report)),
872                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
873                 " reformat.",
874             )
875             self.assertEqual(report.return_code, 123)
876             report.path_ignored(Path("wat"), "no match")
877             self.assertEqual(len(out_lines), 5)
878             self.assertEqual(len(err_lines), 2)
879             self.assertEqual(out_lines[-1], "wat ignored: no match")
880             self.assertEqual(
881                 unstyle(str(report)),
882                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
883                 " reformat.",
884             )
885             self.assertEqual(report.return_code, 123)
886             report.done(Path("f4"), black.Changed.NO)
887             self.assertEqual(len(out_lines), 6)
888             self.assertEqual(len(err_lines), 2)
889             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
890             self.assertEqual(
891                 unstyle(str(report)),
892                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
893                 " reformat.",
894             )
895             self.assertEqual(report.return_code, 123)
896             report.check = True
897             self.assertEqual(
898                 unstyle(str(report)),
899                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
900                 " would fail to reformat.",
901             )
902             report.check = False
903             report.diff = True
904             self.assertEqual(
905                 unstyle(str(report)),
906                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
907                 " would fail to reformat.",
908             )
909
910     def test_report_quiet(self) -> None:
911         report = black.Report(quiet=True)
912         out_lines = []
913         err_lines = []
914
915         def out(msg: str, **kwargs: Any) -> None:
916             out_lines.append(msg)
917
918         def err(msg: str, **kwargs: Any) -> None:
919             err_lines.append(msg)
920
921         with patch("black.out", out), patch("black.err", err):
922             report.done(Path("f1"), black.Changed.NO)
923             self.assertEqual(len(out_lines), 0)
924             self.assertEqual(len(err_lines), 0)
925             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
926             self.assertEqual(report.return_code, 0)
927             report.done(Path("f2"), black.Changed.YES)
928             self.assertEqual(len(out_lines), 0)
929             self.assertEqual(len(err_lines), 0)
930             self.assertEqual(
931                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
932             )
933             report.done(Path("f3"), black.Changed.CACHED)
934             self.assertEqual(len(out_lines), 0)
935             self.assertEqual(len(err_lines), 0)
936             self.assertEqual(
937                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
938             )
939             self.assertEqual(report.return_code, 0)
940             report.check = True
941             self.assertEqual(report.return_code, 1)
942             report.check = False
943             report.failed(Path("e1"), "boom")
944             self.assertEqual(len(out_lines), 0)
945             self.assertEqual(len(err_lines), 1)
946             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
947             self.assertEqual(
948                 unstyle(str(report)),
949                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
950                 " reformat.",
951             )
952             self.assertEqual(report.return_code, 123)
953             report.done(Path("f3"), black.Changed.YES)
954             self.assertEqual(len(out_lines), 0)
955             self.assertEqual(len(err_lines), 1)
956             self.assertEqual(
957                 unstyle(str(report)),
958                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
959                 " reformat.",
960             )
961             self.assertEqual(report.return_code, 123)
962             report.failed(Path("e2"), "boom")
963             self.assertEqual(len(out_lines), 0)
964             self.assertEqual(len(err_lines), 2)
965             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
966             self.assertEqual(
967                 unstyle(str(report)),
968                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
969                 " reformat.",
970             )
971             self.assertEqual(report.return_code, 123)
972             report.path_ignored(Path("wat"), "no match")
973             self.assertEqual(len(out_lines), 0)
974             self.assertEqual(len(err_lines), 2)
975             self.assertEqual(
976                 unstyle(str(report)),
977                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
978                 " reformat.",
979             )
980             self.assertEqual(report.return_code, 123)
981             report.done(Path("f4"), black.Changed.NO)
982             self.assertEqual(len(out_lines), 0)
983             self.assertEqual(len(err_lines), 2)
984             self.assertEqual(
985                 unstyle(str(report)),
986                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
987                 " reformat.",
988             )
989             self.assertEqual(report.return_code, 123)
990             report.check = True
991             self.assertEqual(
992                 unstyle(str(report)),
993                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
994                 " would fail to reformat.",
995             )
996             report.check = False
997             report.diff = True
998             self.assertEqual(
999                 unstyle(str(report)),
1000                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1001                 " would fail to reformat.",
1002             )
1003
1004     def test_report_normal(self) -> None:
1005         report = black.Report()
1006         out_lines = []
1007         err_lines = []
1008
1009         def out(msg: str, **kwargs: Any) -> None:
1010             out_lines.append(msg)
1011
1012         def err(msg: str, **kwargs: Any) -> None:
1013             err_lines.append(msg)
1014
1015         with patch("black.out", out), patch("black.err", err):
1016             report.done(Path("f1"), black.Changed.NO)
1017             self.assertEqual(len(out_lines), 0)
1018             self.assertEqual(len(err_lines), 0)
1019             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1020             self.assertEqual(report.return_code, 0)
1021             report.done(Path("f2"), black.Changed.YES)
1022             self.assertEqual(len(out_lines), 1)
1023             self.assertEqual(len(err_lines), 0)
1024             self.assertEqual(out_lines[-1], "reformatted f2")
1025             self.assertEqual(
1026                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1027             )
1028             report.done(Path("f3"), black.Changed.CACHED)
1029             self.assertEqual(len(out_lines), 1)
1030             self.assertEqual(len(err_lines), 0)
1031             self.assertEqual(out_lines[-1], "reformatted f2")
1032             self.assertEqual(
1033                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1034             )
1035             self.assertEqual(report.return_code, 0)
1036             report.check = True
1037             self.assertEqual(report.return_code, 1)
1038             report.check = False
1039             report.failed(Path("e1"), "boom")
1040             self.assertEqual(len(out_lines), 1)
1041             self.assertEqual(len(err_lines), 1)
1042             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1043             self.assertEqual(
1044                 unstyle(str(report)),
1045                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1046                 " reformat.",
1047             )
1048             self.assertEqual(report.return_code, 123)
1049             report.done(Path("f3"), black.Changed.YES)
1050             self.assertEqual(len(out_lines), 2)
1051             self.assertEqual(len(err_lines), 1)
1052             self.assertEqual(out_lines[-1], "reformatted f3")
1053             self.assertEqual(
1054                 unstyle(str(report)),
1055                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1056                 " reformat.",
1057             )
1058             self.assertEqual(report.return_code, 123)
1059             report.failed(Path("e2"), "boom")
1060             self.assertEqual(len(out_lines), 2)
1061             self.assertEqual(len(err_lines), 2)
1062             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1063             self.assertEqual(
1064                 unstyle(str(report)),
1065                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1066                 " reformat.",
1067             )
1068             self.assertEqual(report.return_code, 123)
1069             report.path_ignored(Path("wat"), "no match")
1070             self.assertEqual(len(out_lines), 2)
1071             self.assertEqual(len(err_lines), 2)
1072             self.assertEqual(
1073                 unstyle(str(report)),
1074                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1075                 " reformat.",
1076             )
1077             self.assertEqual(report.return_code, 123)
1078             report.done(Path("f4"), black.Changed.NO)
1079             self.assertEqual(len(out_lines), 2)
1080             self.assertEqual(len(err_lines), 2)
1081             self.assertEqual(
1082                 unstyle(str(report)),
1083                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1084                 " reformat.",
1085             )
1086             self.assertEqual(report.return_code, 123)
1087             report.check = True
1088             self.assertEqual(
1089                 unstyle(str(report)),
1090                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1091                 " would fail to reformat.",
1092             )
1093             report.check = False
1094             report.diff = True
1095             self.assertEqual(
1096                 unstyle(str(report)),
1097                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1098                 " would fail to reformat.",
1099             )
1100
1101     def test_lib2to3_parse(self) -> None:
1102         with self.assertRaises(black.InvalidInput):
1103             black.lib2to3_parse("invalid syntax")
1104
1105         straddling = "x + y"
1106         black.lib2to3_parse(straddling)
1107         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1108         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1109         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1110
1111         py2_only = "print x"
1112         black.lib2to3_parse(py2_only)
1113         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1114         with self.assertRaises(black.InvalidInput):
1115             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1116         with self.assertRaises(black.InvalidInput):
1117             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1118
1119         py3_only = "exec(x, end=y)"
1120         black.lib2to3_parse(py3_only)
1121         with self.assertRaises(black.InvalidInput):
1122             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1123         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1124         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1125
1126     def test_get_features_used(self) -> None:
1127         node = black.lib2to3_parse("def f(*, arg): ...\n")
1128         self.assertEqual(black.get_features_used(node), set())
1129         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1130         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1131         node = black.lib2to3_parse("f(*arg,)\n")
1132         self.assertEqual(
1133             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1134         )
1135         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1136         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1137         node = black.lib2to3_parse("123_456\n")
1138         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1139         node = black.lib2to3_parse("123456\n")
1140         self.assertEqual(black.get_features_used(node), set())
1141         source, expected = read_data("function")
1142         node = black.lib2to3_parse(source)
1143         expected_features = {
1144             Feature.TRAILING_COMMA_IN_CALL,
1145             Feature.TRAILING_COMMA_IN_DEF,
1146             Feature.F_STRINGS,
1147         }
1148         self.assertEqual(black.get_features_used(node), expected_features)
1149         node = black.lib2to3_parse(expected)
1150         self.assertEqual(black.get_features_used(node), expected_features)
1151         source, expected = read_data("expression")
1152         node = black.lib2to3_parse(source)
1153         self.assertEqual(black.get_features_used(node), set())
1154         node = black.lib2to3_parse(expected)
1155         self.assertEqual(black.get_features_used(node), set())
1156
1157     def test_get_future_imports(self) -> None:
1158         node = black.lib2to3_parse("\n")
1159         self.assertEqual(set(), black.get_future_imports(node))
1160         node = black.lib2to3_parse("from __future__ import black\n")
1161         self.assertEqual({"black"}, black.get_future_imports(node))
1162         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1163         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1164         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1165         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1166         node = black.lib2to3_parse(
1167             "from __future__ import multiple\nfrom __future__ import imports\n"
1168         )
1169         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1170         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1171         self.assertEqual({"black"}, black.get_future_imports(node))
1172         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1173         self.assertEqual({"black"}, black.get_future_imports(node))
1174         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1175         self.assertEqual(set(), black.get_future_imports(node))
1176         node = black.lib2to3_parse("from some.module import black\n")
1177         self.assertEqual(set(), black.get_future_imports(node))
1178         node = black.lib2to3_parse(
1179             "from __future__ import unicode_literals as _unicode_literals"
1180         )
1181         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1182         node = black.lib2to3_parse(
1183             "from __future__ import unicode_literals as _lol, print"
1184         )
1185         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1186
1187     def test_debug_visitor(self) -> None:
1188         source, _ = read_data("debug_visitor.py")
1189         expected, _ = read_data("debug_visitor.out")
1190         out_lines = []
1191         err_lines = []
1192
1193         def out(msg: str, **kwargs: Any) -> None:
1194             out_lines.append(msg)
1195
1196         def err(msg: str, **kwargs: Any) -> None:
1197             err_lines.append(msg)
1198
1199         with patch("black.out", out), patch("black.err", err):
1200             black.DebugVisitor.show(source)
1201         actual = "\n".join(out_lines) + "\n"
1202         log_name = ""
1203         if expected != actual:
1204             log_name = black.dump_to_file(*out_lines)
1205         self.assertEqual(
1206             expected,
1207             actual,
1208             f"AST print out is different. Actual version dumped to {log_name}",
1209         )
1210
1211     def test_format_file_contents(self) -> None:
1212         empty = ""
1213         mode = black.FileMode()
1214         with self.assertRaises(black.NothingChanged):
1215             black.format_file_contents(empty, mode=mode, fast=False)
1216         just_nl = "\n"
1217         with self.assertRaises(black.NothingChanged):
1218             black.format_file_contents(just_nl, mode=mode, fast=False)
1219         same = "j = [1, 2, 3]\n"
1220         with self.assertRaises(black.NothingChanged):
1221             black.format_file_contents(same, mode=mode, fast=False)
1222         different = "j = [1,2,3]"
1223         expected = same
1224         actual = black.format_file_contents(different, mode=mode, fast=False)
1225         self.assertEqual(expected, actual)
1226         invalid = "return if you can"
1227         with self.assertRaises(black.InvalidInput) as e:
1228             black.format_file_contents(invalid, mode=mode, fast=False)
1229         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1230
1231     def test_endmarker(self) -> None:
1232         n = black.lib2to3_parse("\n")
1233         self.assertEqual(n.type, black.syms.file_input)
1234         self.assertEqual(len(n.children), 1)
1235         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1236
1237     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1238     def test_assertFormatEqual(self) -> None:
1239         out_lines = []
1240         err_lines = []
1241
1242         def out(msg: str, **kwargs: Any) -> None:
1243             out_lines.append(msg)
1244
1245         def err(msg: str, **kwargs: Any) -> None:
1246             err_lines.append(msg)
1247
1248         with patch("black.out", out), patch("black.err", err):
1249             with self.assertRaises(AssertionError):
1250                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1251
1252         out_str = "".join(out_lines)
1253         self.assertTrue("Expected tree:" in out_str)
1254         self.assertTrue("Actual tree:" in out_str)
1255         self.assertEqual("".join(err_lines), "")
1256
1257     def test_cache_broken_file(self) -> None:
1258         mode = black.FileMode()
1259         with cache_dir() as workspace:
1260             cache_file = black.get_cache_file(mode)
1261             with cache_file.open("w") as fobj:
1262                 fobj.write("this is not a pickle")
1263             self.assertEqual(black.read_cache(mode), {})
1264             src = (workspace / "test.py").resolve()
1265             with src.open("w") as fobj:
1266                 fobj.write("print('hello')")
1267             self.invokeBlack([str(src)])
1268             cache = black.read_cache(mode)
1269             self.assertIn(src, cache)
1270
1271     def test_cache_single_file_already_cached(self) -> None:
1272         mode = black.FileMode()
1273         with cache_dir() as workspace:
1274             src = (workspace / "test.py").resolve()
1275             with src.open("w") as fobj:
1276                 fobj.write("print('hello')")
1277             black.write_cache({}, [src], mode)
1278             self.invokeBlack([str(src)])
1279             with src.open("r") as fobj:
1280                 self.assertEqual(fobj.read(), "print('hello')")
1281
1282     @event_loop()
1283     def test_cache_multiple_files(self) -> None:
1284         mode = black.FileMode()
1285         with cache_dir() as workspace, patch(
1286             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1287         ):
1288             one = (workspace / "one.py").resolve()
1289             with one.open("w") as fobj:
1290                 fobj.write("print('hello')")
1291             two = (workspace / "two.py").resolve()
1292             with two.open("w") as fobj:
1293                 fobj.write("print('hello')")
1294             black.write_cache({}, [one], mode)
1295             self.invokeBlack([str(workspace)])
1296             with one.open("r") as fobj:
1297                 self.assertEqual(fobj.read(), "print('hello')")
1298             with two.open("r") as fobj:
1299                 self.assertEqual(fobj.read(), 'print("hello")\n')
1300             cache = black.read_cache(mode)
1301             self.assertIn(one, cache)
1302             self.assertIn(two, cache)
1303
1304     def test_no_cache_when_writeback_diff(self) -> None:
1305         mode = black.FileMode()
1306         with cache_dir() as workspace:
1307             src = (workspace / "test.py").resolve()
1308             with src.open("w") as fobj:
1309                 fobj.write("print('hello')")
1310             self.invokeBlack([str(src), "--diff"])
1311             cache_file = black.get_cache_file(mode)
1312             self.assertFalse(cache_file.exists())
1313
1314     def test_no_cache_when_stdin(self) -> None:
1315         mode = black.FileMode()
1316         with cache_dir():
1317             result = CliRunner().invoke(
1318                 black.main, ["-"], input=BytesIO(b"print('hello')")
1319             )
1320             self.assertEqual(result.exit_code, 0)
1321             cache_file = black.get_cache_file(mode)
1322             self.assertFalse(cache_file.exists())
1323
1324     def test_read_cache_no_cachefile(self) -> None:
1325         mode = black.FileMode()
1326         with cache_dir():
1327             self.assertEqual(black.read_cache(mode), {})
1328
1329     def test_write_cache_read_cache(self) -> None:
1330         mode = black.FileMode()
1331         with cache_dir() as workspace:
1332             src = (workspace / "test.py").resolve()
1333             src.touch()
1334             black.write_cache({}, [src], mode)
1335             cache = black.read_cache(mode)
1336             self.assertIn(src, cache)
1337             self.assertEqual(cache[src], black.get_cache_info(src))
1338
1339     def test_filter_cached(self) -> None:
1340         with TemporaryDirectory() as workspace:
1341             path = Path(workspace)
1342             uncached = (path / "uncached").resolve()
1343             cached = (path / "cached").resolve()
1344             cached_but_changed = (path / "changed").resolve()
1345             uncached.touch()
1346             cached.touch()
1347             cached_but_changed.touch()
1348             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1349             todo, done = black.filter_cached(
1350                 cache, {uncached, cached, cached_but_changed}
1351             )
1352             self.assertEqual(todo, {uncached, cached_but_changed})
1353             self.assertEqual(done, {cached})
1354
1355     def test_write_cache_creates_directory_if_needed(self) -> None:
1356         mode = black.FileMode()
1357         with cache_dir(exists=False) as workspace:
1358             self.assertFalse(workspace.exists())
1359             black.write_cache({}, [], mode)
1360             self.assertTrue(workspace.exists())
1361
1362     @event_loop()
1363     def test_failed_formatting_does_not_get_cached(self) -> None:
1364         mode = black.FileMode()
1365         with cache_dir() as workspace, patch(
1366             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1367         ):
1368             failing = (workspace / "failing.py").resolve()
1369             with failing.open("w") as fobj:
1370                 fobj.write("not actually python")
1371             clean = (workspace / "clean.py").resolve()
1372             with clean.open("w") as fobj:
1373                 fobj.write('print("hello")\n')
1374             self.invokeBlack([str(workspace)], exit_code=123)
1375             cache = black.read_cache(mode)
1376             self.assertNotIn(failing, cache)
1377             self.assertIn(clean, cache)
1378
1379     def test_write_cache_write_fail(self) -> None:
1380         mode = black.FileMode()
1381         with cache_dir(), patch.object(Path, "open") as mock:
1382             mock.side_effect = OSError
1383             black.write_cache({}, [], mode)
1384
1385     @event_loop()
1386     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1387     def test_works_in_mono_process_only_environment(self) -> None:
1388         with cache_dir() as workspace:
1389             for f in [
1390                 (workspace / "one.py").resolve(),
1391                 (workspace / "two.py").resolve(),
1392             ]:
1393                 f.write_text('print("hello")\n')
1394             self.invokeBlack([str(workspace)])
1395
1396     @event_loop()
1397     def test_check_diff_use_together(self) -> None:
1398         with cache_dir():
1399             # Files which will be reformatted.
1400             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1401             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1402             # Files which will not be reformatted.
1403             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1404             self.invokeBlack([str(src2), "--diff", "--check"])
1405             # Multi file command.
1406             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1407
1408     def test_no_files(self) -> None:
1409         with cache_dir():
1410             # Without an argument, black exits with error code 0.
1411             self.invokeBlack([])
1412
1413     def test_broken_symlink(self) -> None:
1414         with cache_dir() as workspace:
1415             symlink = workspace / "broken_link.py"
1416             try:
1417                 symlink.symlink_to("nonexistent.py")
1418             except OSError as e:
1419                 self.skipTest(f"Can't create symlinks: {e}")
1420             self.invokeBlack([str(workspace.resolve())])
1421
1422     def test_read_cache_line_lengths(self) -> None:
1423         mode = black.FileMode()
1424         short_mode = black.FileMode(line_length=1)
1425         with cache_dir() as workspace:
1426             path = (workspace / "file.py").resolve()
1427             path.touch()
1428             black.write_cache({}, [path], mode)
1429             one = black.read_cache(mode)
1430             self.assertIn(path, one)
1431             two = black.read_cache(short_mode)
1432             self.assertNotIn(path, two)
1433
1434     def test_tricky_unicode_symbols(self) -> None:
1435         source, expected = read_data("tricky_unicode_symbols")
1436         actual = fs(source)
1437         self.assertFormatEqual(expected, actual)
1438         black.assert_equivalent(source, actual)
1439         black.assert_stable(source, actual, black.FileMode())
1440
1441     def test_single_file_force_pyi(self) -> None:
1442         reg_mode = black.FileMode()
1443         pyi_mode = black.FileMode(is_pyi=True)
1444         contents, expected = read_data("force_pyi")
1445         with cache_dir() as workspace:
1446             path = (workspace / "file.py").resolve()
1447             with open(path, "w") as fh:
1448                 fh.write(contents)
1449             self.invokeBlack([str(path), "--pyi"])
1450             with open(path, "r") as fh:
1451                 actual = fh.read()
1452             # verify cache with --pyi is separate
1453             pyi_cache = black.read_cache(pyi_mode)
1454             self.assertIn(path, pyi_cache)
1455             normal_cache = black.read_cache(reg_mode)
1456             self.assertNotIn(path, normal_cache)
1457         self.assertEqual(actual, expected)
1458
1459     @event_loop()
1460     def test_multi_file_force_pyi(self) -> None:
1461         reg_mode = black.FileMode()
1462         pyi_mode = black.FileMode(is_pyi=True)
1463         contents, expected = read_data("force_pyi")
1464         with cache_dir() as workspace:
1465             paths = [
1466                 (workspace / "file1.py").resolve(),
1467                 (workspace / "file2.py").resolve(),
1468             ]
1469             for path in paths:
1470                 with open(path, "w") as fh:
1471                     fh.write(contents)
1472             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1473             for path in paths:
1474                 with open(path, "r") as fh:
1475                     actual = fh.read()
1476                 self.assertEqual(actual, expected)
1477             # verify cache with --pyi is separate
1478             pyi_cache = black.read_cache(pyi_mode)
1479             normal_cache = black.read_cache(reg_mode)
1480             for path in paths:
1481                 self.assertIn(path, pyi_cache)
1482                 self.assertNotIn(path, normal_cache)
1483
1484     def test_pipe_force_pyi(self) -> None:
1485         source, expected = read_data("force_pyi")
1486         result = CliRunner().invoke(
1487             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1488         )
1489         self.assertEqual(result.exit_code, 0)
1490         actual = result.output
1491         self.assertFormatEqual(actual, expected)
1492
1493     def test_single_file_force_py36(self) -> None:
1494         reg_mode = black.FileMode()
1495         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1496         source, expected = read_data("force_py36")
1497         with cache_dir() as workspace:
1498             path = (workspace / "file.py").resolve()
1499             with open(path, "w") as fh:
1500                 fh.write(source)
1501             self.invokeBlack([str(path), *PY36_ARGS])
1502             with open(path, "r") as fh:
1503                 actual = fh.read()
1504             # verify cache with --target-version is separate
1505             py36_cache = black.read_cache(py36_mode)
1506             self.assertIn(path, py36_cache)
1507             normal_cache = black.read_cache(reg_mode)
1508             self.assertNotIn(path, normal_cache)
1509         self.assertEqual(actual, expected)
1510
1511     @event_loop()
1512     def test_multi_file_force_py36(self) -> None:
1513         reg_mode = black.FileMode()
1514         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1515         source, expected = read_data("force_py36")
1516         with cache_dir() as workspace:
1517             paths = [
1518                 (workspace / "file1.py").resolve(),
1519                 (workspace / "file2.py").resolve(),
1520             ]
1521             for path in paths:
1522                 with open(path, "w") as fh:
1523                     fh.write(source)
1524             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1525             for path in paths:
1526                 with open(path, "r") as fh:
1527                     actual = fh.read()
1528                 self.assertEqual(actual, expected)
1529             # verify cache with --target-version is separate
1530             pyi_cache = black.read_cache(py36_mode)
1531             normal_cache = black.read_cache(reg_mode)
1532             for path in paths:
1533                 self.assertIn(path, pyi_cache)
1534                 self.assertNotIn(path, normal_cache)
1535
1536     def test_collections(self) -> None:
1537         source, expected = read_data("collections")
1538         actual = fs(source)
1539         self.assertFormatEqual(expected, actual)
1540         black.assert_equivalent(source, actual)
1541         black.assert_stable(source, actual, black.FileMode())
1542
1543     def test_pipe_force_py36(self) -> None:
1544         source, expected = read_data("force_py36")
1545         result = CliRunner().invoke(
1546             black.main,
1547             ["-", "-q", "--target-version=py36"],
1548             input=BytesIO(source.encode("utf8")),
1549         )
1550         self.assertEqual(result.exit_code, 0)
1551         actual = result.output
1552         self.assertFormatEqual(actual, expected)
1553
1554     def test_include_exclude(self) -> None:
1555         path = THIS_DIR / "data" / "include_exclude_tests"
1556         include = re.compile(r"\.pyi?$")
1557         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1558         report = black.Report()
1559         gitignore = PathSpec.from_lines("gitwildmatch", [])
1560         sources: List[Path] = []
1561         expected = [
1562             Path(path / "b/dont_exclude/a.py"),
1563             Path(path / "b/dont_exclude/a.pyi"),
1564         ]
1565         this_abs = THIS_DIR.resolve()
1566         sources.extend(
1567             black.gen_python_files(
1568                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1569             )
1570         )
1571         self.assertEqual(sorted(expected), sorted(sources))
1572
1573     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1574     def test_exclude_for_issue_1572(self) -> None:
1575         # Exclude shouldn't touch files that were explicitly given to Black through the
1576         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1577         # https://github.com/psf/black/issues/1572
1578         path = THIS_DIR / "data" / "include_exclude_tests"
1579         include = ""
1580         exclude = r"/exclude/|a\.py"
1581         src = str(path / "b/exclude/a.py")
1582         report = black.Report()
1583         expected = [Path(path / "b/exclude/a.py")]
1584         sources = list(
1585             black.get_sources(
1586                 ctx=FakeContext(),
1587                 src=(src,),
1588                 quiet=True,
1589                 verbose=False,
1590                 include=include,
1591                 exclude=exclude,
1592                 force_exclude=None,
1593                 report=report,
1594             )
1595         )
1596         self.assertEqual(sorted(expected), sorted(sources))
1597
1598     def test_gitignore_exclude(self) -> None:
1599         path = THIS_DIR / "data" / "include_exclude_tests"
1600         include = re.compile(r"\.pyi?$")
1601         exclude = re.compile(r"")
1602         report = black.Report()
1603         gitignore = PathSpec.from_lines(
1604             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1605         )
1606         sources: List[Path] = []
1607         expected = [
1608             Path(path / "b/dont_exclude/a.py"),
1609             Path(path / "b/dont_exclude/a.pyi"),
1610         ]
1611         this_abs = THIS_DIR.resolve()
1612         sources.extend(
1613             black.gen_python_files(
1614                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1615             )
1616         )
1617         self.assertEqual(sorted(expected), sorted(sources))
1618
1619     def test_empty_include(self) -> None:
1620         path = THIS_DIR / "data" / "include_exclude_tests"
1621         report = black.Report()
1622         gitignore = PathSpec.from_lines("gitwildmatch", [])
1623         empty = re.compile(r"")
1624         sources: List[Path] = []
1625         expected = [
1626             Path(path / "b/exclude/a.pie"),
1627             Path(path / "b/exclude/a.py"),
1628             Path(path / "b/exclude/a.pyi"),
1629             Path(path / "b/dont_exclude/a.pie"),
1630             Path(path / "b/dont_exclude/a.py"),
1631             Path(path / "b/dont_exclude/a.pyi"),
1632             Path(path / "b/.definitely_exclude/a.pie"),
1633             Path(path / "b/.definitely_exclude/a.py"),
1634             Path(path / "b/.definitely_exclude/a.pyi"),
1635         ]
1636         this_abs = THIS_DIR.resolve()
1637         sources.extend(
1638             black.gen_python_files(
1639                 path.iterdir(),
1640                 this_abs,
1641                 empty,
1642                 re.compile(black.DEFAULT_EXCLUDES),
1643                 None,
1644                 report,
1645                 gitignore,
1646             )
1647         )
1648         self.assertEqual(sorted(expected), sorted(sources))
1649
1650     def test_empty_exclude(self) -> None:
1651         path = THIS_DIR / "data" / "include_exclude_tests"
1652         report = black.Report()
1653         gitignore = PathSpec.from_lines("gitwildmatch", [])
1654         empty = re.compile(r"")
1655         sources: List[Path] = []
1656         expected = [
1657             Path(path / "b/dont_exclude/a.py"),
1658             Path(path / "b/dont_exclude/a.pyi"),
1659             Path(path / "b/exclude/a.py"),
1660             Path(path / "b/exclude/a.pyi"),
1661             Path(path / "b/.definitely_exclude/a.py"),
1662             Path(path / "b/.definitely_exclude/a.pyi"),
1663         ]
1664         this_abs = THIS_DIR.resolve()
1665         sources.extend(
1666             black.gen_python_files(
1667                 path.iterdir(),
1668                 this_abs,
1669                 re.compile(black.DEFAULT_INCLUDES),
1670                 empty,
1671                 None,
1672                 report,
1673                 gitignore,
1674             )
1675         )
1676         self.assertEqual(sorted(expected), sorted(sources))
1677
1678     def test_invalid_include_exclude(self) -> None:
1679         for option in ["--include", "--exclude"]:
1680             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1681
1682     def test_preserves_line_endings(self) -> None:
1683         with TemporaryDirectory() as workspace:
1684             test_file = Path(workspace) / "test.py"
1685             for nl in ["\n", "\r\n"]:
1686                 contents = nl.join(["def f(  ):", "    pass"])
1687                 test_file.write_bytes(contents.encode())
1688                 ff(test_file, write_back=black.WriteBack.YES)
1689                 updated_contents: bytes = test_file.read_bytes()
1690                 self.assertIn(nl.encode(), updated_contents)
1691                 if nl == "\n":
1692                     self.assertNotIn(b"\r\n", updated_contents)
1693
1694     def test_preserves_line_endings_via_stdin(self) -> None:
1695         for nl in ["\n", "\r\n"]:
1696             contents = nl.join(["def f(  ):", "    pass"])
1697             runner = BlackRunner()
1698             result = runner.invoke(
1699                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1700             )
1701             self.assertEqual(result.exit_code, 0)
1702             output = runner.stdout_bytes
1703             self.assertIn(nl.encode("utf8"), output)
1704             if nl == "\n":
1705                 self.assertNotIn(b"\r\n", output)
1706
1707     def test_assert_equivalent_different_asts(self) -> None:
1708         with self.assertRaises(AssertionError):
1709             black.assert_equivalent("{}", "None")
1710
1711     def test_symlink_out_of_root_directory(self) -> None:
1712         path = MagicMock()
1713         root = THIS_DIR.resolve()
1714         child = MagicMock()
1715         include = re.compile(black.DEFAULT_INCLUDES)
1716         exclude = re.compile(black.DEFAULT_EXCLUDES)
1717         report = black.Report()
1718         gitignore = PathSpec.from_lines("gitwildmatch", [])
1719         # `child` should behave like a symlink which resolved path is clearly
1720         # outside of the `root` directory.
1721         path.iterdir.return_value = [child]
1722         child.resolve.return_value = Path("/a/b/c")
1723         child.as_posix.return_value = "/a/b/c"
1724         child.is_symlink.return_value = True
1725         try:
1726             list(
1727                 black.gen_python_files(
1728                     path.iterdir(), root, include, exclude, None, report, gitignore
1729                 )
1730             )
1731         except ValueError as ve:
1732             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1733         path.iterdir.assert_called_once()
1734         child.resolve.assert_called_once()
1735         child.is_symlink.assert_called_once()
1736         # `child` should behave like a strange file which resolved path is clearly
1737         # outside of the `root` directory.
1738         child.is_symlink.return_value = False
1739         with self.assertRaises(ValueError):
1740             list(
1741                 black.gen_python_files(
1742                     path.iterdir(), root, include, exclude, None, report, gitignore
1743                 )
1744             )
1745         path.iterdir.assert_called()
1746         self.assertEqual(path.iterdir.call_count, 2)
1747         child.resolve.assert_called()
1748         self.assertEqual(child.resolve.call_count, 2)
1749         child.is_symlink.assert_called()
1750         self.assertEqual(child.is_symlink.call_count, 2)
1751
1752     def test_shhh_click(self) -> None:
1753         try:
1754             from click import _unicodefun  # type: ignore
1755         except ModuleNotFoundError:
1756             self.skipTest("Incompatible Click version")
1757         if not hasattr(_unicodefun, "_verify_python3_env"):
1758             self.skipTest("Incompatible Click version")
1759         # First, let's see if Click is crashing with a preferred ASCII charset.
1760         with patch("locale.getpreferredencoding") as gpe:
1761             gpe.return_value = "ASCII"
1762             with self.assertRaises(RuntimeError):
1763                 _unicodefun._verify_python3_env()
1764         # Now, let's silence Click...
1765         black.patch_click()
1766         # ...and confirm it's silent.
1767         with patch("locale.getpreferredencoding") as gpe:
1768             gpe.return_value = "ASCII"
1769             try:
1770                 _unicodefun._verify_python3_env()
1771             except RuntimeError as re:
1772                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1773
1774     def test_root_logger_not_used_directly(self) -> None:
1775         def fail(*args: Any, **kwargs: Any) -> None:
1776             self.fail("Record created with root logger")
1777
1778         with patch.multiple(
1779             logging.root,
1780             debug=fail,
1781             info=fail,
1782             warning=fail,
1783             error=fail,
1784             critical=fail,
1785             log=fail,
1786         ):
1787             ff(THIS_FILE)
1788
1789     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1790     def test_blackd_main(self) -> None:
1791         with patch("blackd.web.run_app"):
1792             result = CliRunner().invoke(blackd.main, [])
1793             if result.exception is not None:
1794                 raise result.exception
1795             self.assertEqual(result.exit_code, 0)
1796
1797     def test_invalid_config_return_code(self) -> None:
1798         tmp_file = Path(black.dump_to_file())
1799         try:
1800             tmp_config = Path(black.dump_to_file())
1801             tmp_config.unlink()
1802             args = ["--config", str(tmp_config), str(tmp_file)]
1803             self.invokeBlack(args, exit_code=2, ignore_config=False)
1804         finally:
1805             tmp_file.unlink()
1806
1807     def test_parse_pyproject_toml(self) -> None:
1808         test_toml_file = THIS_DIR / "test.toml"
1809         config = black.parse_pyproject_toml(str(test_toml_file))
1810         self.assertEqual(config["verbose"], 1)
1811         self.assertEqual(config["check"], "no")
1812         self.assertEqual(config["diff"], "y")
1813         self.assertEqual(config["color"], True)
1814         self.assertEqual(config["line_length"], 79)
1815         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1816         self.assertEqual(config["exclude"], r"\.pyi?$")
1817         self.assertEqual(config["include"], r"\.py?$")
1818
1819     def test_read_pyproject_toml(self) -> None:
1820         test_toml_file = THIS_DIR / "test.toml"
1821         fake_ctx = FakeContext()
1822         black.read_pyproject_toml(
1823             fake_ctx, FakeParameter(), str(test_toml_file),
1824         )
1825         config = fake_ctx.default_map
1826         self.assertEqual(config["verbose"], "1")
1827         self.assertEqual(config["check"], "no")
1828         self.assertEqual(config["diff"], "y")
1829         self.assertEqual(config["color"], "True")
1830         self.assertEqual(config["line_length"], "79")
1831         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1832         self.assertEqual(config["exclude"], r"\.pyi?$")
1833         self.assertEqual(config["include"], r"\.py?$")
1834
1835     def test_find_project_root(self) -> None:
1836         with TemporaryDirectory() as workspace:
1837             root = Path(workspace)
1838             test_dir = root / "test"
1839             test_dir.mkdir()
1840
1841             src_dir = root / "src"
1842             src_dir.mkdir()
1843
1844             root_pyproject = root / "pyproject.toml"
1845             root_pyproject.touch()
1846             src_pyproject = src_dir / "pyproject.toml"
1847             src_pyproject.touch()
1848             src_python = src_dir / "foo.py"
1849             src_python.touch()
1850
1851             self.assertEqual(
1852                 black.find_project_root((src_dir, test_dir)), root.resolve()
1853             )
1854             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1855             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1856
1857
1858 class BlackDTestCase(AioHTTPTestCase):
1859     async def get_application(self) -> web.Application:
1860         return blackd.make_app()
1861
1862     # TODO: remove these decorators once the below is released
1863     # https://github.com/aio-libs/aiohttp/pull/3727
1864     @skip_if_exception("ClientOSError")
1865     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1866     @unittest_run_loop
1867     async def test_blackd_request_needs_formatting(self) -> None:
1868         response = await self.client.post("/", data=b"print('hello world')")
1869         self.assertEqual(response.status, 200)
1870         self.assertEqual(response.charset, "utf8")
1871         self.assertEqual(await response.read(), b'print("hello world")\n')
1872
1873     @skip_if_exception("ClientOSError")
1874     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1875     @unittest_run_loop
1876     async def test_blackd_request_no_change(self) -> None:
1877         response = await self.client.post("/", data=b'print("hello world")\n')
1878         self.assertEqual(response.status, 204)
1879         self.assertEqual(await response.read(), b"")
1880
1881     @skip_if_exception("ClientOSError")
1882     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1883     @unittest_run_loop
1884     async def test_blackd_request_syntax_error(self) -> None:
1885         response = await self.client.post("/", data=b"what even ( is")
1886         self.assertEqual(response.status, 400)
1887         content = await response.text()
1888         self.assertTrue(
1889             content.startswith("Cannot parse"),
1890             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1891         )
1892
1893     @skip_if_exception("ClientOSError")
1894     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1895     @unittest_run_loop
1896     async def test_blackd_unsupported_version(self) -> None:
1897         response = await self.client.post(
1898             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1899         )
1900         self.assertEqual(response.status, 501)
1901
1902     @skip_if_exception("ClientOSError")
1903     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1904     @unittest_run_loop
1905     async def test_blackd_supported_version(self) -> None:
1906         response = await self.client.post(
1907             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1908         )
1909         self.assertEqual(response.status, 200)
1910
1911     @skip_if_exception("ClientOSError")
1912     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1913     @unittest_run_loop
1914     async def test_blackd_invalid_python_variant(self) -> None:
1915         async def check(header_value: str, expected_status: int = 400) -> None:
1916             response = await self.client.post(
1917                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1918             )
1919             self.assertEqual(response.status, expected_status)
1920
1921         await check("lol")
1922         await check("ruby3.5")
1923         await check("pyi3.6")
1924         await check("py1.5")
1925         await check("2.8")
1926         await check("py2.8")
1927         await check("3.0")
1928         await check("pypy3.0")
1929         await check("jython3.4")
1930
1931     @skip_if_exception("ClientOSError")
1932     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1933     @unittest_run_loop
1934     async def test_blackd_pyi(self) -> None:
1935         source, expected = read_data("stub.pyi")
1936         response = await self.client.post(
1937             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1938         )
1939         self.assertEqual(response.status, 200)
1940         self.assertEqual(await response.text(), expected)
1941
1942     @skip_if_exception("ClientOSError")
1943     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1944     @unittest_run_loop
1945     async def test_blackd_diff(self) -> None:
1946         diff_header = re.compile(
1947             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1948         )
1949
1950         source, _ = read_data("blackd_diff.py")
1951         expected, _ = read_data("blackd_diff.diff")
1952
1953         response = await self.client.post(
1954             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
1955         )
1956         self.assertEqual(response.status, 200)
1957
1958         actual = await response.text()
1959         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1960         self.assertEqual(actual, expected)
1961
1962     @skip_if_exception("ClientOSError")
1963     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1964     @unittest_run_loop
1965     async def test_blackd_python_variant(self) -> None:
1966         code = (
1967             "def f(\n"
1968             "    and_has_a_bunch_of,\n"
1969             "    very_long_arguments_too,\n"
1970             "    and_lots_of_them_as_well_lol,\n"
1971             "    **and_very_long_keyword_arguments\n"
1972             "):\n"
1973             "    pass\n"
1974         )
1975
1976         async def check(header_value: str, expected_status: int) -> None:
1977             response = await self.client.post(
1978                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1979             )
1980             self.assertEqual(
1981                 response.status, expected_status, msg=await response.text()
1982             )
1983
1984         await check("3.6", 200)
1985         await check("py3.6", 200)
1986         await check("3.6,3.7", 200)
1987         await check("3.6,py3.7", 200)
1988         await check("py36,py37", 200)
1989         await check("36", 200)
1990         await check("3.6.4", 200)
1991
1992         await check("2", 204)
1993         await check("2.7", 204)
1994         await check("py2.7", 204)
1995         await check("3.4", 204)
1996         await check("py3.4", 204)
1997         await check("py34,py36", 204)
1998         await check("34", 204)
1999
2000     @skip_if_exception("ClientOSError")
2001     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2002     @unittest_run_loop
2003     async def test_blackd_line_length(self) -> None:
2004         response = await self.client.post(
2005             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2006         )
2007         self.assertEqual(response.status, 200)
2008
2009     @skip_if_exception("ClientOSError")
2010     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2011     @unittest_run_loop
2012     async def test_blackd_invalid_line_length(self) -> None:
2013         response = await self.client.post(
2014             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2015         )
2016         self.assertEqual(response.status, 400)
2017
2018     @skip_if_exception("ClientOSError")
2019     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2020     @unittest_run_loop
2021     async def test_blackd_response_black_version_header(self) -> None:
2022         response = await self.client.post("/")
2023         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2024
2025
2026 if __name__ == "__main__":
2027     unittest.main(module="test_black")