]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Upgrade docs to Sphinx 3+ and add doc build test (#1613)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from dataclasses import replace
7 from functools import partial
8 from io import BytesIO, TextIOWrapper
9 import os
10 from pathlib import Path
11 import regex as re
12 import sys
13 from tempfile import TemporaryDirectory
14 from typing import Any, BinaryIO, Dict, Generator, List, Tuple, Iterator, TypeVar
15 import unittest
16 from unittest.mock import patch, MagicMock
17
18 import click
19 from click import unstyle
20 from click.testing import CliRunner
21
22 import black
23 from black import Feature, TargetVersion
24
25 try:
26     import blackd
27     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
28     from aiohttp import web
29 except ImportError:
30     has_blackd_deps = False
31 else:
32     has_blackd_deps = True
33
34 from pathspec import PathSpec
35
36 # Import other test classes
37 from .test_primer import PrimerCLITests  # noqa: F401
38
39
40 DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
41 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
42 fs = partial(black.format_str, mode=DEFAULT_MODE)
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 PROJECT_ROOT = THIS_DIR.parent
46 DETERMINISTIC_HEADER = "[Deterministic header]"
47 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
48 PY36_ARGS = [
49     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
50 ]
51 T = TypeVar("T")
52 R = TypeVar("R")
53
54
55 def dump_to_stderr(*output: str) -> str:
56     return "\n" + "\n".join(output) + "\n"
57
58
59 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
60     """read_data('test_name') -> 'input', 'output'"""
61     if not name.endswith((".py", ".pyi", ".out", ".diff")):
62         name += ".py"
63     _input: List[str] = []
64     _output: List[str] = []
65     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
66     with open(base_dir / name, "r", encoding="utf8") as test:
67         lines = test.readlines()
68     result = _input
69     for line in lines:
70         line = line.replace(EMPTY_LINE, "")
71         if line.rstrip() == "# output":
72             result = _output
73             continue
74
75         result.append(line)
76     if _input and not _output:
77         # If there's no output marker, treat the entire file as already pre-formatted.
78         _output = _input[:]
79     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
80
81
82 @contextmanager
83 def cache_dir(exists: bool = True) -> Iterator[Path]:
84     with TemporaryDirectory() as workspace:
85         cache_dir = Path(workspace)
86         if not exists:
87             cache_dir = cache_dir / "new"
88         with patch("black.CACHE_DIR", cache_dir):
89             yield cache_dir
90
91
92 @contextmanager
93 def event_loop() -> Iterator[None]:
94     policy = asyncio.get_event_loop_policy()
95     loop = policy.new_event_loop()
96     asyncio.set_event_loop(loop)
97     try:
98         yield
99
100     finally:
101         loop.close()
102
103
104 @contextmanager
105 def skip_if_exception(e: str) -> Iterator[None]:
106     try:
107         yield
108     except Exception as exc:
109         if exc.__class__.__name__ == e:
110             unittest.skip(f"Encountered expected exception {exc}, skipping")
111         else:
112             raise
113
114
115 class FakeContext(click.Context):
116     """A fake click Context for when calling functions that need it."""
117
118     def __init__(self) -> None:
119         self.default_map: Dict[str, Any] = {}
120
121
122 class FakeParameter(click.Parameter):
123     """A fake click Parameter for when calling functions that need it."""
124
125     def __init__(self) -> None:
126         pass
127
128
129 class BlackRunner(CliRunner):
130     """Modify CliRunner so that stderr is not merged with stdout.
131
132     This is a hack that can be removed once we depend on Click 7.x"""
133
134     def __init__(self) -> None:
135         self.stderrbuf = BytesIO()
136         self.stdoutbuf = BytesIO()
137         self.stdout_bytes = b""
138         self.stderr_bytes = b""
139         super().__init__()
140
141     @contextmanager
142     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
143         with super().isolation(*args, **kwargs) as output:
144             try:
145                 hold_stderr = sys.stderr
146                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
147                 yield output
148             finally:
149                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
150                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
151                 sys.stderr = hold_stderr
152
153
154 class BlackTestCase(unittest.TestCase):
155     maxDiff = None
156
157     def assertFormatEqual(self, expected: str, actual: str) -> None:
158         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
159             bdv: black.DebugVisitor[Any]
160             black.out("Expected tree:", fg="green")
161             try:
162                 exp_node = black.lib2to3_parse(expected)
163                 bdv = black.DebugVisitor()
164                 list(bdv.visit(exp_node))
165             except Exception as ve:
166                 black.err(str(ve))
167             black.out("Actual tree:", fg="red")
168             try:
169                 exp_node = black.lib2to3_parse(actual)
170                 bdv = black.DebugVisitor()
171                 list(bdv.visit(exp_node))
172             except Exception as ve:
173                 black.err(str(ve))
174         self.assertEqual(expected, actual)
175
176     def invokeBlack(
177         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
178     ) -> None:
179         runner = BlackRunner()
180         if ignore_config:
181             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
182         result = runner.invoke(black.main, args)
183         self.assertEqual(
184             result.exit_code,
185             exit_code,
186             msg=(
187                 f"Failed with args: {args}\n"
188                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
189                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
190                 f"exception: {result.exception}"
191             ),
192         )
193
194     @patch("black.dump_to_file", dump_to_stderr)
195     def checkSourceFile(self, name: str, mode: black.FileMode = DEFAULT_MODE) -> None:
196         path = THIS_DIR.parent / name
197         source, expected = read_data(str(path), data=False)
198         actual = fs(source, mode=mode)
199         self.assertFormatEqual(expected, actual)
200         black.assert_equivalent(source, actual)
201         black.assert_stable(source, actual, mode)
202         self.assertFalse(ff(path))
203
204     @patch("black.dump_to_file", dump_to_stderr)
205     def test_empty(self) -> None:
206         source = expected = ""
207         actual = fs(source)
208         self.assertFormatEqual(expected, actual)
209         black.assert_equivalent(source, actual)
210         black.assert_stable(source, actual, DEFAULT_MODE)
211
212     def test_empty_ff(self) -> None:
213         expected = ""
214         tmp_file = Path(black.dump_to_file())
215         try:
216             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
217             with open(tmp_file, encoding="utf8") as f:
218                 actual = f.read()
219         finally:
220             os.unlink(tmp_file)
221         self.assertFormatEqual(expected, actual)
222
223     def test_self(self) -> None:
224         self.checkSourceFile("tests/test_black.py")
225
226     def test_black(self) -> None:
227         self.checkSourceFile("src/black/__init__.py")
228
229     def test_pygram(self) -> None:
230         self.checkSourceFile("src/blib2to3/pygram.py")
231
232     def test_pytree(self) -> None:
233         self.checkSourceFile("src/blib2to3/pytree.py")
234
235     def test_conv(self) -> None:
236         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
237
238     def test_driver(self) -> None:
239         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
240
241     def test_grammar(self) -> None:
242         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
243
244     def test_literals(self) -> None:
245         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
246
247     def test_parse(self) -> None:
248         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
249
250     def test_pgen(self) -> None:
251         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
252
253     def test_tokenize(self) -> None:
254         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
255
256     def test_token(self) -> None:
257         self.checkSourceFile("src/blib2to3/pgen2/token.py")
258
259     def test_setup(self) -> None:
260         self.checkSourceFile("setup.py")
261
262     def test_piping(self) -> None:
263         source, expected = read_data("src/black/__init__", data=False)
264         result = BlackRunner().invoke(
265             black.main,
266             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
267             input=BytesIO(source.encode("utf8")),
268         )
269         self.assertEqual(result.exit_code, 0)
270         self.assertFormatEqual(expected, result.output)
271         black.assert_equivalent(source, result.output)
272         black.assert_stable(source, result.output, DEFAULT_MODE)
273
274     def test_piping_diff(self) -> None:
275         diff_header = re.compile(
276             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
277             r"\+\d\d\d\d"
278         )
279         source, _ = read_data("expression.py")
280         expected, _ = read_data("expression.diff")
281         config = THIS_DIR / "data" / "empty_pyproject.toml"
282         args = [
283             "-",
284             "--fast",
285             f"--line-length={black.DEFAULT_LINE_LENGTH}",
286             "--diff",
287             f"--config={config}",
288         ]
289         result = BlackRunner().invoke(
290             black.main, args, input=BytesIO(source.encode("utf8"))
291         )
292         self.assertEqual(result.exit_code, 0)
293         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
294         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
295         self.assertEqual(expected, actual)
296
297     def test_piping_diff_with_color(self) -> None:
298         source, _ = read_data("expression.py")
299         config = THIS_DIR / "data" / "empty_pyproject.toml"
300         args = [
301             "-",
302             "--fast",
303             f"--line-length={black.DEFAULT_LINE_LENGTH}",
304             "--diff",
305             "--color",
306             f"--config={config}",
307         ]
308         result = BlackRunner().invoke(
309             black.main, args, input=BytesIO(source.encode("utf8"))
310         )
311         actual = result.output
312         # Again, the contents are checked in a different test, so only look for colors.
313         self.assertIn("\033[1;37m", actual)
314         self.assertIn("\033[36m", actual)
315         self.assertIn("\033[32m", actual)
316         self.assertIn("\033[31m", actual)
317         self.assertIn("\033[0m", actual)
318
319     @patch("black.dump_to_file", dump_to_stderr)
320     def test_function(self) -> None:
321         source, expected = read_data("function")
322         actual = fs(source)
323         self.assertFormatEqual(expected, actual)
324         black.assert_equivalent(source, actual)
325         black.assert_stable(source, actual, DEFAULT_MODE)
326
327     @patch("black.dump_to_file", dump_to_stderr)
328     def test_function2(self) -> None:
329         source, expected = read_data("function2")
330         actual = fs(source)
331         self.assertFormatEqual(expected, actual)
332         black.assert_equivalent(source, actual)
333         black.assert_stable(source, actual, DEFAULT_MODE)
334
335     @patch("black.dump_to_file", dump_to_stderr)
336     def test_function_trailing_comma(self) -> None:
337         source, expected = read_data("function_trailing_comma")
338         actual = fs(source)
339         self.assertFormatEqual(expected, actual)
340         black.assert_equivalent(source, actual)
341         black.assert_stable(source, actual, DEFAULT_MODE)
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_expression(self) -> None:
345         source, expected = read_data("expression")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, DEFAULT_MODE)
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_pep_572(self) -> None:
353         source, expected = read_data("pep_572")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_stable(source, actual, DEFAULT_MODE)
357         if sys.version_info >= (3, 8):
358             black.assert_equivalent(source, actual)
359
360     def test_pep_572_version_detection(self) -> None:
361         source, _ = read_data("pep_572")
362         root = black.lib2to3_parse(source)
363         features = black.get_features_used(root)
364         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
365         versions = black.detect_target_versions(root)
366         self.assertIn(black.TargetVersion.PY38, versions)
367
368     def test_expression_ff(self) -> None:
369         source, expected = read_data("expression")
370         tmp_file = Path(black.dump_to_file(source))
371         try:
372             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
373             with open(tmp_file, encoding="utf8") as f:
374                 actual = f.read()
375         finally:
376             os.unlink(tmp_file)
377         self.assertFormatEqual(expected, actual)
378         with patch("black.dump_to_file", dump_to_stderr):
379             black.assert_equivalent(source, actual)
380             black.assert_stable(source, actual, DEFAULT_MODE)
381
382     def test_expression_diff(self) -> None:
383         source, _ = read_data("expression.py")
384         expected, _ = read_data("expression.diff")
385         tmp_file = Path(black.dump_to_file(source))
386         diff_header = re.compile(
387             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
388             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
389         )
390         try:
391             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
392             self.assertEqual(result.exit_code, 0)
393         finally:
394             os.unlink(tmp_file)
395         actual = result.output
396         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
397         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
398         if expected != actual:
399             dump = black.dump_to_file(actual)
400             msg = (
401                 "Expected diff isn't equal to the actual. If you made changes to"
402                 " expression.py and this is an anticipated difference, overwrite"
403                 f" tests/data/expression.diff with {dump}"
404             )
405             self.assertEqual(expected, actual, msg)
406
407     def test_expression_diff_with_color(self) -> None:
408         source, _ = read_data("expression.py")
409         expected, _ = read_data("expression.diff")
410         tmp_file = Path(black.dump_to_file(source))
411         try:
412             result = BlackRunner().invoke(
413                 black.main, ["--diff", "--color", str(tmp_file)]
414             )
415         finally:
416             os.unlink(tmp_file)
417         actual = result.output
418         # We check the contents of the diff in `test_expression_diff`. All
419         # we need to check here is that color codes exist in the result.
420         self.assertIn("\033[1;37m", actual)
421         self.assertIn("\033[36m", actual)
422         self.assertIn("\033[32m", actual)
423         self.assertIn("\033[31m", actual)
424         self.assertIn("\033[0m", actual)
425
426     @patch("black.dump_to_file", dump_to_stderr)
427     def test_fstring(self) -> None:
428         source, expected = read_data("fstring")
429         actual = fs(source)
430         self.assertFormatEqual(expected, actual)
431         black.assert_equivalent(source, actual)
432         black.assert_stable(source, actual, DEFAULT_MODE)
433
434     @patch("black.dump_to_file", dump_to_stderr)
435     def test_pep_570(self) -> None:
436         source, expected = read_data("pep_570")
437         actual = fs(source)
438         self.assertFormatEqual(expected, actual)
439         black.assert_stable(source, actual, DEFAULT_MODE)
440         if sys.version_info >= (3, 8):
441             black.assert_equivalent(source, actual)
442
443     def test_detect_pos_only_arguments(self) -> None:
444         source, _ = read_data("pep_570")
445         root = black.lib2to3_parse(source)
446         features = black.get_features_used(root)
447         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
448         versions = black.detect_target_versions(root)
449         self.assertIn(black.TargetVersion.PY38, versions)
450
451     @patch("black.dump_to_file", dump_to_stderr)
452     def test_string_quotes(self) -> None:
453         source, expected = read_data("string_quotes")
454         actual = fs(source)
455         self.assertFormatEqual(expected, actual)
456         black.assert_equivalent(source, actual)
457         black.assert_stable(source, actual, DEFAULT_MODE)
458         mode = replace(DEFAULT_MODE, string_normalization=False)
459         not_normalized = fs(source, mode=mode)
460         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
461         black.assert_equivalent(source, not_normalized)
462         black.assert_stable(source, not_normalized, mode=mode)
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_docstring(self) -> None:
466         source, expected = read_data("docstring")
467         actual = fs(source)
468         self.assertFormatEqual(expected, actual)
469         black.assert_equivalent(source, actual)
470         black.assert_stable(source, actual, DEFAULT_MODE)
471
472     def test_long_strings(self) -> None:
473         """Tests for splitting long strings."""
474         source, expected = read_data("long_strings")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         black.assert_equivalent(source, actual)
478         black.assert_stable(source, actual, DEFAULT_MODE)
479
480     def test_long_strings_flag_disabled(self) -> None:
481         """Tests for turning off the string processing logic."""
482         source, expected = read_data("long_strings_flag_disabled")
483         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
484         actual = fs(source, mode=mode)
485         self.assertFormatEqual(expected, actual)
486         black.assert_stable(expected, actual, mode)
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_long_strings__edge_case(self) -> None:
490         """Edge-case tests for splitting long strings."""
491         source, expected = read_data("long_strings__edge_case")
492         actual = fs(source)
493         self.assertFormatEqual(expected, actual)
494         black.assert_equivalent(source, actual)
495         black.assert_stable(source, actual, DEFAULT_MODE)
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_long_strings__regression(self) -> None:
499         """Regression tests for splitting long strings."""
500         source, expected = read_data("long_strings__regression")
501         actual = fs(source)
502         self.assertFormatEqual(expected, actual)
503         black.assert_equivalent(source, actual)
504         black.assert_stable(source, actual, DEFAULT_MODE)
505
506     @patch("black.dump_to_file", dump_to_stderr)
507     def test_slices(self) -> None:
508         source, expected = read_data("slices")
509         actual = fs(source)
510         self.assertFormatEqual(expected, actual)
511         black.assert_equivalent(source, actual)
512         black.assert_stable(source, actual, DEFAULT_MODE)
513
514     @patch("black.dump_to_file", dump_to_stderr)
515     def test_percent_precedence(self) -> None:
516         source, expected = read_data("percent_precedence")
517         actual = fs(source)
518         self.assertFormatEqual(expected, actual)
519         black.assert_equivalent(source, actual)
520         black.assert_stable(source, actual, DEFAULT_MODE)
521
522     @patch("black.dump_to_file", dump_to_stderr)
523     def test_comments(self) -> None:
524         source, expected = read_data("comments")
525         actual = fs(source)
526         self.assertFormatEqual(expected, actual)
527         black.assert_equivalent(source, actual)
528         black.assert_stable(source, actual, DEFAULT_MODE)
529
530     @patch("black.dump_to_file", dump_to_stderr)
531     def test_comments2(self) -> None:
532         source, expected = read_data("comments2")
533         actual = fs(source)
534         self.assertFormatEqual(expected, actual)
535         black.assert_equivalent(source, actual)
536         black.assert_stable(source, actual, DEFAULT_MODE)
537
538     @patch("black.dump_to_file", dump_to_stderr)
539     def test_comments3(self) -> None:
540         source, expected = read_data("comments3")
541         actual = fs(source)
542         self.assertFormatEqual(expected, actual)
543         black.assert_equivalent(source, actual)
544         black.assert_stable(source, actual, DEFAULT_MODE)
545
546     @patch("black.dump_to_file", dump_to_stderr)
547     def test_comments4(self) -> None:
548         source, expected = read_data("comments4")
549         actual = fs(source)
550         self.assertFormatEqual(expected, actual)
551         black.assert_equivalent(source, actual)
552         black.assert_stable(source, actual, DEFAULT_MODE)
553
554     @patch("black.dump_to_file", dump_to_stderr)
555     def test_comments5(self) -> None:
556         source, expected = read_data("comments5")
557         actual = fs(source)
558         self.assertFormatEqual(expected, actual)
559         black.assert_equivalent(source, actual)
560         black.assert_stable(source, actual, DEFAULT_MODE)
561
562     @patch("black.dump_to_file", dump_to_stderr)
563     def test_comments6(self) -> None:
564         source, expected = read_data("comments6")
565         actual = fs(source)
566         self.assertFormatEqual(expected, actual)
567         black.assert_equivalent(source, actual)
568         black.assert_stable(source, actual, DEFAULT_MODE)
569
570     @patch("black.dump_to_file", dump_to_stderr)
571     def test_comments7(self) -> None:
572         source, expected = read_data("comments7")
573         actual = fs(source)
574         self.assertFormatEqual(expected, actual)
575         black.assert_equivalent(source, actual)
576         black.assert_stable(source, actual, DEFAULT_MODE)
577
578     @patch("black.dump_to_file", dump_to_stderr)
579     def test_comment_after_escaped_newline(self) -> None:
580         source, expected = read_data("comment_after_escaped_newline")
581         actual = fs(source)
582         self.assertFormatEqual(expected, actual)
583         black.assert_equivalent(source, actual)
584         black.assert_stable(source, actual, DEFAULT_MODE)
585
586     @patch("black.dump_to_file", dump_to_stderr)
587     def test_cantfit(self) -> None:
588         source, expected = read_data("cantfit")
589         actual = fs(source)
590         self.assertFormatEqual(expected, actual)
591         black.assert_equivalent(source, actual)
592         black.assert_stable(source, actual, DEFAULT_MODE)
593
594     @patch("black.dump_to_file", dump_to_stderr)
595     def test_import_spacing(self) -> None:
596         source, expected = read_data("import_spacing")
597         actual = fs(source)
598         self.assertFormatEqual(expected, actual)
599         black.assert_equivalent(source, actual)
600         black.assert_stable(source, actual, DEFAULT_MODE)
601
602     @patch("black.dump_to_file", dump_to_stderr)
603     def test_composition(self) -> None:
604         source, expected = read_data("composition")
605         actual = fs(source)
606         self.assertFormatEqual(expected, actual)
607         black.assert_equivalent(source, actual)
608         black.assert_stable(source, actual, DEFAULT_MODE)
609
610     @patch("black.dump_to_file", dump_to_stderr)
611     def test_empty_lines(self) -> None:
612         source, expected = read_data("empty_lines")
613         actual = fs(source)
614         self.assertFormatEqual(expected, actual)
615         black.assert_equivalent(source, actual)
616         black.assert_stable(source, actual, DEFAULT_MODE)
617
618     @patch("black.dump_to_file", dump_to_stderr)
619     def test_remove_parens(self) -> None:
620         source, expected = read_data("remove_parens")
621         actual = fs(source)
622         self.assertFormatEqual(expected, actual)
623         black.assert_equivalent(source, actual)
624         black.assert_stable(source, actual, DEFAULT_MODE)
625
626     @patch("black.dump_to_file", dump_to_stderr)
627     def test_string_prefixes(self) -> None:
628         source, expected = read_data("string_prefixes")
629         actual = fs(source)
630         self.assertFormatEqual(expected, actual)
631         black.assert_equivalent(source, actual)
632         black.assert_stable(source, actual, DEFAULT_MODE)
633
634     @patch("black.dump_to_file", dump_to_stderr)
635     def test_numeric_literals(self) -> None:
636         source, expected = read_data("numeric_literals")
637         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
638         actual = fs(source, mode=mode)
639         self.assertFormatEqual(expected, actual)
640         black.assert_equivalent(source, actual)
641         black.assert_stable(source, actual, mode)
642
643     @patch("black.dump_to_file", dump_to_stderr)
644     def test_numeric_literals_ignoring_underscores(self) -> None:
645         source, expected = read_data("numeric_literals_skip_underscores")
646         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
647         actual = fs(source, mode=mode)
648         self.assertFormatEqual(expected, actual)
649         black.assert_equivalent(source, actual)
650         black.assert_stable(source, actual, mode)
651
652     @patch("black.dump_to_file", dump_to_stderr)
653     def test_numeric_literals_py2(self) -> None:
654         source, expected = read_data("numeric_literals_py2")
655         actual = fs(source)
656         self.assertFormatEqual(expected, actual)
657         black.assert_stable(source, actual, DEFAULT_MODE)
658
659     @patch("black.dump_to_file", dump_to_stderr)
660     def test_python2(self) -> None:
661         source, expected = read_data("python2")
662         actual = fs(source)
663         self.assertFormatEqual(expected, actual)
664         black.assert_equivalent(source, actual)
665         black.assert_stable(source, actual, DEFAULT_MODE)
666
667     @patch("black.dump_to_file", dump_to_stderr)
668     def test_python2_print_function(self) -> None:
669         source, expected = read_data("python2_print_function")
670         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
671         actual = fs(source, mode=mode)
672         self.assertFormatEqual(expected, actual)
673         black.assert_equivalent(source, actual)
674         black.assert_stable(source, actual, mode)
675
676     @patch("black.dump_to_file", dump_to_stderr)
677     def test_python2_unicode_literals(self) -> None:
678         source, expected = read_data("python2_unicode_literals")
679         actual = fs(source)
680         self.assertFormatEqual(expected, actual)
681         black.assert_equivalent(source, actual)
682         black.assert_stable(source, actual, DEFAULT_MODE)
683
684     @patch("black.dump_to_file", dump_to_stderr)
685     def test_stub(self) -> None:
686         mode = replace(DEFAULT_MODE, is_pyi=True)
687         source, expected = read_data("stub.pyi")
688         actual = fs(source, mode=mode)
689         self.assertFormatEqual(expected, actual)
690         black.assert_stable(source, actual, mode)
691
692     @patch("black.dump_to_file", dump_to_stderr)
693     def test_async_as_identifier(self) -> None:
694         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
695         source, expected = read_data("async_as_identifier")
696         actual = fs(source)
697         self.assertFormatEqual(expected, actual)
698         major, minor = sys.version_info[:2]
699         if major < 3 or (major <= 3 and minor < 7):
700             black.assert_equivalent(source, actual)
701         black.assert_stable(source, actual, DEFAULT_MODE)
702         # ensure black can parse this when the target is 3.6
703         self.invokeBlack([str(source_path), "--target-version", "py36"])
704         # but not on 3.7, because async/await is no longer an identifier
705         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
706
707     @patch("black.dump_to_file", dump_to_stderr)
708     def test_python37(self) -> None:
709         source_path = (THIS_DIR / "data" / "python37.py").resolve()
710         source, expected = read_data("python37")
711         actual = fs(source)
712         self.assertFormatEqual(expected, actual)
713         major, minor = sys.version_info[:2]
714         if major > 3 or (major == 3 and minor >= 7):
715             black.assert_equivalent(source, actual)
716         black.assert_stable(source, actual, DEFAULT_MODE)
717         # ensure black can parse this when the target is 3.7
718         self.invokeBlack([str(source_path), "--target-version", "py37"])
719         # but not on 3.6, because we use async as a reserved keyword
720         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
721
722     @patch("black.dump_to_file", dump_to_stderr)
723     def test_python38(self) -> None:
724         source, expected = read_data("python38")
725         actual = fs(source)
726         self.assertFormatEqual(expected, actual)
727         major, minor = sys.version_info[:2]
728         if major > 3 or (major == 3 and minor >= 8):
729             black.assert_equivalent(source, actual)
730         black.assert_stable(source, actual, DEFAULT_MODE)
731
732     @patch("black.dump_to_file", dump_to_stderr)
733     def test_fmtonoff(self) -> None:
734         source, expected = read_data("fmtonoff")
735         actual = fs(source)
736         self.assertFormatEqual(expected, actual)
737         black.assert_equivalent(source, actual)
738         black.assert_stable(source, actual, DEFAULT_MODE)
739
740     @patch("black.dump_to_file", dump_to_stderr)
741     def test_fmtonoff2(self) -> None:
742         source, expected = read_data("fmtonoff2")
743         actual = fs(source)
744         self.assertFormatEqual(expected, actual)
745         black.assert_equivalent(source, actual)
746         black.assert_stable(source, actual, DEFAULT_MODE)
747
748     @patch("black.dump_to_file", dump_to_stderr)
749     def test_fmtonoff3(self) -> None:
750         source, expected = read_data("fmtonoff3")
751         actual = fs(source)
752         self.assertFormatEqual(expected, actual)
753         black.assert_equivalent(source, actual)
754         black.assert_stable(source, actual, DEFAULT_MODE)
755
756     @patch("black.dump_to_file", dump_to_stderr)
757     def test_fmtonoff4(self) -> None:
758         source, expected = read_data("fmtonoff4")
759         actual = fs(source)
760         self.assertFormatEqual(expected, actual)
761         black.assert_equivalent(source, actual)
762         black.assert_stable(source, actual, DEFAULT_MODE)
763
764     @patch("black.dump_to_file", dump_to_stderr)
765     def test_remove_empty_parentheses_after_class(self) -> None:
766         source, expected = read_data("class_blank_parentheses")
767         actual = fs(source)
768         self.assertFormatEqual(expected, actual)
769         black.assert_equivalent(source, actual)
770         black.assert_stable(source, actual, DEFAULT_MODE)
771
772     @patch("black.dump_to_file", dump_to_stderr)
773     def test_new_line_between_class_and_code(self) -> None:
774         source, expected = read_data("class_methods_new_line")
775         actual = fs(source)
776         self.assertFormatEqual(expected, actual)
777         black.assert_equivalent(source, actual)
778         black.assert_stable(source, actual, DEFAULT_MODE)
779
780     @patch("black.dump_to_file", dump_to_stderr)
781     def test_bracket_match(self) -> None:
782         source, expected = read_data("bracketmatch")
783         actual = fs(source)
784         self.assertFormatEqual(expected, actual)
785         black.assert_equivalent(source, actual)
786         black.assert_stable(source, actual, DEFAULT_MODE)
787
788     @patch("black.dump_to_file", dump_to_stderr)
789     def test_tuple_assign(self) -> None:
790         source, expected = read_data("tupleassign")
791         actual = fs(source)
792         self.assertFormatEqual(expected, actual)
793         black.assert_equivalent(source, actual)
794         black.assert_stable(source, actual, DEFAULT_MODE)
795
796     @patch("black.dump_to_file", dump_to_stderr)
797     def test_beginning_backslash(self) -> None:
798         source, expected = read_data("beginning_backslash")
799         actual = fs(source)
800         self.assertFormatEqual(expected, actual)
801         black.assert_equivalent(source, actual)
802         black.assert_stable(source, actual, DEFAULT_MODE)
803
804     def test_tab_comment_indentation(self) -> None:
805         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
806         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
807         self.assertFormatEqual(contents_spc, fs(contents_spc))
808         self.assertFormatEqual(contents_spc, fs(contents_tab))
809
810         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
811         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
812         self.assertFormatEqual(contents_spc, fs(contents_spc))
813         self.assertFormatEqual(contents_spc, fs(contents_tab))
814
815         # mixed tabs and spaces (valid Python 2 code)
816         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
817         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
818         self.assertFormatEqual(contents_spc, fs(contents_spc))
819         self.assertFormatEqual(contents_spc, fs(contents_tab))
820
821         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
822         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
823         self.assertFormatEqual(contents_spc, fs(contents_spc))
824         self.assertFormatEqual(contents_spc, fs(contents_tab))
825
826     def test_report_verbose(self) -> None:
827         report = black.Report(verbose=True)
828         out_lines = []
829         err_lines = []
830
831         def out(msg: str, **kwargs: Any) -> None:
832             out_lines.append(msg)
833
834         def err(msg: str, **kwargs: Any) -> None:
835             err_lines.append(msg)
836
837         with patch("black.out", out), patch("black.err", err):
838             report.done(Path("f1"), black.Changed.NO)
839             self.assertEqual(len(out_lines), 1)
840             self.assertEqual(len(err_lines), 0)
841             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
842             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
843             self.assertEqual(report.return_code, 0)
844             report.done(Path("f2"), black.Changed.YES)
845             self.assertEqual(len(out_lines), 2)
846             self.assertEqual(len(err_lines), 0)
847             self.assertEqual(out_lines[-1], "reformatted f2")
848             self.assertEqual(
849                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
850             )
851             report.done(Path("f3"), black.Changed.CACHED)
852             self.assertEqual(len(out_lines), 3)
853             self.assertEqual(len(err_lines), 0)
854             self.assertEqual(
855                 out_lines[-1], "f3 wasn't modified on disk since last run."
856             )
857             self.assertEqual(
858                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
859             )
860             self.assertEqual(report.return_code, 0)
861             report.check = True
862             self.assertEqual(report.return_code, 1)
863             report.check = False
864             report.failed(Path("e1"), "boom")
865             self.assertEqual(len(out_lines), 3)
866             self.assertEqual(len(err_lines), 1)
867             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
868             self.assertEqual(
869                 unstyle(str(report)),
870                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
871                 " reformat.",
872             )
873             self.assertEqual(report.return_code, 123)
874             report.done(Path("f3"), black.Changed.YES)
875             self.assertEqual(len(out_lines), 4)
876             self.assertEqual(len(err_lines), 1)
877             self.assertEqual(out_lines[-1], "reformatted f3")
878             self.assertEqual(
879                 unstyle(str(report)),
880                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
881                 " reformat.",
882             )
883             self.assertEqual(report.return_code, 123)
884             report.failed(Path("e2"), "boom")
885             self.assertEqual(len(out_lines), 4)
886             self.assertEqual(len(err_lines), 2)
887             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
888             self.assertEqual(
889                 unstyle(str(report)),
890                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
891                 " reformat.",
892             )
893             self.assertEqual(report.return_code, 123)
894             report.path_ignored(Path("wat"), "no match")
895             self.assertEqual(len(out_lines), 5)
896             self.assertEqual(len(err_lines), 2)
897             self.assertEqual(out_lines[-1], "wat ignored: no match")
898             self.assertEqual(
899                 unstyle(str(report)),
900                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
901                 " reformat.",
902             )
903             self.assertEqual(report.return_code, 123)
904             report.done(Path("f4"), black.Changed.NO)
905             self.assertEqual(len(out_lines), 6)
906             self.assertEqual(len(err_lines), 2)
907             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
908             self.assertEqual(
909                 unstyle(str(report)),
910                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
911                 " reformat.",
912             )
913             self.assertEqual(report.return_code, 123)
914             report.check = True
915             self.assertEqual(
916                 unstyle(str(report)),
917                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
918                 " would fail to reformat.",
919             )
920             report.check = False
921             report.diff = True
922             self.assertEqual(
923                 unstyle(str(report)),
924                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
925                 " would fail to reformat.",
926             )
927
928     def test_report_quiet(self) -> None:
929         report = black.Report(quiet=True)
930         out_lines = []
931         err_lines = []
932
933         def out(msg: str, **kwargs: Any) -> None:
934             out_lines.append(msg)
935
936         def err(msg: str, **kwargs: Any) -> None:
937             err_lines.append(msg)
938
939         with patch("black.out", out), patch("black.err", err):
940             report.done(Path("f1"), black.Changed.NO)
941             self.assertEqual(len(out_lines), 0)
942             self.assertEqual(len(err_lines), 0)
943             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
944             self.assertEqual(report.return_code, 0)
945             report.done(Path("f2"), black.Changed.YES)
946             self.assertEqual(len(out_lines), 0)
947             self.assertEqual(len(err_lines), 0)
948             self.assertEqual(
949                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
950             )
951             report.done(Path("f3"), black.Changed.CACHED)
952             self.assertEqual(len(out_lines), 0)
953             self.assertEqual(len(err_lines), 0)
954             self.assertEqual(
955                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
956             )
957             self.assertEqual(report.return_code, 0)
958             report.check = True
959             self.assertEqual(report.return_code, 1)
960             report.check = False
961             report.failed(Path("e1"), "boom")
962             self.assertEqual(len(out_lines), 0)
963             self.assertEqual(len(err_lines), 1)
964             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
965             self.assertEqual(
966                 unstyle(str(report)),
967                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
968                 " reformat.",
969             )
970             self.assertEqual(report.return_code, 123)
971             report.done(Path("f3"), black.Changed.YES)
972             self.assertEqual(len(out_lines), 0)
973             self.assertEqual(len(err_lines), 1)
974             self.assertEqual(
975                 unstyle(str(report)),
976                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
977                 " reformat.",
978             )
979             self.assertEqual(report.return_code, 123)
980             report.failed(Path("e2"), "boom")
981             self.assertEqual(len(out_lines), 0)
982             self.assertEqual(len(err_lines), 2)
983             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
984             self.assertEqual(
985                 unstyle(str(report)),
986                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
987                 " reformat.",
988             )
989             self.assertEqual(report.return_code, 123)
990             report.path_ignored(Path("wat"), "no match")
991             self.assertEqual(len(out_lines), 0)
992             self.assertEqual(len(err_lines), 2)
993             self.assertEqual(
994                 unstyle(str(report)),
995                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
996                 " reformat.",
997             )
998             self.assertEqual(report.return_code, 123)
999             report.done(Path("f4"), black.Changed.NO)
1000             self.assertEqual(len(out_lines), 0)
1001             self.assertEqual(len(err_lines), 2)
1002             self.assertEqual(
1003                 unstyle(str(report)),
1004                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1005                 " reformat.",
1006             )
1007             self.assertEqual(report.return_code, 123)
1008             report.check = True
1009             self.assertEqual(
1010                 unstyle(str(report)),
1011                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1012                 " would fail to reformat.",
1013             )
1014             report.check = False
1015             report.diff = True
1016             self.assertEqual(
1017                 unstyle(str(report)),
1018                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1019                 " would fail to reformat.",
1020             )
1021
1022     def test_report_normal(self) -> None:
1023         report = black.Report()
1024         out_lines = []
1025         err_lines = []
1026
1027         def out(msg: str, **kwargs: Any) -> None:
1028             out_lines.append(msg)
1029
1030         def err(msg: str, **kwargs: Any) -> None:
1031             err_lines.append(msg)
1032
1033         with patch("black.out", out), patch("black.err", err):
1034             report.done(Path("f1"), black.Changed.NO)
1035             self.assertEqual(len(out_lines), 0)
1036             self.assertEqual(len(err_lines), 0)
1037             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1038             self.assertEqual(report.return_code, 0)
1039             report.done(Path("f2"), black.Changed.YES)
1040             self.assertEqual(len(out_lines), 1)
1041             self.assertEqual(len(err_lines), 0)
1042             self.assertEqual(out_lines[-1], "reformatted f2")
1043             self.assertEqual(
1044                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1045             )
1046             report.done(Path("f3"), black.Changed.CACHED)
1047             self.assertEqual(len(out_lines), 1)
1048             self.assertEqual(len(err_lines), 0)
1049             self.assertEqual(out_lines[-1], "reformatted f2")
1050             self.assertEqual(
1051                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1052             )
1053             self.assertEqual(report.return_code, 0)
1054             report.check = True
1055             self.assertEqual(report.return_code, 1)
1056             report.check = False
1057             report.failed(Path("e1"), "boom")
1058             self.assertEqual(len(out_lines), 1)
1059             self.assertEqual(len(err_lines), 1)
1060             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1061             self.assertEqual(
1062                 unstyle(str(report)),
1063                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1064                 " reformat.",
1065             )
1066             self.assertEqual(report.return_code, 123)
1067             report.done(Path("f3"), black.Changed.YES)
1068             self.assertEqual(len(out_lines), 2)
1069             self.assertEqual(len(err_lines), 1)
1070             self.assertEqual(out_lines[-1], "reformatted f3")
1071             self.assertEqual(
1072                 unstyle(str(report)),
1073                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1074                 " reformat.",
1075             )
1076             self.assertEqual(report.return_code, 123)
1077             report.failed(Path("e2"), "boom")
1078             self.assertEqual(len(out_lines), 2)
1079             self.assertEqual(len(err_lines), 2)
1080             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1081             self.assertEqual(
1082                 unstyle(str(report)),
1083                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1084                 " reformat.",
1085             )
1086             self.assertEqual(report.return_code, 123)
1087             report.path_ignored(Path("wat"), "no match")
1088             self.assertEqual(len(out_lines), 2)
1089             self.assertEqual(len(err_lines), 2)
1090             self.assertEqual(
1091                 unstyle(str(report)),
1092                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1093                 " reformat.",
1094             )
1095             self.assertEqual(report.return_code, 123)
1096             report.done(Path("f4"), black.Changed.NO)
1097             self.assertEqual(len(out_lines), 2)
1098             self.assertEqual(len(err_lines), 2)
1099             self.assertEqual(
1100                 unstyle(str(report)),
1101                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1102                 " reformat.",
1103             )
1104             self.assertEqual(report.return_code, 123)
1105             report.check = True
1106             self.assertEqual(
1107                 unstyle(str(report)),
1108                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1109                 " would fail to reformat.",
1110             )
1111             report.check = False
1112             report.diff = True
1113             self.assertEqual(
1114                 unstyle(str(report)),
1115                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1116                 " would fail to reformat.",
1117             )
1118
1119     def test_lib2to3_parse(self) -> None:
1120         with self.assertRaises(black.InvalidInput):
1121             black.lib2to3_parse("invalid syntax")
1122
1123         straddling = "x + y"
1124         black.lib2to3_parse(straddling)
1125         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1126         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1127         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1128
1129         py2_only = "print x"
1130         black.lib2to3_parse(py2_only)
1131         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1132         with self.assertRaises(black.InvalidInput):
1133             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1134         with self.assertRaises(black.InvalidInput):
1135             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1136
1137         py3_only = "exec(x, end=y)"
1138         black.lib2to3_parse(py3_only)
1139         with self.assertRaises(black.InvalidInput):
1140             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1141         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1142         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1143
1144     def test_get_features_used(self) -> None:
1145         node = black.lib2to3_parse("def f(*, arg): ...\n")
1146         self.assertEqual(black.get_features_used(node), set())
1147         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1148         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1149         node = black.lib2to3_parse("f(*arg,)\n")
1150         self.assertEqual(
1151             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1152         )
1153         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1154         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1155         node = black.lib2to3_parse("123_456\n")
1156         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1157         node = black.lib2to3_parse("123456\n")
1158         self.assertEqual(black.get_features_used(node), set())
1159         source, expected = read_data("function")
1160         node = black.lib2to3_parse(source)
1161         expected_features = {
1162             Feature.TRAILING_COMMA_IN_CALL,
1163             Feature.TRAILING_COMMA_IN_DEF,
1164             Feature.F_STRINGS,
1165         }
1166         self.assertEqual(black.get_features_used(node), expected_features)
1167         node = black.lib2to3_parse(expected)
1168         self.assertEqual(black.get_features_used(node), expected_features)
1169         source, expected = read_data("expression")
1170         node = black.lib2to3_parse(source)
1171         self.assertEqual(black.get_features_used(node), set())
1172         node = black.lib2to3_parse(expected)
1173         self.assertEqual(black.get_features_used(node), set())
1174
1175     def test_get_future_imports(self) -> None:
1176         node = black.lib2to3_parse("\n")
1177         self.assertEqual(set(), black.get_future_imports(node))
1178         node = black.lib2to3_parse("from __future__ import black\n")
1179         self.assertEqual({"black"}, black.get_future_imports(node))
1180         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1181         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1182         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1183         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1184         node = black.lib2to3_parse(
1185             "from __future__ import multiple\nfrom __future__ import imports\n"
1186         )
1187         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1188         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1189         self.assertEqual({"black"}, black.get_future_imports(node))
1190         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1191         self.assertEqual({"black"}, black.get_future_imports(node))
1192         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1193         self.assertEqual(set(), black.get_future_imports(node))
1194         node = black.lib2to3_parse("from some.module import black\n")
1195         self.assertEqual(set(), black.get_future_imports(node))
1196         node = black.lib2to3_parse(
1197             "from __future__ import unicode_literals as _unicode_literals"
1198         )
1199         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1200         node = black.lib2to3_parse(
1201             "from __future__ import unicode_literals as _lol, print"
1202         )
1203         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1204
1205     def test_debug_visitor(self) -> None:
1206         source, _ = read_data("debug_visitor.py")
1207         expected, _ = read_data("debug_visitor.out")
1208         out_lines = []
1209         err_lines = []
1210
1211         def out(msg: str, **kwargs: Any) -> None:
1212             out_lines.append(msg)
1213
1214         def err(msg: str, **kwargs: Any) -> None:
1215             err_lines.append(msg)
1216
1217         with patch("black.out", out), patch("black.err", err):
1218             black.DebugVisitor.show(source)
1219         actual = "\n".join(out_lines) + "\n"
1220         log_name = ""
1221         if expected != actual:
1222             log_name = black.dump_to_file(*out_lines)
1223         self.assertEqual(
1224             expected,
1225             actual,
1226             f"AST print out is different. Actual version dumped to {log_name}",
1227         )
1228
1229     def test_format_file_contents(self) -> None:
1230         empty = ""
1231         mode = DEFAULT_MODE
1232         with self.assertRaises(black.NothingChanged):
1233             black.format_file_contents(empty, mode=mode, fast=False)
1234         just_nl = "\n"
1235         with self.assertRaises(black.NothingChanged):
1236             black.format_file_contents(just_nl, mode=mode, fast=False)
1237         same = "j = [1, 2, 3]\n"
1238         with self.assertRaises(black.NothingChanged):
1239             black.format_file_contents(same, mode=mode, fast=False)
1240         different = "j = [1,2,3]"
1241         expected = same
1242         actual = black.format_file_contents(different, mode=mode, fast=False)
1243         self.assertEqual(expected, actual)
1244         invalid = "return if you can"
1245         with self.assertRaises(black.InvalidInput) as e:
1246             black.format_file_contents(invalid, mode=mode, fast=False)
1247         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1248
1249     def test_endmarker(self) -> None:
1250         n = black.lib2to3_parse("\n")
1251         self.assertEqual(n.type, black.syms.file_input)
1252         self.assertEqual(len(n.children), 1)
1253         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1254
1255     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1256     def test_assertFormatEqual(self) -> None:
1257         out_lines = []
1258         err_lines = []
1259
1260         def out(msg: str, **kwargs: Any) -> None:
1261             out_lines.append(msg)
1262
1263         def err(msg: str, **kwargs: Any) -> None:
1264             err_lines.append(msg)
1265
1266         with patch("black.out", out), patch("black.err", err):
1267             with self.assertRaises(AssertionError):
1268                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1269
1270         out_str = "".join(out_lines)
1271         self.assertTrue("Expected tree:" in out_str)
1272         self.assertTrue("Actual tree:" in out_str)
1273         self.assertEqual("".join(err_lines), "")
1274
1275     def test_cache_broken_file(self) -> None:
1276         mode = DEFAULT_MODE
1277         with cache_dir() as workspace:
1278             cache_file = black.get_cache_file(mode)
1279             with cache_file.open("w") as fobj:
1280                 fobj.write("this is not a pickle")
1281             self.assertEqual(black.read_cache(mode), {})
1282             src = (workspace / "test.py").resolve()
1283             with src.open("w") as fobj:
1284                 fobj.write("print('hello')")
1285             self.invokeBlack([str(src)])
1286             cache = black.read_cache(mode)
1287             self.assertIn(src, cache)
1288
1289     def test_cache_single_file_already_cached(self) -> None:
1290         mode = DEFAULT_MODE
1291         with cache_dir() as workspace:
1292             src = (workspace / "test.py").resolve()
1293             with src.open("w") as fobj:
1294                 fobj.write("print('hello')")
1295             black.write_cache({}, [src], mode)
1296             self.invokeBlack([str(src)])
1297             with src.open("r") as fobj:
1298                 self.assertEqual(fobj.read(), "print('hello')")
1299
1300     @event_loop()
1301     def test_cache_multiple_files(self) -> None:
1302         mode = DEFAULT_MODE
1303         with cache_dir() as workspace, patch(
1304             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1305         ):
1306             one = (workspace / "one.py").resolve()
1307             with one.open("w") as fobj:
1308                 fobj.write("print('hello')")
1309             two = (workspace / "two.py").resolve()
1310             with two.open("w") as fobj:
1311                 fobj.write("print('hello')")
1312             black.write_cache({}, [one], mode)
1313             self.invokeBlack([str(workspace)])
1314             with one.open("r") as fobj:
1315                 self.assertEqual(fobj.read(), "print('hello')")
1316             with two.open("r") as fobj:
1317                 self.assertEqual(fobj.read(), 'print("hello")\n')
1318             cache = black.read_cache(mode)
1319             self.assertIn(one, cache)
1320             self.assertIn(two, cache)
1321
1322     def test_no_cache_when_writeback_diff(self) -> None:
1323         mode = DEFAULT_MODE
1324         with cache_dir() as workspace:
1325             src = (workspace / "test.py").resolve()
1326             with src.open("w") as fobj:
1327                 fobj.write("print('hello')")
1328             self.invokeBlack([str(src), "--diff"])
1329             cache_file = black.get_cache_file(mode)
1330             self.assertFalse(cache_file.exists())
1331
1332     def test_no_cache_when_stdin(self) -> None:
1333         mode = DEFAULT_MODE
1334         with cache_dir():
1335             result = CliRunner().invoke(
1336                 black.main, ["-"], input=BytesIO(b"print('hello')")
1337             )
1338             self.assertEqual(result.exit_code, 0)
1339             cache_file = black.get_cache_file(mode)
1340             self.assertFalse(cache_file.exists())
1341
1342     def test_read_cache_no_cachefile(self) -> None:
1343         mode = DEFAULT_MODE
1344         with cache_dir():
1345             self.assertEqual(black.read_cache(mode), {})
1346
1347     def test_write_cache_read_cache(self) -> None:
1348         mode = DEFAULT_MODE
1349         with cache_dir() as workspace:
1350             src = (workspace / "test.py").resolve()
1351             src.touch()
1352             black.write_cache({}, [src], mode)
1353             cache = black.read_cache(mode)
1354             self.assertIn(src, cache)
1355             self.assertEqual(cache[src], black.get_cache_info(src))
1356
1357     def test_filter_cached(self) -> None:
1358         with TemporaryDirectory() as workspace:
1359             path = Path(workspace)
1360             uncached = (path / "uncached").resolve()
1361             cached = (path / "cached").resolve()
1362             cached_but_changed = (path / "changed").resolve()
1363             uncached.touch()
1364             cached.touch()
1365             cached_but_changed.touch()
1366             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1367             todo, done = black.filter_cached(
1368                 cache, {uncached, cached, cached_but_changed}
1369             )
1370             self.assertEqual(todo, {uncached, cached_but_changed})
1371             self.assertEqual(done, {cached})
1372
1373     def test_write_cache_creates_directory_if_needed(self) -> None:
1374         mode = DEFAULT_MODE
1375         with cache_dir(exists=False) as workspace:
1376             self.assertFalse(workspace.exists())
1377             black.write_cache({}, [], mode)
1378             self.assertTrue(workspace.exists())
1379
1380     @event_loop()
1381     def test_failed_formatting_does_not_get_cached(self) -> None:
1382         mode = DEFAULT_MODE
1383         with cache_dir() as workspace, patch(
1384             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1385         ):
1386             failing = (workspace / "failing.py").resolve()
1387             with failing.open("w") as fobj:
1388                 fobj.write("not actually python")
1389             clean = (workspace / "clean.py").resolve()
1390             with clean.open("w") as fobj:
1391                 fobj.write('print("hello")\n')
1392             self.invokeBlack([str(workspace)], exit_code=123)
1393             cache = black.read_cache(mode)
1394             self.assertNotIn(failing, cache)
1395             self.assertIn(clean, cache)
1396
1397     def test_write_cache_write_fail(self) -> None:
1398         mode = DEFAULT_MODE
1399         with cache_dir(), patch.object(Path, "open") as mock:
1400             mock.side_effect = OSError
1401             black.write_cache({}, [], mode)
1402
1403     @event_loop()
1404     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1405     def test_works_in_mono_process_only_environment(self) -> None:
1406         with cache_dir() as workspace:
1407             for f in [
1408                 (workspace / "one.py").resolve(),
1409                 (workspace / "two.py").resolve(),
1410             ]:
1411                 f.write_text('print("hello")\n')
1412             self.invokeBlack([str(workspace)])
1413
1414     @event_loop()
1415     def test_check_diff_use_together(self) -> None:
1416         with cache_dir():
1417             # Files which will be reformatted.
1418             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1419             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1420             # Files which will not be reformatted.
1421             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1422             self.invokeBlack([str(src2), "--diff", "--check"])
1423             # Multi file command.
1424             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1425
1426     def test_no_files(self) -> None:
1427         with cache_dir():
1428             # Without an argument, black exits with error code 0.
1429             self.invokeBlack([])
1430
1431     def test_broken_symlink(self) -> None:
1432         with cache_dir() as workspace:
1433             symlink = workspace / "broken_link.py"
1434             try:
1435                 symlink.symlink_to("nonexistent.py")
1436             except OSError as e:
1437                 self.skipTest(f"Can't create symlinks: {e}")
1438             self.invokeBlack([str(workspace.resolve())])
1439
1440     def test_read_cache_line_lengths(self) -> None:
1441         mode = DEFAULT_MODE
1442         short_mode = replace(DEFAULT_MODE, line_length=1)
1443         with cache_dir() as workspace:
1444             path = (workspace / "file.py").resolve()
1445             path.touch()
1446             black.write_cache({}, [path], mode)
1447             one = black.read_cache(mode)
1448             self.assertIn(path, one)
1449             two = black.read_cache(short_mode)
1450             self.assertNotIn(path, two)
1451
1452     def test_tricky_unicode_symbols(self) -> None:
1453         source, expected = read_data("tricky_unicode_symbols")
1454         actual = fs(source)
1455         self.assertFormatEqual(expected, actual)
1456         black.assert_equivalent(source, actual)
1457         black.assert_stable(source, actual, DEFAULT_MODE)
1458
1459     def test_single_file_force_pyi(self) -> None:
1460         reg_mode = DEFAULT_MODE
1461         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1462         contents, expected = read_data("force_pyi")
1463         with cache_dir() as workspace:
1464             path = (workspace / "file.py").resolve()
1465             with open(path, "w") as fh:
1466                 fh.write(contents)
1467             self.invokeBlack([str(path), "--pyi"])
1468             with open(path, "r") as fh:
1469                 actual = fh.read()
1470             # verify cache with --pyi is separate
1471             pyi_cache = black.read_cache(pyi_mode)
1472             self.assertIn(path, pyi_cache)
1473             normal_cache = black.read_cache(reg_mode)
1474             self.assertNotIn(path, normal_cache)
1475         self.assertEqual(actual, expected)
1476
1477     @event_loop()
1478     def test_multi_file_force_pyi(self) -> None:
1479         reg_mode = DEFAULT_MODE
1480         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1481         contents, expected = read_data("force_pyi")
1482         with cache_dir() as workspace:
1483             paths = [
1484                 (workspace / "file1.py").resolve(),
1485                 (workspace / "file2.py").resolve(),
1486             ]
1487             for path in paths:
1488                 with open(path, "w") as fh:
1489                     fh.write(contents)
1490             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1491             for path in paths:
1492                 with open(path, "r") as fh:
1493                     actual = fh.read()
1494                 self.assertEqual(actual, expected)
1495             # verify cache with --pyi is separate
1496             pyi_cache = black.read_cache(pyi_mode)
1497             normal_cache = black.read_cache(reg_mode)
1498             for path in paths:
1499                 self.assertIn(path, pyi_cache)
1500                 self.assertNotIn(path, normal_cache)
1501
1502     def test_pipe_force_pyi(self) -> None:
1503         source, expected = read_data("force_pyi")
1504         result = CliRunner().invoke(
1505             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1506         )
1507         self.assertEqual(result.exit_code, 0)
1508         actual = result.output
1509         self.assertFormatEqual(actual, expected)
1510
1511     def test_single_file_force_py36(self) -> None:
1512         reg_mode = DEFAULT_MODE
1513         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1514         source, expected = read_data("force_py36")
1515         with cache_dir() as workspace:
1516             path = (workspace / "file.py").resolve()
1517             with open(path, "w") as fh:
1518                 fh.write(source)
1519             self.invokeBlack([str(path), *PY36_ARGS])
1520             with open(path, "r") as fh:
1521                 actual = fh.read()
1522             # verify cache with --target-version is separate
1523             py36_cache = black.read_cache(py36_mode)
1524             self.assertIn(path, py36_cache)
1525             normal_cache = black.read_cache(reg_mode)
1526             self.assertNotIn(path, normal_cache)
1527         self.assertEqual(actual, expected)
1528
1529     @event_loop()
1530     def test_multi_file_force_py36(self) -> None:
1531         reg_mode = DEFAULT_MODE
1532         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1533         source, expected = read_data("force_py36")
1534         with cache_dir() as workspace:
1535             paths = [
1536                 (workspace / "file1.py").resolve(),
1537                 (workspace / "file2.py").resolve(),
1538             ]
1539             for path in paths:
1540                 with open(path, "w") as fh:
1541                     fh.write(source)
1542             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1543             for path in paths:
1544                 with open(path, "r") as fh:
1545                     actual = fh.read()
1546                 self.assertEqual(actual, expected)
1547             # verify cache with --target-version is separate
1548             pyi_cache = black.read_cache(py36_mode)
1549             normal_cache = black.read_cache(reg_mode)
1550             for path in paths:
1551                 self.assertIn(path, pyi_cache)
1552                 self.assertNotIn(path, normal_cache)
1553
1554     def test_collections(self) -> None:
1555         source, expected = read_data("collections")
1556         actual = fs(source)
1557         self.assertFormatEqual(expected, actual)
1558         black.assert_equivalent(source, actual)
1559         black.assert_stable(source, actual, DEFAULT_MODE)
1560
1561     def test_pipe_force_py36(self) -> None:
1562         source, expected = read_data("force_py36")
1563         result = CliRunner().invoke(
1564             black.main,
1565             ["-", "-q", "--target-version=py36"],
1566             input=BytesIO(source.encode("utf8")),
1567         )
1568         self.assertEqual(result.exit_code, 0)
1569         actual = result.output
1570         self.assertFormatEqual(actual, expected)
1571
1572     def test_include_exclude(self) -> None:
1573         path = THIS_DIR / "data" / "include_exclude_tests"
1574         include = re.compile(r"\.pyi?$")
1575         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1576         report = black.Report()
1577         gitignore = PathSpec.from_lines("gitwildmatch", [])
1578         sources: List[Path] = []
1579         expected = [
1580             Path(path / "b/dont_exclude/a.py"),
1581             Path(path / "b/dont_exclude/a.pyi"),
1582         ]
1583         this_abs = THIS_DIR.resolve()
1584         sources.extend(
1585             black.gen_python_files(
1586                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1587             )
1588         )
1589         self.assertEqual(sorted(expected), sorted(sources))
1590
1591     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1592     def test_exclude_for_issue_1572(self) -> None:
1593         # Exclude shouldn't touch files that were explicitly given to Black through the
1594         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1595         # https://github.com/psf/black/issues/1572
1596         path = THIS_DIR / "data" / "include_exclude_tests"
1597         include = ""
1598         exclude = r"/exclude/|a\.py"
1599         src = str(path / "b/exclude/a.py")
1600         report = black.Report()
1601         expected = [Path(path / "b/exclude/a.py")]
1602         sources = list(
1603             black.get_sources(
1604                 ctx=FakeContext(),
1605                 src=(src,),
1606                 quiet=True,
1607                 verbose=False,
1608                 include=include,
1609                 exclude=exclude,
1610                 force_exclude=None,
1611                 report=report,
1612             )
1613         )
1614         self.assertEqual(sorted(expected), sorted(sources))
1615
1616     def test_gitignore_exclude(self) -> None:
1617         path = THIS_DIR / "data" / "include_exclude_tests"
1618         include = re.compile(r"\.pyi?$")
1619         exclude = re.compile(r"")
1620         report = black.Report()
1621         gitignore = PathSpec.from_lines(
1622             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1623         )
1624         sources: List[Path] = []
1625         expected = [
1626             Path(path / "b/dont_exclude/a.py"),
1627             Path(path / "b/dont_exclude/a.pyi"),
1628         ]
1629         this_abs = THIS_DIR.resolve()
1630         sources.extend(
1631             black.gen_python_files(
1632                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1633             )
1634         )
1635         self.assertEqual(sorted(expected), sorted(sources))
1636
1637     def test_empty_include(self) -> None:
1638         path = THIS_DIR / "data" / "include_exclude_tests"
1639         report = black.Report()
1640         gitignore = PathSpec.from_lines("gitwildmatch", [])
1641         empty = re.compile(r"")
1642         sources: List[Path] = []
1643         expected = [
1644             Path(path / "b/exclude/a.pie"),
1645             Path(path / "b/exclude/a.py"),
1646             Path(path / "b/exclude/a.pyi"),
1647             Path(path / "b/dont_exclude/a.pie"),
1648             Path(path / "b/dont_exclude/a.py"),
1649             Path(path / "b/dont_exclude/a.pyi"),
1650             Path(path / "b/.definitely_exclude/a.pie"),
1651             Path(path / "b/.definitely_exclude/a.py"),
1652             Path(path / "b/.definitely_exclude/a.pyi"),
1653         ]
1654         this_abs = THIS_DIR.resolve()
1655         sources.extend(
1656             black.gen_python_files(
1657                 path.iterdir(),
1658                 this_abs,
1659                 empty,
1660                 re.compile(black.DEFAULT_EXCLUDES),
1661                 None,
1662                 report,
1663                 gitignore,
1664             )
1665         )
1666         self.assertEqual(sorted(expected), sorted(sources))
1667
1668     def test_empty_exclude(self) -> None:
1669         path = THIS_DIR / "data" / "include_exclude_tests"
1670         report = black.Report()
1671         gitignore = PathSpec.from_lines("gitwildmatch", [])
1672         empty = re.compile(r"")
1673         sources: List[Path] = []
1674         expected = [
1675             Path(path / "b/dont_exclude/a.py"),
1676             Path(path / "b/dont_exclude/a.pyi"),
1677             Path(path / "b/exclude/a.py"),
1678             Path(path / "b/exclude/a.pyi"),
1679             Path(path / "b/.definitely_exclude/a.py"),
1680             Path(path / "b/.definitely_exclude/a.pyi"),
1681         ]
1682         this_abs = THIS_DIR.resolve()
1683         sources.extend(
1684             black.gen_python_files(
1685                 path.iterdir(),
1686                 this_abs,
1687                 re.compile(black.DEFAULT_INCLUDES),
1688                 empty,
1689                 None,
1690                 report,
1691                 gitignore,
1692             )
1693         )
1694         self.assertEqual(sorted(expected), sorted(sources))
1695
1696     def test_invalid_include_exclude(self) -> None:
1697         for option in ["--include", "--exclude"]:
1698             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1699
1700     def test_preserves_line_endings(self) -> None:
1701         with TemporaryDirectory() as workspace:
1702             test_file = Path(workspace) / "test.py"
1703             for nl in ["\n", "\r\n"]:
1704                 contents = nl.join(["def f(  ):", "    pass"])
1705                 test_file.write_bytes(contents.encode())
1706                 ff(test_file, write_back=black.WriteBack.YES)
1707                 updated_contents: bytes = test_file.read_bytes()
1708                 self.assertIn(nl.encode(), updated_contents)
1709                 if nl == "\n":
1710                     self.assertNotIn(b"\r\n", updated_contents)
1711
1712     def test_preserves_line_endings_via_stdin(self) -> None:
1713         for nl in ["\n", "\r\n"]:
1714             contents = nl.join(["def f(  ):", "    pass"])
1715             runner = BlackRunner()
1716             result = runner.invoke(
1717                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1718             )
1719             self.assertEqual(result.exit_code, 0)
1720             output = runner.stdout_bytes
1721             self.assertIn(nl.encode("utf8"), output)
1722             if nl == "\n":
1723                 self.assertNotIn(b"\r\n", output)
1724
1725     def test_assert_equivalent_different_asts(self) -> None:
1726         with self.assertRaises(AssertionError):
1727             black.assert_equivalent("{}", "None")
1728
1729     def test_symlink_out_of_root_directory(self) -> None:
1730         path = MagicMock()
1731         root = THIS_DIR.resolve()
1732         child = MagicMock()
1733         include = re.compile(black.DEFAULT_INCLUDES)
1734         exclude = re.compile(black.DEFAULT_EXCLUDES)
1735         report = black.Report()
1736         gitignore = PathSpec.from_lines("gitwildmatch", [])
1737         # `child` should behave like a symlink which resolved path is clearly
1738         # outside of the `root` directory.
1739         path.iterdir.return_value = [child]
1740         child.resolve.return_value = Path("/a/b/c")
1741         child.as_posix.return_value = "/a/b/c"
1742         child.is_symlink.return_value = True
1743         try:
1744             list(
1745                 black.gen_python_files(
1746                     path.iterdir(), root, include, exclude, None, report, gitignore
1747                 )
1748             )
1749         except ValueError as ve:
1750             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1751         path.iterdir.assert_called_once()
1752         child.resolve.assert_called_once()
1753         child.is_symlink.assert_called_once()
1754         # `child` should behave like a strange file which resolved path is clearly
1755         # outside of the `root` directory.
1756         child.is_symlink.return_value = False
1757         with self.assertRaises(ValueError):
1758             list(
1759                 black.gen_python_files(
1760                     path.iterdir(), root, include, exclude, None, report, gitignore
1761                 )
1762             )
1763         path.iterdir.assert_called()
1764         self.assertEqual(path.iterdir.call_count, 2)
1765         child.resolve.assert_called()
1766         self.assertEqual(child.resolve.call_count, 2)
1767         child.is_symlink.assert_called()
1768         self.assertEqual(child.is_symlink.call_count, 2)
1769
1770     def test_shhh_click(self) -> None:
1771         try:
1772             from click import _unicodefun  # type: ignore
1773         except ModuleNotFoundError:
1774             self.skipTest("Incompatible Click version")
1775         if not hasattr(_unicodefun, "_verify_python3_env"):
1776             self.skipTest("Incompatible Click version")
1777         # First, let's see if Click is crashing with a preferred ASCII charset.
1778         with patch("locale.getpreferredencoding") as gpe:
1779             gpe.return_value = "ASCII"
1780             with self.assertRaises(RuntimeError):
1781                 _unicodefun._verify_python3_env()
1782         # Now, let's silence Click...
1783         black.patch_click()
1784         # ...and confirm it's silent.
1785         with patch("locale.getpreferredencoding") as gpe:
1786             gpe.return_value = "ASCII"
1787             try:
1788                 _unicodefun._verify_python3_env()
1789             except RuntimeError as re:
1790                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1791
1792     def test_root_logger_not_used_directly(self) -> None:
1793         def fail(*args: Any, **kwargs: Any) -> None:
1794             self.fail("Record created with root logger")
1795
1796         with patch.multiple(
1797             logging.root,
1798             debug=fail,
1799             info=fail,
1800             warning=fail,
1801             error=fail,
1802             critical=fail,
1803             log=fail,
1804         ):
1805             ff(THIS_FILE)
1806
1807     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1808     def test_blackd_main(self) -> None:
1809         with patch("blackd.web.run_app"):
1810             result = CliRunner().invoke(blackd.main, [])
1811             if result.exception is not None:
1812                 raise result.exception
1813             self.assertEqual(result.exit_code, 0)
1814
1815     def test_invalid_config_return_code(self) -> None:
1816         tmp_file = Path(black.dump_to_file())
1817         try:
1818             tmp_config = Path(black.dump_to_file())
1819             tmp_config.unlink()
1820             args = ["--config", str(tmp_config), str(tmp_file)]
1821             self.invokeBlack(args, exit_code=2, ignore_config=False)
1822         finally:
1823             tmp_file.unlink()
1824
1825     def test_parse_pyproject_toml(self) -> None:
1826         test_toml_file = THIS_DIR / "test.toml"
1827         config = black.parse_pyproject_toml(str(test_toml_file))
1828         self.assertEqual(config["verbose"], 1)
1829         self.assertEqual(config["check"], "no")
1830         self.assertEqual(config["diff"], "y")
1831         self.assertEqual(config["color"], True)
1832         self.assertEqual(config["line_length"], 79)
1833         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1834         self.assertEqual(config["exclude"], r"\.pyi?$")
1835         self.assertEqual(config["include"], r"\.py?$")
1836
1837     def test_read_pyproject_toml(self) -> None:
1838         test_toml_file = THIS_DIR / "test.toml"
1839         fake_ctx = FakeContext()
1840         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1841         config = fake_ctx.default_map
1842         self.assertEqual(config["verbose"], "1")
1843         self.assertEqual(config["check"], "no")
1844         self.assertEqual(config["diff"], "y")
1845         self.assertEqual(config["color"], "True")
1846         self.assertEqual(config["line_length"], "79")
1847         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1848         self.assertEqual(config["exclude"], r"\.pyi?$")
1849         self.assertEqual(config["include"], r"\.py?$")
1850
1851     def test_find_project_root(self) -> None:
1852         with TemporaryDirectory() as workspace:
1853             root = Path(workspace)
1854             test_dir = root / "test"
1855             test_dir.mkdir()
1856
1857             src_dir = root / "src"
1858             src_dir.mkdir()
1859
1860             root_pyproject = root / "pyproject.toml"
1861             root_pyproject.touch()
1862             src_pyproject = src_dir / "pyproject.toml"
1863             src_pyproject.touch()
1864             src_python = src_dir / "foo.py"
1865             src_python.touch()
1866
1867             self.assertEqual(
1868                 black.find_project_root((src_dir, test_dir)), root.resolve()
1869             )
1870             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1871             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1872
1873
1874 class BlackDTestCase(AioHTTPTestCase):
1875     async def get_application(self) -> web.Application:
1876         return blackd.make_app()
1877
1878     # TODO: remove these decorators once the below is released
1879     # https://github.com/aio-libs/aiohttp/pull/3727
1880     @skip_if_exception("ClientOSError")
1881     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1882     @unittest_run_loop
1883     async def test_blackd_request_needs_formatting(self) -> None:
1884         response = await self.client.post("/", data=b"print('hello world')")
1885         self.assertEqual(response.status, 200)
1886         self.assertEqual(response.charset, "utf8")
1887         self.assertEqual(await response.read(), b'print("hello world")\n')
1888
1889     @skip_if_exception("ClientOSError")
1890     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1891     @unittest_run_loop
1892     async def test_blackd_request_no_change(self) -> None:
1893         response = await self.client.post("/", data=b'print("hello world")\n')
1894         self.assertEqual(response.status, 204)
1895         self.assertEqual(await response.read(), b"")
1896
1897     @skip_if_exception("ClientOSError")
1898     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1899     @unittest_run_loop
1900     async def test_blackd_request_syntax_error(self) -> None:
1901         response = await self.client.post("/", data=b"what even ( is")
1902         self.assertEqual(response.status, 400)
1903         content = await response.text()
1904         self.assertTrue(
1905             content.startswith("Cannot parse"),
1906             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1907         )
1908
1909     @skip_if_exception("ClientOSError")
1910     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1911     @unittest_run_loop
1912     async def test_blackd_unsupported_version(self) -> None:
1913         response = await self.client.post(
1914             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1915         )
1916         self.assertEqual(response.status, 501)
1917
1918     @skip_if_exception("ClientOSError")
1919     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1920     @unittest_run_loop
1921     async def test_blackd_supported_version(self) -> None:
1922         response = await self.client.post(
1923             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1924         )
1925         self.assertEqual(response.status, 200)
1926
1927     @skip_if_exception("ClientOSError")
1928     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1929     @unittest_run_loop
1930     async def test_blackd_invalid_python_variant(self) -> None:
1931         async def check(header_value: str, expected_status: int = 400) -> None:
1932             response = await self.client.post(
1933                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1934             )
1935             self.assertEqual(response.status, expected_status)
1936
1937         await check("lol")
1938         await check("ruby3.5")
1939         await check("pyi3.6")
1940         await check("py1.5")
1941         await check("2.8")
1942         await check("py2.8")
1943         await check("3.0")
1944         await check("pypy3.0")
1945         await check("jython3.4")
1946
1947     @skip_if_exception("ClientOSError")
1948     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1949     @unittest_run_loop
1950     async def test_blackd_pyi(self) -> None:
1951         source, expected = read_data("stub.pyi")
1952         response = await self.client.post(
1953             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1954         )
1955         self.assertEqual(response.status, 200)
1956         self.assertEqual(await response.text(), expected)
1957
1958     @skip_if_exception("ClientOSError")
1959     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1960     @unittest_run_loop
1961     async def test_blackd_diff(self) -> None:
1962         diff_header = re.compile(
1963             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1964         )
1965
1966         source, _ = read_data("blackd_diff.py")
1967         expected, _ = read_data("blackd_diff.diff")
1968
1969         response = await self.client.post(
1970             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
1971         )
1972         self.assertEqual(response.status, 200)
1973
1974         actual = await response.text()
1975         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1976         self.assertEqual(actual, expected)
1977
1978     @skip_if_exception("ClientOSError")
1979     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1980     @unittest_run_loop
1981     async def test_blackd_python_variant(self) -> None:
1982         code = (
1983             "def f(\n"
1984             "    and_has_a_bunch_of,\n"
1985             "    very_long_arguments_too,\n"
1986             "    and_lots_of_them_as_well_lol,\n"
1987             "    **and_very_long_keyword_arguments\n"
1988             "):\n"
1989             "    pass\n"
1990         )
1991
1992         async def check(header_value: str, expected_status: int) -> None:
1993             response = await self.client.post(
1994                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1995             )
1996             self.assertEqual(
1997                 response.status, expected_status, msg=await response.text()
1998             )
1999
2000         await check("3.6", 200)
2001         await check("py3.6", 200)
2002         await check("3.6,3.7", 200)
2003         await check("3.6,py3.7", 200)
2004         await check("py36,py37", 200)
2005         await check("36", 200)
2006         await check("3.6.4", 200)
2007
2008         await check("2", 204)
2009         await check("2.7", 204)
2010         await check("py2.7", 204)
2011         await check("3.4", 204)
2012         await check("py3.4", 204)
2013         await check("py34,py36", 204)
2014         await check("34", 204)
2015
2016     @skip_if_exception("ClientOSError")
2017     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2018     @unittest_run_loop
2019     async def test_blackd_line_length(self) -> None:
2020         response = await self.client.post(
2021             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2022         )
2023         self.assertEqual(response.status, 200)
2024
2025     @skip_if_exception("ClientOSError")
2026     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2027     @unittest_run_loop
2028     async def test_blackd_invalid_line_length(self) -> None:
2029         response = await self.client.post(
2030             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2031         )
2032         self.assertEqual(response.status, 400)
2033
2034     @skip_if_exception("ClientOSError")
2035     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2036     @unittest_run_loop
2037     async def test_blackd_response_black_version_header(self) -> None:
2038         response = await self.client.post("/")
2039         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2040
2041
2042 if __name__ == "__main__":
2043     unittest.main(module="test_black")