]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

16002c0b728c9c82bad30726ceb9820ff1a2ee8d
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from dataclasses import replace
7 from functools import partial
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 import regex as re
13 import sys
14 from tempfile import TemporaryDirectory
15 import types
16 from typing import (
17     Any,
18     BinaryIO,
19     Callable,
20     Dict,
21     Generator,
22     List,
23     Tuple,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 try:
38     import blackd
39     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
40     from aiohttp import web
41 except ImportError:
42     has_blackd_deps = False
43 else:
44     has_blackd_deps = True
45
46 from pathspec import PathSpec
47
48 # Import other test classes
49 from .test_primer import PrimerCLITests  # noqa: F401
50
51
52 DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
53 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
54 fs = partial(black.format_str, mode=DEFAULT_MODE)
55 THIS_FILE = Path(__file__)
56 THIS_DIR = THIS_FILE.parent
57 PROJECT_ROOT = THIS_DIR.parent
58 DETERMINISTIC_HEADER = "[Deterministic header]"
59 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
60 PY36_ARGS = [
61     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
62 ]
63 T = TypeVar("T")
64 R = TypeVar("R")
65
66
67 def dump_to_stderr(*output: str) -> str:
68     return "\n" + "\n".join(output) + "\n"
69
70
71 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
72     """read_data('test_name') -> 'input', 'output'"""
73     if not name.endswith((".py", ".pyi", ".out", ".diff")):
74         name += ".py"
75     _input: List[str] = []
76     _output: List[str] = []
77     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
78     with open(base_dir / name, "r", encoding="utf8") as test:
79         lines = test.readlines()
80     result = _input
81     for line in lines:
82         line = line.replace(EMPTY_LINE, "")
83         if line.rstrip() == "# output":
84             result = _output
85             continue
86
87         result.append(line)
88     if _input and not _output:
89         # If there's no output marker, treat the entire file as already pre-formatted.
90         _output = _input[:]
91     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
92
93
94 @contextmanager
95 def cache_dir(exists: bool = True) -> Iterator[Path]:
96     with TemporaryDirectory() as workspace:
97         cache_dir = Path(workspace)
98         if not exists:
99             cache_dir = cache_dir / "new"
100         with patch("black.CACHE_DIR", cache_dir):
101             yield cache_dir
102
103
104 @contextmanager
105 def event_loop() -> Iterator[None]:
106     policy = asyncio.get_event_loop_policy()
107     loop = policy.new_event_loop()
108     asyncio.set_event_loop(loop)
109     try:
110         yield
111
112     finally:
113         loop.close()
114
115
116 @contextmanager
117 def skip_if_exception(e: str) -> Iterator[None]:
118     try:
119         yield
120     except Exception as exc:
121         if exc.__class__.__name__ == e:
122             unittest.skip(f"Encountered expected exception {exc}, skipping")
123         else:
124             raise
125
126
127 class FakeContext(click.Context):
128     """A fake click Context for when calling functions that need it."""
129
130     def __init__(self) -> None:
131         self.default_map: Dict[str, Any] = {}
132
133
134 class FakeParameter(click.Parameter):
135     """A fake click Parameter for when calling functions that need it."""
136
137     def __init__(self) -> None:
138         pass
139
140
141 class BlackRunner(CliRunner):
142     """Modify CliRunner so that stderr is not merged with stdout.
143
144     This is a hack that can be removed once we depend on Click 7.x"""
145
146     def __init__(self) -> None:
147         self.stderrbuf = BytesIO()
148         self.stdoutbuf = BytesIO()
149         self.stdout_bytes = b""
150         self.stderr_bytes = b""
151         super().__init__()
152
153     @contextmanager
154     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
155         with super().isolation(*args, **kwargs) as output:
156             try:
157                 hold_stderr = sys.stderr
158                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
159                 yield output
160             finally:
161                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
162                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
163                 sys.stderr = hold_stderr
164
165
166 class BlackTestCase(unittest.TestCase):
167     maxDiff = None
168     _diffThreshold = 2 ** 20
169
170     def assertFormatEqual(self, expected: str, actual: str) -> None:
171         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
172             bdv: black.DebugVisitor[Any]
173             black.out("Expected tree:", fg="green")
174             try:
175                 exp_node = black.lib2to3_parse(expected)
176                 bdv = black.DebugVisitor()
177                 list(bdv.visit(exp_node))
178             except Exception as ve:
179                 black.err(str(ve))
180             black.out("Actual tree:", fg="red")
181             try:
182                 exp_node = black.lib2to3_parse(actual)
183                 bdv = black.DebugVisitor()
184                 list(bdv.visit(exp_node))
185             except Exception as ve:
186                 black.err(str(ve))
187         self.assertMultiLineEqual(expected, actual)
188
189     def invokeBlack(
190         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
191     ) -> None:
192         runner = BlackRunner()
193         if ignore_config:
194             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
195         result = runner.invoke(black.main, args)
196         self.assertEqual(
197             result.exit_code,
198             exit_code,
199             msg=(
200                 f"Failed with args: {args}\n"
201                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
202                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
203                 f"exception: {result.exception}"
204             ),
205         )
206
207     @patch("black.dump_to_file", dump_to_stderr)
208     def checkSourceFile(self, name: str, mode: black.FileMode = DEFAULT_MODE) -> None:
209         path = THIS_DIR.parent / name
210         source, expected = read_data(str(path), data=False)
211         actual = fs(source, mode=mode)
212         self.assertFormatEqual(expected, actual)
213         black.assert_equivalent(source, actual)
214         black.assert_stable(source, actual, mode)
215         self.assertFalse(ff(path))
216
217     @patch("black.dump_to_file", dump_to_stderr)
218     def test_empty(self) -> None:
219         source = expected = ""
220         actual = fs(source)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, DEFAULT_MODE)
224
225     def test_empty_ff(self) -> None:
226         expected = ""
227         tmp_file = Path(black.dump_to_file())
228         try:
229             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
230             with open(tmp_file, encoding="utf8") as f:
231                 actual = f.read()
232         finally:
233             os.unlink(tmp_file)
234         self.assertFormatEqual(expected, actual)
235
236     def test_self(self) -> None:
237         self.checkSourceFile("tests/test_black.py")
238
239     def test_black(self) -> None:
240         self.checkSourceFile("src/black/__init__.py")
241
242     def test_pygram(self) -> None:
243         self.checkSourceFile("src/blib2to3/pygram.py")
244
245     def test_pytree(self) -> None:
246         self.checkSourceFile("src/blib2to3/pytree.py")
247
248     def test_conv(self) -> None:
249         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
250
251     def test_driver(self) -> None:
252         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
253
254     def test_grammar(self) -> None:
255         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
256
257     def test_literals(self) -> None:
258         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
259
260     def test_parse(self) -> None:
261         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
262
263     def test_pgen(self) -> None:
264         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
265
266     def test_tokenize(self) -> None:
267         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
268
269     def test_token(self) -> None:
270         self.checkSourceFile("src/blib2to3/pgen2/token.py")
271
272     def test_setup(self) -> None:
273         self.checkSourceFile("setup.py")
274
275     def test_piping(self) -> None:
276         source, expected = read_data("src/black/__init__", data=False)
277         result = BlackRunner().invoke(
278             black.main,
279             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
280             input=BytesIO(source.encode("utf8")),
281         )
282         self.assertEqual(result.exit_code, 0)
283         self.assertFormatEqual(expected, result.output)
284         black.assert_equivalent(source, result.output)
285         black.assert_stable(source, result.output, DEFAULT_MODE)
286
287     def test_piping_diff(self) -> None:
288         diff_header = re.compile(
289             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
290             r"\+\d\d\d\d"
291         )
292         source, _ = read_data("expression.py")
293         expected, _ = read_data("expression.diff")
294         config = THIS_DIR / "data" / "empty_pyproject.toml"
295         args = [
296             "-",
297             "--fast",
298             f"--line-length={black.DEFAULT_LINE_LENGTH}",
299             "--diff",
300             f"--config={config}",
301         ]
302         result = BlackRunner().invoke(
303             black.main, args, input=BytesIO(source.encode("utf8"))
304         )
305         self.assertEqual(result.exit_code, 0)
306         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
307         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
308         self.assertEqual(expected, actual)
309
310     def test_piping_diff_with_color(self) -> None:
311         source, _ = read_data("expression.py")
312         config = THIS_DIR / "data" / "empty_pyproject.toml"
313         args = [
314             "-",
315             "--fast",
316             f"--line-length={black.DEFAULT_LINE_LENGTH}",
317             "--diff",
318             "--color",
319             f"--config={config}",
320         ]
321         result = BlackRunner().invoke(
322             black.main, args, input=BytesIO(source.encode("utf8"))
323         )
324         actual = result.output
325         # Again, the contents are checked in a different test, so only look for colors.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_function(self) -> None:
334         source, expected = read_data("function")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_equivalent(source, actual)
338         black.assert_stable(source, actual, DEFAULT_MODE)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_function2(self) -> None:
342         source, expected = read_data("function2")
343         actual = fs(source)
344         self.assertFormatEqual(expected, actual)
345         black.assert_equivalent(source, actual)
346         black.assert_stable(source, actual, DEFAULT_MODE)
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def test_function_trailing_comma_wip(self) -> None:
350         source, expected = read_data("function_trailing_comma_wip")
351         # sys.settrace(tracefunc)
352         actual = fs(source)
353         # sys.settrace(None)
354         self.assertFormatEqual(expected, actual)
355         black.assert_equivalent(source, actual)
356         black.assert_stable(source, actual, black.FileMode())
357
358     @patch("black.dump_to_file", dump_to_stderr)
359     def test_function_trailing_comma(self) -> None:
360         source, expected = read_data("function_trailing_comma")
361         actual = fs(source)
362         self.assertFormatEqual(expected, actual)
363         black.assert_equivalent(source, actual)
364         black.assert_stable(source, actual, DEFAULT_MODE)
365
366     @patch("black.dump_to_file", dump_to_stderr)
367     def test_expression(self) -> None:
368         source, expected = read_data("expression")
369         actual = fs(source)
370         self.assertFormatEqual(expected, actual)
371         black.assert_equivalent(source, actual)
372         black.assert_stable(source, actual, DEFAULT_MODE)
373
374     @patch("black.dump_to_file", dump_to_stderr)
375     def test_pep_572(self) -> None:
376         source, expected = read_data("pep_572")
377         actual = fs(source)
378         self.assertFormatEqual(expected, actual)
379         black.assert_stable(source, actual, DEFAULT_MODE)
380         if sys.version_info >= (3, 8):
381             black.assert_equivalent(source, actual)
382
383     def test_pep_572_version_detection(self) -> None:
384         source, _ = read_data("pep_572")
385         root = black.lib2to3_parse(source)
386         features = black.get_features_used(root)
387         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
388         versions = black.detect_target_versions(root)
389         self.assertIn(black.TargetVersion.PY38, versions)
390
391     def test_expression_ff(self) -> None:
392         source, expected = read_data("expression")
393         tmp_file = Path(black.dump_to_file(source))
394         try:
395             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
396             with open(tmp_file, encoding="utf8") as f:
397                 actual = f.read()
398         finally:
399             os.unlink(tmp_file)
400         self.assertFormatEqual(expected, actual)
401         with patch("black.dump_to_file", dump_to_stderr):
402             black.assert_equivalent(source, actual)
403             black.assert_stable(source, actual, DEFAULT_MODE)
404
405     def test_expression_diff(self) -> None:
406         source, _ = read_data("expression.py")
407         expected, _ = read_data("expression.diff")
408         tmp_file = Path(black.dump_to_file(source))
409         diff_header = re.compile(
410             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
411             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
412         )
413         try:
414             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
415             self.assertEqual(result.exit_code, 0)
416         finally:
417             os.unlink(tmp_file)
418         actual = result.output
419         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
420         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
421         if expected != actual:
422             dump = black.dump_to_file(actual)
423             msg = (
424                 "Expected diff isn't equal to the actual. If you made changes to"
425                 " expression.py and this is an anticipated difference, overwrite"
426                 f" tests/data/expression.diff with {dump}"
427             )
428             self.assertEqual(expected, actual, msg)
429
430     def test_expression_diff_with_color(self) -> None:
431         source, _ = read_data("expression.py")
432         expected, _ = read_data("expression.diff")
433         tmp_file = Path(black.dump_to_file(source))
434         try:
435             result = BlackRunner().invoke(
436                 black.main, ["--diff", "--color", str(tmp_file)]
437             )
438         finally:
439             os.unlink(tmp_file)
440         actual = result.output
441         # We check the contents of the diff in `test_expression_diff`. All
442         # we need to check here is that color codes exist in the result.
443         self.assertIn("\033[1;37m", actual)
444         self.assertIn("\033[36m", actual)
445         self.assertIn("\033[32m", actual)
446         self.assertIn("\033[31m", actual)
447         self.assertIn("\033[0m", actual)
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_fstring(self) -> None:
451         source, expected = read_data("fstring")
452         actual = fs(source)
453         self.assertFormatEqual(expected, actual)
454         black.assert_equivalent(source, actual)
455         black.assert_stable(source, actual, DEFAULT_MODE)
456
457     @patch("black.dump_to_file", dump_to_stderr)
458     def test_pep_570(self) -> None:
459         source, expected = read_data("pep_570")
460         actual = fs(source)
461         self.assertFormatEqual(expected, actual)
462         black.assert_stable(source, actual, DEFAULT_MODE)
463         if sys.version_info >= (3, 8):
464             black.assert_equivalent(source, actual)
465
466     def test_detect_pos_only_arguments(self) -> None:
467         source, _ = read_data("pep_570")
468         root = black.lib2to3_parse(source)
469         features = black.get_features_used(root)
470         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
471         versions = black.detect_target_versions(root)
472         self.assertIn(black.TargetVersion.PY38, versions)
473
474     @patch("black.dump_to_file", dump_to_stderr)
475     def test_string_quotes(self) -> None:
476         source, expected = read_data("string_quotes")
477         actual = fs(source)
478         self.assertFormatEqual(expected, actual)
479         black.assert_equivalent(source, actual)
480         black.assert_stable(source, actual, DEFAULT_MODE)
481         mode = replace(DEFAULT_MODE, string_normalization=False)
482         not_normalized = fs(source, mode=mode)
483         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
484         black.assert_equivalent(source, not_normalized)
485         black.assert_stable(source, not_normalized, mode=mode)
486
487     @patch("black.dump_to_file", dump_to_stderr)
488     def test_docstring(self) -> None:
489         source, expected = read_data("docstring")
490         actual = fs(source)
491         self.assertFormatEqual(expected, actual)
492         black.assert_equivalent(source, actual)
493         black.assert_stable(source, actual, DEFAULT_MODE)
494
495     def test_long_strings(self) -> None:
496         """Tests for splitting long strings."""
497         source, expected = read_data("long_strings")
498         actual = fs(source)
499         self.assertFormatEqual(expected, actual)
500         black.assert_equivalent(source, actual)
501         black.assert_stable(source, actual, DEFAULT_MODE)
502
503     def test_long_strings_flag_disabled(self) -> None:
504         """Tests for turning off the string processing logic."""
505         source, expected = read_data("long_strings_flag_disabled")
506         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
507         actual = fs(source, mode=mode)
508         self.assertFormatEqual(expected, actual)
509         black.assert_stable(expected, actual, mode)
510
511     @patch("black.dump_to_file", dump_to_stderr)
512     def test_long_strings__edge_case(self) -> None:
513         """Edge-case tests for splitting long strings."""
514         source, expected = read_data("long_strings__edge_case")
515         actual = fs(source)
516         self.assertFormatEqual(expected, actual)
517         black.assert_equivalent(source, actual)
518         black.assert_stable(source, actual, DEFAULT_MODE)
519
520     @patch("black.dump_to_file", dump_to_stderr)
521     def test_long_strings__regression(self) -> None:
522         """Regression tests for splitting long strings."""
523         source, expected = read_data("long_strings__regression")
524         actual = fs(source)
525         self.assertFormatEqual(expected, actual)
526         black.assert_equivalent(source, actual)
527         black.assert_stable(source, actual, DEFAULT_MODE)
528
529     @patch("black.dump_to_file", dump_to_stderr)
530     def test_slices(self) -> None:
531         source, expected = read_data("slices")
532         actual = fs(source)
533         self.assertFormatEqual(expected, actual)
534         black.assert_equivalent(source, actual)
535         black.assert_stable(source, actual, DEFAULT_MODE)
536
537     @patch("black.dump_to_file", dump_to_stderr)
538     def test_percent_precedence(self) -> None:
539         source, expected = read_data("percent_precedence")
540         actual = fs(source)
541         self.assertFormatEqual(expected, actual)
542         black.assert_equivalent(source, actual)
543         black.assert_stable(source, actual, DEFAULT_MODE)
544
545     @patch("black.dump_to_file", dump_to_stderr)
546     def test_comments(self) -> None:
547         source, expected = read_data("comments")
548         actual = fs(source)
549         self.assertFormatEqual(expected, actual)
550         black.assert_equivalent(source, actual)
551         black.assert_stable(source, actual, DEFAULT_MODE)
552
553     @patch("black.dump_to_file", dump_to_stderr)
554     def test_comments2(self) -> None:
555         source, expected = read_data("comments2")
556         actual = fs(source)
557         self.assertFormatEqual(expected, actual)
558         black.assert_equivalent(source, actual)
559         black.assert_stable(source, actual, DEFAULT_MODE)
560
561     @patch("black.dump_to_file", dump_to_stderr)
562     def test_comments3(self) -> None:
563         source, expected = read_data("comments3")
564         actual = fs(source)
565         self.assertFormatEqual(expected, actual)
566         black.assert_equivalent(source, actual)
567         black.assert_stable(source, actual, DEFAULT_MODE)
568
569     @patch("black.dump_to_file", dump_to_stderr)
570     def test_comments4(self) -> None:
571         source, expected = read_data("comments4")
572         actual = fs(source)
573         self.assertFormatEqual(expected, actual)
574         black.assert_equivalent(source, actual)
575         black.assert_stable(source, actual, DEFAULT_MODE)
576
577     @patch("black.dump_to_file", dump_to_stderr)
578     def test_comments5(self) -> None:
579         source, expected = read_data("comments5")
580         actual = fs(source)
581         self.assertFormatEqual(expected, actual)
582         black.assert_equivalent(source, actual)
583         black.assert_stable(source, actual, DEFAULT_MODE)
584
585     @patch("black.dump_to_file", dump_to_stderr)
586     def test_comments6(self) -> None:
587         source, expected = read_data("comments6")
588         actual = fs(source)
589         self.assertFormatEqual(expected, actual)
590         black.assert_equivalent(source, actual)
591         black.assert_stable(source, actual, DEFAULT_MODE)
592
593     @patch("black.dump_to_file", dump_to_stderr)
594     def test_comments7(self) -> None:
595         source, expected = read_data("comments7")
596         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
597         actual = fs(source, mode=mode)
598         self.assertFormatEqual(expected, actual)
599         black.assert_equivalent(source, actual)
600         black.assert_stable(source, actual, DEFAULT_MODE)
601
602     @patch("black.dump_to_file", dump_to_stderr)
603     def test_comment_after_escaped_newline(self) -> None:
604         source, expected = read_data("comment_after_escaped_newline")
605         actual = fs(source)
606         self.assertFormatEqual(expected, actual)
607         black.assert_equivalent(source, actual)
608         black.assert_stable(source, actual, DEFAULT_MODE)
609
610     @patch("black.dump_to_file", dump_to_stderr)
611     def test_cantfit(self) -> None:
612         source, expected = read_data("cantfit")
613         actual = fs(source)
614         self.assertFormatEqual(expected, actual)
615         black.assert_equivalent(source, actual)
616         black.assert_stable(source, actual, DEFAULT_MODE)
617
618     @patch("black.dump_to_file", dump_to_stderr)
619     def test_import_spacing(self) -> None:
620         source, expected = read_data("import_spacing")
621         actual = fs(source)
622         self.assertFormatEqual(expected, actual)
623         black.assert_equivalent(source, actual)
624         black.assert_stable(source, actual, DEFAULT_MODE)
625
626     @patch("black.dump_to_file", dump_to_stderr)
627     def test_composition(self) -> None:
628         source, expected = read_data("composition")
629         actual = fs(source)
630         self.assertFormatEqual(expected, actual)
631         black.assert_equivalent(source, actual)
632         black.assert_stable(source, actual, DEFAULT_MODE)
633
634     @patch("black.dump_to_file", dump_to_stderr)
635     def test_composition_no_trailing_comma(self) -> None:
636         source, expected = read_data("composition_no_trailing_comma")
637         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
638         actual = fs(source, mode=mode)
639         self.assertFormatEqual(expected, actual)
640         black.assert_equivalent(source, actual)
641         black.assert_stable(source, actual, DEFAULT_MODE)
642
643     @patch("black.dump_to_file", dump_to_stderr)
644     def test_empty_lines(self) -> None:
645         source, expected = read_data("empty_lines")
646         actual = fs(source)
647         self.assertFormatEqual(expected, actual)
648         black.assert_equivalent(source, actual)
649         black.assert_stable(source, actual, DEFAULT_MODE)
650
651     @patch("black.dump_to_file", dump_to_stderr)
652     def test_remove_parens(self) -> None:
653         source, expected = read_data("remove_parens")
654         actual = fs(source)
655         self.assertFormatEqual(expected, actual)
656         black.assert_equivalent(source, actual)
657         black.assert_stable(source, actual, DEFAULT_MODE)
658
659     @patch("black.dump_to_file", dump_to_stderr)
660     def test_string_prefixes(self) -> None:
661         source, expected = read_data("string_prefixes")
662         actual = fs(source)
663         self.assertFormatEqual(expected, actual)
664         black.assert_equivalent(source, actual)
665         black.assert_stable(source, actual, DEFAULT_MODE)
666
667     @patch("black.dump_to_file", dump_to_stderr)
668     def test_numeric_literals(self) -> None:
669         source, expected = read_data("numeric_literals")
670         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
671         actual = fs(source, mode=mode)
672         self.assertFormatEqual(expected, actual)
673         black.assert_equivalent(source, actual)
674         black.assert_stable(source, actual, mode)
675
676     @patch("black.dump_to_file", dump_to_stderr)
677     def test_numeric_literals_ignoring_underscores(self) -> None:
678         source, expected = read_data("numeric_literals_skip_underscores")
679         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
680         actual = fs(source, mode=mode)
681         self.assertFormatEqual(expected, actual)
682         black.assert_equivalent(source, actual)
683         black.assert_stable(source, actual, mode)
684
685     @patch("black.dump_to_file", dump_to_stderr)
686     def test_numeric_literals_py2(self) -> None:
687         source, expected = read_data("numeric_literals_py2")
688         actual = fs(source)
689         self.assertFormatEqual(expected, actual)
690         black.assert_stable(source, actual, DEFAULT_MODE)
691
692     @patch("black.dump_to_file", dump_to_stderr)
693     def test_python2(self) -> None:
694         source, expected = read_data("python2")
695         actual = fs(source)
696         self.assertFormatEqual(expected, actual)
697         black.assert_equivalent(source, actual)
698         black.assert_stable(source, actual, DEFAULT_MODE)
699
700     @patch("black.dump_to_file", dump_to_stderr)
701     def test_python2_print_function(self) -> None:
702         source, expected = read_data("python2_print_function")
703         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
704         actual = fs(source, mode=mode)
705         self.assertFormatEqual(expected, actual)
706         black.assert_equivalent(source, actual)
707         black.assert_stable(source, actual, mode)
708
709     @patch("black.dump_to_file", dump_to_stderr)
710     def test_python2_unicode_literals(self) -> None:
711         source, expected = read_data("python2_unicode_literals")
712         actual = fs(source)
713         self.assertFormatEqual(expected, actual)
714         black.assert_equivalent(source, actual)
715         black.assert_stable(source, actual, DEFAULT_MODE)
716
717     @patch("black.dump_to_file", dump_to_stderr)
718     def test_stub(self) -> None:
719         mode = replace(DEFAULT_MODE, is_pyi=True)
720         source, expected = read_data("stub.pyi")
721         actual = fs(source, mode=mode)
722         self.assertFormatEqual(expected, actual)
723         black.assert_stable(source, actual, mode)
724
725     @patch("black.dump_to_file", dump_to_stderr)
726     def test_async_as_identifier(self) -> None:
727         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
728         source, expected = read_data("async_as_identifier")
729         actual = fs(source)
730         self.assertFormatEqual(expected, actual)
731         major, minor = sys.version_info[:2]
732         if major < 3 or (major <= 3 and minor < 7):
733             black.assert_equivalent(source, actual)
734         black.assert_stable(source, actual, DEFAULT_MODE)
735         # ensure black can parse this when the target is 3.6
736         self.invokeBlack([str(source_path), "--target-version", "py36"])
737         # but not on 3.7, because async/await is no longer an identifier
738         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
739
740     @patch("black.dump_to_file", dump_to_stderr)
741     def test_python37(self) -> None:
742         source_path = (THIS_DIR / "data" / "python37.py").resolve()
743         source, expected = read_data("python37")
744         actual = fs(source)
745         self.assertFormatEqual(expected, actual)
746         major, minor = sys.version_info[:2]
747         if major > 3 or (major == 3 and minor >= 7):
748             black.assert_equivalent(source, actual)
749         black.assert_stable(source, actual, DEFAULT_MODE)
750         # ensure black can parse this when the target is 3.7
751         self.invokeBlack([str(source_path), "--target-version", "py37"])
752         # but not on 3.6, because we use async as a reserved keyword
753         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
754
755     @patch("black.dump_to_file", dump_to_stderr)
756     def test_python38(self) -> None:
757         source, expected = read_data("python38")
758         actual = fs(source)
759         self.assertFormatEqual(expected, actual)
760         major, minor = sys.version_info[:2]
761         if major > 3 or (major == 3 and minor >= 8):
762             black.assert_equivalent(source, actual)
763         black.assert_stable(source, actual, DEFAULT_MODE)
764
765     @patch("black.dump_to_file", dump_to_stderr)
766     def test_fmtonoff(self) -> None:
767         source, expected = read_data("fmtonoff")
768         actual = fs(source)
769         self.assertFormatEqual(expected, actual)
770         black.assert_equivalent(source, actual)
771         black.assert_stable(source, actual, DEFAULT_MODE)
772
773     @patch("black.dump_to_file", dump_to_stderr)
774     def test_fmtonoff2(self) -> None:
775         source, expected = read_data("fmtonoff2")
776         actual = fs(source)
777         self.assertFormatEqual(expected, actual)
778         black.assert_equivalent(source, actual)
779         black.assert_stable(source, actual, DEFAULT_MODE)
780
781     @patch("black.dump_to_file", dump_to_stderr)
782     def test_fmtonoff3(self) -> None:
783         source, expected = read_data("fmtonoff3")
784         actual = fs(source)
785         self.assertFormatEqual(expected, actual)
786         black.assert_equivalent(source, actual)
787         black.assert_stable(source, actual, DEFAULT_MODE)
788
789     @patch("black.dump_to_file", dump_to_stderr)
790     def test_fmtonoff4(self) -> None:
791         source, expected = read_data("fmtonoff4")
792         actual = fs(source)
793         self.assertFormatEqual(expected, actual)
794         black.assert_equivalent(source, actual)
795         black.assert_stable(source, actual, DEFAULT_MODE)
796
797     @patch("black.dump_to_file", dump_to_stderr)
798     def test_remove_empty_parentheses_after_class(self) -> None:
799         source, expected = read_data("class_blank_parentheses")
800         actual = fs(source)
801         self.assertFormatEqual(expected, actual)
802         black.assert_equivalent(source, actual)
803         black.assert_stable(source, actual, DEFAULT_MODE)
804
805     @patch("black.dump_to_file", dump_to_stderr)
806     def test_new_line_between_class_and_code(self) -> None:
807         source, expected = read_data("class_methods_new_line")
808         actual = fs(source)
809         self.assertFormatEqual(expected, actual)
810         black.assert_equivalent(source, actual)
811         black.assert_stable(source, actual, DEFAULT_MODE)
812
813     @patch("black.dump_to_file", dump_to_stderr)
814     def test_bracket_match(self) -> None:
815         source, expected = read_data("bracketmatch")
816         actual = fs(source)
817         self.assertFormatEqual(expected, actual)
818         black.assert_equivalent(source, actual)
819         black.assert_stable(source, actual, DEFAULT_MODE)
820
821     @patch("black.dump_to_file", dump_to_stderr)
822     def test_tuple_assign(self) -> None:
823         source, expected = read_data("tupleassign")
824         actual = fs(source)
825         self.assertFormatEqual(expected, actual)
826         black.assert_equivalent(source, actual)
827         black.assert_stable(source, actual, DEFAULT_MODE)
828
829     @patch("black.dump_to_file", dump_to_stderr)
830     def test_beginning_backslash(self) -> None:
831         source, expected = read_data("beginning_backslash")
832         actual = fs(source)
833         self.assertFormatEqual(expected, actual)
834         black.assert_equivalent(source, actual)
835         black.assert_stable(source, actual, DEFAULT_MODE)
836
837     def test_tab_comment_indentation(self) -> None:
838         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
839         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
840         self.assertFormatEqual(contents_spc, fs(contents_spc))
841         self.assertFormatEqual(contents_spc, fs(contents_tab))
842
843         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
844         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
845         self.assertFormatEqual(contents_spc, fs(contents_spc))
846         self.assertFormatEqual(contents_spc, fs(contents_tab))
847
848         # mixed tabs and spaces (valid Python 2 code)
849         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
850         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
851         self.assertFormatEqual(contents_spc, fs(contents_spc))
852         self.assertFormatEqual(contents_spc, fs(contents_tab))
853
854         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
855         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
856         self.assertFormatEqual(contents_spc, fs(contents_spc))
857         self.assertFormatEqual(contents_spc, fs(contents_tab))
858
859     def test_report_verbose(self) -> None:
860         report = black.Report(verbose=True)
861         out_lines = []
862         err_lines = []
863
864         def out(msg: str, **kwargs: Any) -> None:
865             out_lines.append(msg)
866
867         def err(msg: str, **kwargs: Any) -> None:
868             err_lines.append(msg)
869
870         with patch("black.out", out), patch("black.err", err):
871             report.done(Path("f1"), black.Changed.NO)
872             self.assertEqual(len(out_lines), 1)
873             self.assertEqual(len(err_lines), 0)
874             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
875             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
876             self.assertEqual(report.return_code, 0)
877             report.done(Path("f2"), black.Changed.YES)
878             self.assertEqual(len(out_lines), 2)
879             self.assertEqual(len(err_lines), 0)
880             self.assertEqual(out_lines[-1], "reformatted f2")
881             self.assertEqual(
882                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
883             )
884             report.done(Path("f3"), black.Changed.CACHED)
885             self.assertEqual(len(out_lines), 3)
886             self.assertEqual(len(err_lines), 0)
887             self.assertEqual(
888                 out_lines[-1], "f3 wasn't modified on disk since last run."
889             )
890             self.assertEqual(
891                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
892             )
893             self.assertEqual(report.return_code, 0)
894             report.check = True
895             self.assertEqual(report.return_code, 1)
896             report.check = False
897             report.failed(Path("e1"), "boom")
898             self.assertEqual(len(out_lines), 3)
899             self.assertEqual(len(err_lines), 1)
900             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
901             self.assertEqual(
902                 unstyle(str(report)),
903                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
904                 " reformat.",
905             )
906             self.assertEqual(report.return_code, 123)
907             report.done(Path("f3"), black.Changed.YES)
908             self.assertEqual(len(out_lines), 4)
909             self.assertEqual(len(err_lines), 1)
910             self.assertEqual(out_lines[-1], "reformatted f3")
911             self.assertEqual(
912                 unstyle(str(report)),
913                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
914                 " reformat.",
915             )
916             self.assertEqual(report.return_code, 123)
917             report.failed(Path("e2"), "boom")
918             self.assertEqual(len(out_lines), 4)
919             self.assertEqual(len(err_lines), 2)
920             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
921             self.assertEqual(
922                 unstyle(str(report)),
923                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
924                 " reformat.",
925             )
926             self.assertEqual(report.return_code, 123)
927             report.path_ignored(Path("wat"), "no match")
928             self.assertEqual(len(out_lines), 5)
929             self.assertEqual(len(err_lines), 2)
930             self.assertEqual(out_lines[-1], "wat ignored: no match")
931             self.assertEqual(
932                 unstyle(str(report)),
933                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
934                 " reformat.",
935             )
936             self.assertEqual(report.return_code, 123)
937             report.done(Path("f4"), black.Changed.NO)
938             self.assertEqual(len(out_lines), 6)
939             self.assertEqual(len(err_lines), 2)
940             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
941             self.assertEqual(
942                 unstyle(str(report)),
943                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
944                 " reformat.",
945             )
946             self.assertEqual(report.return_code, 123)
947             report.check = True
948             self.assertEqual(
949                 unstyle(str(report)),
950                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
951                 " would fail to reformat.",
952             )
953             report.check = False
954             report.diff = True
955             self.assertEqual(
956                 unstyle(str(report)),
957                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
958                 " would fail to reformat.",
959             )
960
961     def test_report_quiet(self) -> None:
962         report = black.Report(quiet=True)
963         out_lines = []
964         err_lines = []
965
966         def out(msg: str, **kwargs: Any) -> None:
967             out_lines.append(msg)
968
969         def err(msg: str, **kwargs: Any) -> None:
970             err_lines.append(msg)
971
972         with patch("black.out", out), patch("black.err", err):
973             report.done(Path("f1"), black.Changed.NO)
974             self.assertEqual(len(out_lines), 0)
975             self.assertEqual(len(err_lines), 0)
976             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
977             self.assertEqual(report.return_code, 0)
978             report.done(Path("f2"), black.Changed.YES)
979             self.assertEqual(len(out_lines), 0)
980             self.assertEqual(len(err_lines), 0)
981             self.assertEqual(
982                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
983             )
984             report.done(Path("f3"), black.Changed.CACHED)
985             self.assertEqual(len(out_lines), 0)
986             self.assertEqual(len(err_lines), 0)
987             self.assertEqual(
988                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
989             )
990             self.assertEqual(report.return_code, 0)
991             report.check = True
992             self.assertEqual(report.return_code, 1)
993             report.check = False
994             report.failed(Path("e1"), "boom")
995             self.assertEqual(len(out_lines), 0)
996             self.assertEqual(len(err_lines), 1)
997             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
998             self.assertEqual(
999                 unstyle(str(report)),
1000                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1001                 " reformat.",
1002             )
1003             self.assertEqual(report.return_code, 123)
1004             report.done(Path("f3"), black.Changed.YES)
1005             self.assertEqual(len(out_lines), 0)
1006             self.assertEqual(len(err_lines), 1)
1007             self.assertEqual(
1008                 unstyle(str(report)),
1009                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1010                 " reformat.",
1011             )
1012             self.assertEqual(report.return_code, 123)
1013             report.failed(Path("e2"), "boom")
1014             self.assertEqual(len(out_lines), 0)
1015             self.assertEqual(len(err_lines), 2)
1016             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1017             self.assertEqual(
1018                 unstyle(str(report)),
1019                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1020                 " reformat.",
1021             )
1022             self.assertEqual(report.return_code, 123)
1023             report.path_ignored(Path("wat"), "no match")
1024             self.assertEqual(len(out_lines), 0)
1025             self.assertEqual(len(err_lines), 2)
1026             self.assertEqual(
1027                 unstyle(str(report)),
1028                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1029                 " reformat.",
1030             )
1031             self.assertEqual(report.return_code, 123)
1032             report.done(Path("f4"), black.Changed.NO)
1033             self.assertEqual(len(out_lines), 0)
1034             self.assertEqual(len(err_lines), 2)
1035             self.assertEqual(
1036                 unstyle(str(report)),
1037                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1038                 " reformat.",
1039             )
1040             self.assertEqual(report.return_code, 123)
1041             report.check = True
1042             self.assertEqual(
1043                 unstyle(str(report)),
1044                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1045                 " would fail to reformat.",
1046             )
1047             report.check = False
1048             report.diff = True
1049             self.assertEqual(
1050                 unstyle(str(report)),
1051                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1052                 " would fail to reformat.",
1053             )
1054
1055     def test_report_normal(self) -> None:
1056         report = black.Report()
1057         out_lines = []
1058         err_lines = []
1059
1060         def out(msg: str, **kwargs: Any) -> None:
1061             out_lines.append(msg)
1062
1063         def err(msg: str, **kwargs: Any) -> None:
1064             err_lines.append(msg)
1065
1066         with patch("black.out", out), patch("black.err", err):
1067             report.done(Path("f1"), black.Changed.NO)
1068             self.assertEqual(len(out_lines), 0)
1069             self.assertEqual(len(err_lines), 0)
1070             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1071             self.assertEqual(report.return_code, 0)
1072             report.done(Path("f2"), black.Changed.YES)
1073             self.assertEqual(len(out_lines), 1)
1074             self.assertEqual(len(err_lines), 0)
1075             self.assertEqual(out_lines[-1], "reformatted f2")
1076             self.assertEqual(
1077                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1078             )
1079             report.done(Path("f3"), black.Changed.CACHED)
1080             self.assertEqual(len(out_lines), 1)
1081             self.assertEqual(len(err_lines), 0)
1082             self.assertEqual(out_lines[-1], "reformatted f2")
1083             self.assertEqual(
1084                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1085             )
1086             self.assertEqual(report.return_code, 0)
1087             report.check = True
1088             self.assertEqual(report.return_code, 1)
1089             report.check = False
1090             report.failed(Path("e1"), "boom")
1091             self.assertEqual(len(out_lines), 1)
1092             self.assertEqual(len(err_lines), 1)
1093             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1094             self.assertEqual(
1095                 unstyle(str(report)),
1096                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1097                 " reformat.",
1098             )
1099             self.assertEqual(report.return_code, 123)
1100             report.done(Path("f3"), black.Changed.YES)
1101             self.assertEqual(len(out_lines), 2)
1102             self.assertEqual(len(err_lines), 1)
1103             self.assertEqual(out_lines[-1], "reformatted f3")
1104             self.assertEqual(
1105                 unstyle(str(report)),
1106                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1107                 " reformat.",
1108             )
1109             self.assertEqual(report.return_code, 123)
1110             report.failed(Path("e2"), "boom")
1111             self.assertEqual(len(out_lines), 2)
1112             self.assertEqual(len(err_lines), 2)
1113             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1114             self.assertEqual(
1115                 unstyle(str(report)),
1116                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1117                 " reformat.",
1118             )
1119             self.assertEqual(report.return_code, 123)
1120             report.path_ignored(Path("wat"), "no match")
1121             self.assertEqual(len(out_lines), 2)
1122             self.assertEqual(len(err_lines), 2)
1123             self.assertEqual(
1124                 unstyle(str(report)),
1125                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1126                 " reformat.",
1127             )
1128             self.assertEqual(report.return_code, 123)
1129             report.done(Path("f4"), black.Changed.NO)
1130             self.assertEqual(len(out_lines), 2)
1131             self.assertEqual(len(err_lines), 2)
1132             self.assertEqual(
1133                 unstyle(str(report)),
1134                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1135                 " reformat.",
1136             )
1137             self.assertEqual(report.return_code, 123)
1138             report.check = True
1139             self.assertEqual(
1140                 unstyle(str(report)),
1141                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1142                 " would fail to reformat.",
1143             )
1144             report.check = False
1145             report.diff = True
1146             self.assertEqual(
1147                 unstyle(str(report)),
1148                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1149                 " would fail to reformat.",
1150             )
1151
1152     def test_lib2to3_parse(self) -> None:
1153         with self.assertRaises(black.InvalidInput):
1154             black.lib2to3_parse("invalid syntax")
1155
1156         straddling = "x + y"
1157         black.lib2to3_parse(straddling)
1158         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1159         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1160         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1161
1162         py2_only = "print x"
1163         black.lib2to3_parse(py2_only)
1164         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1165         with self.assertRaises(black.InvalidInput):
1166             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1167         with self.assertRaises(black.InvalidInput):
1168             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1169
1170         py3_only = "exec(x, end=y)"
1171         black.lib2to3_parse(py3_only)
1172         with self.assertRaises(black.InvalidInput):
1173             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1174         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1175         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1176
1177     def test_get_features_used(self) -> None:
1178         node = black.lib2to3_parse("def f(*, arg): ...\n")
1179         self.assertEqual(black.get_features_used(node), set())
1180         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1181         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1182         node = black.lib2to3_parse("f(*arg,)\n")
1183         self.assertEqual(
1184             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1185         )
1186         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1187         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1188         node = black.lib2to3_parse("123_456\n")
1189         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1190         node = black.lib2to3_parse("123456\n")
1191         self.assertEqual(black.get_features_used(node), set())
1192         source, expected = read_data("function")
1193         node = black.lib2to3_parse(source)
1194         expected_features = {
1195             Feature.TRAILING_COMMA_IN_CALL,
1196             Feature.TRAILING_COMMA_IN_DEF,
1197             Feature.F_STRINGS,
1198         }
1199         self.assertEqual(black.get_features_used(node), expected_features)
1200         node = black.lib2to3_parse(expected)
1201         self.assertEqual(black.get_features_used(node), expected_features)
1202         source, expected = read_data("expression")
1203         node = black.lib2to3_parse(source)
1204         self.assertEqual(black.get_features_used(node), set())
1205         node = black.lib2to3_parse(expected)
1206         self.assertEqual(black.get_features_used(node), set())
1207
1208     def test_get_future_imports(self) -> None:
1209         node = black.lib2to3_parse("\n")
1210         self.assertEqual(set(), black.get_future_imports(node))
1211         node = black.lib2to3_parse("from __future__ import black\n")
1212         self.assertEqual({"black"}, black.get_future_imports(node))
1213         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1214         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1215         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1216         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1217         node = black.lib2to3_parse(
1218             "from __future__ import multiple\nfrom __future__ import imports\n"
1219         )
1220         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1221         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1222         self.assertEqual({"black"}, black.get_future_imports(node))
1223         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1224         self.assertEqual({"black"}, black.get_future_imports(node))
1225         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1226         self.assertEqual(set(), black.get_future_imports(node))
1227         node = black.lib2to3_parse("from some.module import black\n")
1228         self.assertEqual(set(), black.get_future_imports(node))
1229         node = black.lib2to3_parse(
1230             "from __future__ import unicode_literals as _unicode_literals"
1231         )
1232         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1233         node = black.lib2to3_parse(
1234             "from __future__ import unicode_literals as _lol, print"
1235         )
1236         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1237
1238     def test_debug_visitor(self) -> None:
1239         source, _ = read_data("debug_visitor.py")
1240         expected, _ = read_data("debug_visitor.out")
1241         out_lines = []
1242         err_lines = []
1243
1244         def out(msg: str, **kwargs: Any) -> None:
1245             out_lines.append(msg)
1246
1247         def err(msg: str, **kwargs: Any) -> None:
1248             err_lines.append(msg)
1249
1250         with patch("black.out", out), patch("black.err", err):
1251             black.DebugVisitor.show(source)
1252         actual = "\n".join(out_lines) + "\n"
1253         log_name = ""
1254         if expected != actual:
1255             log_name = black.dump_to_file(*out_lines)
1256         self.assertEqual(
1257             expected,
1258             actual,
1259             f"AST print out is different. Actual version dumped to {log_name}",
1260         )
1261
1262     def test_format_file_contents(self) -> None:
1263         empty = ""
1264         mode = DEFAULT_MODE
1265         with self.assertRaises(black.NothingChanged):
1266             black.format_file_contents(empty, mode=mode, fast=False)
1267         just_nl = "\n"
1268         with self.assertRaises(black.NothingChanged):
1269             black.format_file_contents(just_nl, mode=mode, fast=False)
1270         same = "j = [1, 2, 3]\n"
1271         with self.assertRaises(black.NothingChanged):
1272             black.format_file_contents(same, mode=mode, fast=False)
1273         different = "j = [1,2,3]"
1274         expected = same
1275         actual = black.format_file_contents(different, mode=mode, fast=False)
1276         self.assertEqual(expected, actual)
1277         invalid = "return if you can"
1278         with self.assertRaises(black.InvalidInput) as e:
1279             black.format_file_contents(invalid, mode=mode, fast=False)
1280         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1281
1282     def test_endmarker(self) -> None:
1283         n = black.lib2to3_parse("\n")
1284         self.assertEqual(n.type, black.syms.file_input)
1285         self.assertEqual(len(n.children), 1)
1286         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1287
1288     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1289     def test_assertFormatEqual(self) -> None:
1290         out_lines = []
1291         err_lines = []
1292
1293         def out(msg: str, **kwargs: Any) -> None:
1294             out_lines.append(msg)
1295
1296         def err(msg: str, **kwargs: Any) -> None:
1297             err_lines.append(msg)
1298
1299         with patch("black.out", out), patch("black.err", err):
1300             with self.assertRaises(AssertionError):
1301                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1302
1303         out_str = "".join(out_lines)
1304         self.assertTrue("Expected tree:" in out_str)
1305         self.assertTrue("Actual tree:" in out_str)
1306         self.assertEqual("".join(err_lines), "")
1307
1308     def test_cache_broken_file(self) -> None:
1309         mode = DEFAULT_MODE
1310         with cache_dir() as workspace:
1311             cache_file = black.get_cache_file(mode)
1312             with cache_file.open("w") as fobj:
1313                 fobj.write("this is not a pickle")
1314             self.assertEqual(black.read_cache(mode), {})
1315             src = (workspace / "test.py").resolve()
1316             with src.open("w") as fobj:
1317                 fobj.write("print('hello')")
1318             self.invokeBlack([str(src)])
1319             cache = black.read_cache(mode)
1320             self.assertIn(src, cache)
1321
1322     def test_cache_single_file_already_cached(self) -> None:
1323         mode = DEFAULT_MODE
1324         with cache_dir() as workspace:
1325             src = (workspace / "test.py").resolve()
1326             with src.open("w") as fobj:
1327                 fobj.write("print('hello')")
1328             black.write_cache({}, [src], mode)
1329             self.invokeBlack([str(src)])
1330             with src.open("r") as fobj:
1331                 self.assertEqual(fobj.read(), "print('hello')")
1332
1333     @event_loop()
1334     def test_cache_multiple_files(self) -> None:
1335         mode = DEFAULT_MODE
1336         with cache_dir() as workspace, patch(
1337             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1338         ):
1339             one = (workspace / "one.py").resolve()
1340             with one.open("w") as fobj:
1341                 fobj.write("print('hello')")
1342             two = (workspace / "two.py").resolve()
1343             with two.open("w") as fobj:
1344                 fobj.write("print('hello')")
1345             black.write_cache({}, [one], mode)
1346             self.invokeBlack([str(workspace)])
1347             with one.open("r") as fobj:
1348                 self.assertEqual(fobj.read(), "print('hello')")
1349             with two.open("r") as fobj:
1350                 self.assertEqual(fobj.read(), 'print("hello")\n')
1351             cache = black.read_cache(mode)
1352             self.assertIn(one, cache)
1353             self.assertIn(two, cache)
1354
1355     def test_no_cache_when_writeback_diff(self) -> None:
1356         mode = DEFAULT_MODE
1357         with cache_dir() as workspace:
1358             src = (workspace / "test.py").resolve()
1359             with src.open("w") as fobj:
1360                 fobj.write("print('hello')")
1361             self.invokeBlack([str(src), "--diff"])
1362             cache_file = black.get_cache_file(mode)
1363             self.assertFalse(cache_file.exists())
1364
1365     def test_no_cache_when_stdin(self) -> None:
1366         mode = DEFAULT_MODE
1367         with cache_dir():
1368             result = CliRunner().invoke(
1369                 black.main, ["-"], input=BytesIO(b"print('hello')")
1370             )
1371             self.assertEqual(result.exit_code, 0)
1372             cache_file = black.get_cache_file(mode)
1373             self.assertFalse(cache_file.exists())
1374
1375     def test_read_cache_no_cachefile(self) -> None:
1376         mode = DEFAULT_MODE
1377         with cache_dir():
1378             self.assertEqual(black.read_cache(mode), {})
1379
1380     def test_write_cache_read_cache(self) -> None:
1381         mode = DEFAULT_MODE
1382         with cache_dir() as workspace:
1383             src = (workspace / "test.py").resolve()
1384             src.touch()
1385             black.write_cache({}, [src], mode)
1386             cache = black.read_cache(mode)
1387             self.assertIn(src, cache)
1388             self.assertEqual(cache[src], black.get_cache_info(src))
1389
1390     def test_filter_cached(self) -> None:
1391         with TemporaryDirectory() as workspace:
1392             path = Path(workspace)
1393             uncached = (path / "uncached").resolve()
1394             cached = (path / "cached").resolve()
1395             cached_but_changed = (path / "changed").resolve()
1396             uncached.touch()
1397             cached.touch()
1398             cached_but_changed.touch()
1399             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1400             todo, done = black.filter_cached(
1401                 cache, {uncached, cached, cached_but_changed}
1402             )
1403             self.assertEqual(todo, {uncached, cached_but_changed})
1404             self.assertEqual(done, {cached})
1405
1406     def test_write_cache_creates_directory_if_needed(self) -> None:
1407         mode = DEFAULT_MODE
1408         with cache_dir(exists=False) as workspace:
1409             self.assertFalse(workspace.exists())
1410             black.write_cache({}, [], mode)
1411             self.assertTrue(workspace.exists())
1412
1413     @event_loop()
1414     def test_failed_formatting_does_not_get_cached(self) -> None:
1415         mode = DEFAULT_MODE
1416         with cache_dir() as workspace, patch(
1417             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1418         ):
1419             failing = (workspace / "failing.py").resolve()
1420             with failing.open("w") as fobj:
1421                 fobj.write("not actually python")
1422             clean = (workspace / "clean.py").resolve()
1423             with clean.open("w") as fobj:
1424                 fobj.write('print("hello")\n')
1425             self.invokeBlack([str(workspace)], exit_code=123)
1426             cache = black.read_cache(mode)
1427             self.assertNotIn(failing, cache)
1428             self.assertIn(clean, cache)
1429
1430     def test_write_cache_write_fail(self) -> None:
1431         mode = DEFAULT_MODE
1432         with cache_dir(), patch.object(Path, "open") as mock:
1433             mock.side_effect = OSError
1434             black.write_cache({}, [], mode)
1435
1436     @event_loop()
1437     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1438     def test_works_in_mono_process_only_environment(self) -> None:
1439         with cache_dir() as workspace:
1440             for f in [
1441                 (workspace / "one.py").resolve(),
1442                 (workspace / "two.py").resolve(),
1443             ]:
1444                 f.write_text('print("hello")\n')
1445             self.invokeBlack([str(workspace)])
1446
1447     @event_loop()
1448     def test_check_diff_use_together(self) -> None:
1449         with cache_dir():
1450             # Files which will be reformatted.
1451             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1452             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1453             # Files which will not be reformatted.
1454             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1455             self.invokeBlack([str(src2), "--diff", "--check"])
1456             # Multi file command.
1457             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1458
1459     def test_no_files(self) -> None:
1460         with cache_dir():
1461             # Without an argument, black exits with error code 0.
1462             self.invokeBlack([])
1463
1464     def test_broken_symlink(self) -> None:
1465         with cache_dir() as workspace:
1466             symlink = workspace / "broken_link.py"
1467             try:
1468                 symlink.symlink_to("nonexistent.py")
1469             except OSError as e:
1470                 self.skipTest(f"Can't create symlinks: {e}")
1471             self.invokeBlack([str(workspace.resolve())])
1472
1473     def test_read_cache_line_lengths(self) -> None:
1474         mode = DEFAULT_MODE
1475         short_mode = replace(DEFAULT_MODE, line_length=1)
1476         with cache_dir() as workspace:
1477             path = (workspace / "file.py").resolve()
1478             path.touch()
1479             black.write_cache({}, [path], mode)
1480             one = black.read_cache(mode)
1481             self.assertIn(path, one)
1482             two = black.read_cache(short_mode)
1483             self.assertNotIn(path, two)
1484
1485     def test_tricky_unicode_symbols(self) -> None:
1486         source, expected = read_data("tricky_unicode_symbols")
1487         actual = fs(source)
1488         self.assertFormatEqual(expected, actual)
1489         black.assert_equivalent(source, actual)
1490         black.assert_stable(source, actual, DEFAULT_MODE)
1491
1492     def test_single_file_force_pyi(self) -> None:
1493         reg_mode = DEFAULT_MODE
1494         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1495         contents, expected = read_data("force_pyi")
1496         with cache_dir() as workspace:
1497             path = (workspace / "file.py").resolve()
1498             with open(path, "w") as fh:
1499                 fh.write(contents)
1500             self.invokeBlack([str(path), "--pyi"])
1501             with open(path, "r") as fh:
1502                 actual = fh.read()
1503             # verify cache with --pyi is separate
1504             pyi_cache = black.read_cache(pyi_mode)
1505             self.assertIn(path, pyi_cache)
1506             normal_cache = black.read_cache(reg_mode)
1507             self.assertNotIn(path, normal_cache)
1508         self.assertEqual(actual, expected)
1509
1510     @event_loop()
1511     def test_multi_file_force_pyi(self) -> None:
1512         reg_mode = DEFAULT_MODE
1513         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1514         contents, expected = read_data("force_pyi")
1515         with cache_dir() as workspace:
1516             paths = [
1517                 (workspace / "file1.py").resolve(),
1518                 (workspace / "file2.py").resolve(),
1519             ]
1520             for path in paths:
1521                 with open(path, "w") as fh:
1522                     fh.write(contents)
1523             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1524             for path in paths:
1525                 with open(path, "r") as fh:
1526                     actual = fh.read()
1527                 self.assertEqual(actual, expected)
1528             # verify cache with --pyi is separate
1529             pyi_cache = black.read_cache(pyi_mode)
1530             normal_cache = black.read_cache(reg_mode)
1531             for path in paths:
1532                 self.assertIn(path, pyi_cache)
1533                 self.assertNotIn(path, normal_cache)
1534
1535     def test_pipe_force_pyi(self) -> None:
1536         source, expected = read_data("force_pyi")
1537         result = CliRunner().invoke(
1538             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1539         )
1540         self.assertEqual(result.exit_code, 0)
1541         actual = result.output
1542         self.assertFormatEqual(actual, expected)
1543
1544     def test_single_file_force_py36(self) -> None:
1545         reg_mode = DEFAULT_MODE
1546         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1547         source, expected = read_data("force_py36")
1548         with cache_dir() as workspace:
1549             path = (workspace / "file.py").resolve()
1550             with open(path, "w") as fh:
1551                 fh.write(source)
1552             self.invokeBlack([str(path), *PY36_ARGS])
1553             with open(path, "r") as fh:
1554                 actual = fh.read()
1555             # verify cache with --target-version is separate
1556             py36_cache = black.read_cache(py36_mode)
1557             self.assertIn(path, py36_cache)
1558             normal_cache = black.read_cache(reg_mode)
1559             self.assertNotIn(path, normal_cache)
1560         self.assertEqual(actual, expected)
1561
1562     @event_loop()
1563     def test_multi_file_force_py36(self) -> None:
1564         reg_mode = DEFAULT_MODE
1565         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1566         source, expected = read_data("force_py36")
1567         with cache_dir() as workspace:
1568             paths = [
1569                 (workspace / "file1.py").resolve(),
1570                 (workspace / "file2.py").resolve(),
1571             ]
1572             for path in paths:
1573                 with open(path, "w") as fh:
1574                     fh.write(source)
1575             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1576             for path in paths:
1577                 with open(path, "r") as fh:
1578                     actual = fh.read()
1579                 self.assertEqual(actual, expected)
1580             # verify cache with --target-version is separate
1581             pyi_cache = black.read_cache(py36_mode)
1582             normal_cache = black.read_cache(reg_mode)
1583             for path in paths:
1584                 self.assertIn(path, pyi_cache)
1585                 self.assertNotIn(path, normal_cache)
1586
1587     def test_collections(self) -> None:
1588         source, expected = read_data("collections")
1589         actual = fs(source)
1590         self.assertFormatEqual(expected, actual)
1591         black.assert_equivalent(source, actual)
1592         black.assert_stable(source, actual, DEFAULT_MODE)
1593
1594     def test_pipe_force_py36(self) -> None:
1595         source, expected = read_data("force_py36")
1596         result = CliRunner().invoke(
1597             black.main,
1598             ["-", "-q", "--target-version=py36"],
1599             input=BytesIO(source.encode("utf8")),
1600         )
1601         self.assertEqual(result.exit_code, 0)
1602         actual = result.output
1603         self.assertFormatEqual(actual, expected)
1604
1605     def test_include_exclude(self) -> None:
1606         path = THIS_DIR / "data" / "include_exclude_tests"
1607         include = re.compile(r"\.pyi?$")
1608         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1609         report = black.Report()
1610         gitignore = PathSpec.from_lines("gitwildmatch", [])
1611         sources: List[Path] = []
1612         expected = [
1613             Path(path / "b/dont_exclude/a.py"),
1614             Path(path / "b/dont_exclude/a.pyi"),
1615         ]
1616         this_abs = THIS_DIR.resolve()
1617         sources.extend(
1618             black.gen_python_files(
1619                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1620             )
1621         )
1622         self.assertEqual(sorted(expected), sorted(sources))
1623
1624     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1625     def test_exclude_for_issue_1572(self) -> None:
1626         # Exclude shouldn't touch files that were explicitly given to Black through the
1627         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1628         # https://github.com/psf/black/issues/1572
1629         path = THIS_DIR / "data" / "include_exclude_tests"
1630         include = ""
1631         exclude = r"/exclude/|a\.py"
1632         src = str(path / "b/exclude/a.py")
1633         report = black.Report()
1634         expected = [Path(path / "b/exclude/a.py")]
1635         sources = list(
1636             black.get_sources(
1637                 ctx=FakeContext(),
1638                 src=(src,),
1639                 quiet=True,
1640                 verbose=False,
1641                 include=include,
1642                 exclude=exclude,
1643                 force_exclude=None,
1644                 report=report,
1645             )
1646         )
1647         self.assertEqual(sorted(expected), sorted(sources))
1648
1649     def test_gitignore_exclude(self) -> None:
1650         path = THIS_DIR / "data" / "include_exclude_tests"
1651         include = re.compile(r"\.pyi?$")
1652         exclude = re.compile(r"")
1653         report = black.Report()
1654         gitignore = PathSpec.from_lines(
1655             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1656         )
1657         sources: List[Path] = []
1658         expected = [
1659             Path(path / "b/dont_exclude/a.py"),
1660             Path(path / "b/dont_exclude/a.pyi"),
1661         ]
1662         this_abs = THIS_DIR.resolve()
1663         sources.extend(
1664             black.gen_python_files(
1665                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1666             )
1667         )
1668         self.assertEqual(sorted(expected), sorted(sources))
1669
1670     def test_empty_include(self) -> None:
1671         path = THIS_DIR / "data" / "include_exclude_tests"
1672         report = black.Report()
1673         gitignore = PathSpec.from_lines("gitwildmatch", [])
1674         empty = re.compile(r"")
1675         sources: List[Path] = []
1676         expected = [
1677             Path(path / "b/exclude/a.pie"),
1678             Path(path / "b/exclude/a.py"),
1679             Path(path / "b/exclude/a.pyi"),
1680             Path(path / "b/dont_exclude/a.pie"),
1681             Path(path / "b/dont_exclude/a.py"),
1682             Path(path / "b/dont_exclude/a.pyi"),
1683             Path(path / "b/.definitely_exclude/a.pie"),
1684             Path(path / "b/.definitely_exclude/a.py"),
1685             Path(path / "b/.definitely_exclude/a.pyi"),
1686         ]
1687         this_abs = THIS_DIR.resolve()
1688         sources.extend(
1689             black.gen_python_files(
1690                 path.iterdir(),
1691                 this_abs,
1692                 empty,
1693                 re.compile(black.DEFAULT_EXCLUDES),
1694                 None,
1695                 report,
1696                 gitignore,
1697             )
1698         )
1699         self.assertEqual(sorted(expected), sorted(sources))
1700
1701     def test_empty_exclude(self) -> None:
1702         path = THIS_DIR / "data" / "include_exclude_tests"
1703         report = black.Report()
1704         gitignore = PathSpec.from_lines("gitwildmatch", [])
1705         empty = re.compile(r"")
1706         sources: List[Path] = []
1707         expected = [
1708             Path(path / "b/dont_exclude/a.py"),
1709             Path(path / "b/dont_exclude/a.pyi"),
1710             Path(path / "b/exclude/a.py"),
1711             Path(path / "b/exclude/a.pyi"),
1712             Path(path / "b/.definitely_exclude/a.py"),
1713             Path(path / "b/.definitely_exclude/a.pyi"),
1714         ]
1715         this_abs = THIS_DIR.resolve()
1716         sources.extend(
1717             black.gen_python_files(
1718                 path.iterdir(),
1719                 this_abs,
1720                 re.compile(black.DEFAULT_INCLUDES),
1721                 empty,
1722                 None,
1723                 report,
1724                 gitignore,
1725             )
1726         )
1727         self.assertEqual(sorted(expected), sorted(sources))
1728
1729     def test_invalid_include_exclude(self) -> None:
1730         for option in ["--include", "--exclude"]:
1731             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1732
1733     def test_preserves_line_endings(self) -> None:
1734         with TemporaryDirectory() as workspace:
1735             test_file = Path(workspace) / "test.py"
1736             for nl in ["\n", "\r\n"]:
1737                 contents = nl.join(["def f(  ):", "    pass"])
1738                 test_file.write_bytes(contents.encode())
1739                 ff(test_file, write_back=black.WriteBack.YES)
1740                 updated_contents: bytes = test_file.read_bytes()
1741                 self.assertIn(nl.encode(), updated_contents)
1742                 if nl == "\n":
1743                     self.assertNotIn(b"\r\n", updated_contents)
1744
1745     def test_preserves_line_endings_via_stdin(self) -> None:
1746         for nl in ["\n", "\r\n"]:
1747             contents = nl.join(["def f(  ):", "    pass"])
1748             runner = BlackRunner()
1749             result = runner.invoke(
1750                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1751             )
1752             self.assertEqual(result.exit_code, 0)
1753             output = runner.stdout_bytes
1754             self.assertIn(nl.encode("utf8"), output)
1755             if nl == "\n":
1756                 self.assertNotIn(b"\r\n", output)
1757
1758     def test_assert_equivalent_different_asts(self) -> None:
1759         with self.assertRaises(AssertionError):
1760             black.assert_equivalent("{}", "None")
1761
1762     def test_symlink_out_of_root_directory(self) -> None:
1763         path = MagicMock()
1764         root = THIS_DIR.resolve()
1765         child = MagicMock()
1766         include = re.compile(black.DEFAULT_INCLUDES)
1767         exclude = re.compile(black.DEFAULT_EXCLUDES)
1768         report = black.Report()
1769         gitignore = PathSpec.from_lines("gitwildmatch", [])
1770         # `child` should behave like a symlink which resolved path is clearly
1771         # outside of the `root` directory.
1772         path.iterdir.return_value = [child]
1773         child.resolve.return_value = Path("/a/b/c")
1774         child.as_posix.return_value = "/a/b/c"
1775         child.is_symlink.return_value = True
1776         try:
1777             list(
1778                 black.gen_python_files(
1779                     path.iterdir(), root, include, exclude, None, report, gitignore
1780                 )
1781             )
1782         except ValueError as ve:
1783             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1784         path.iterdir.assert_called_once()
1785         child.resolve.assert_called_once()
1786         child.is_symlink.assert_called_once()
1787         # `child` should behave like a strange file which resolved path is clearly
1788         # outside of the `root` directory.
1789         child.is_symlink.return_value = False
1790         with self.assertRaises(ValueError):
1791             list(
1792                 black.gen_python_files(
1793                     path.iterdir(), root, include, exclude, None, report, gitignore
1794                 )
1795             )
1796         path.iterdir.assert_called()
1797         self.assertEqual(path.iterdir.call_count, 2)
1798         child.resolve.assert_called()
1799         self.assertEqual(child.resolve.call_count, 2)
1800         child.is_symlink.assert_called()
1801         self.assertEqual(child.is_symlink.call_count, 2)
1802
1803     def test_shhh_click(self) -> None:
1804         try:
1805             from click import _unicodefun  # type: ignore
1806         except ModuleNotFoundError:
1807             self.skipTest("Incompatible Click version")
1808         if not hasattr(_unicodefun, "_verify_python3_env"):
1809             self.skipTest("Incompatible Click version")
1810         # First, let's see if Click is crashing with a preferred ASCII charset.
1811         with patch("locale.getpreferredencoding") as gpe:
1812             gpe.return_value = "ASCII"
1813             with self.assertRaises(RuntimeError):
1814                 _unicodefun._verify_python3_env()
1815         # Now, let's silence Click...
1816         black.patch_click()
1817         # ...and confirm it's silent.
1818         with patch("locale.getpreferredencoding") as gpe:
1819             gpe.return_value = "ASCII"
1820             try:
1821                 _unicodefun._verify_python3_env()
1822             except RuntimeError as re:
1823                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1824
1825     def test_root_logger_not_used_directly(self) -> None:
1826         def fail(*args: Any, **kwargs: Any) -> None:
1827             self.fail("Record created with root logger")
1828
1829         with patch.multiple(
1830             logging.root,
1831             debug=fail,
1832             info=fail,
1833             warning=fail,
1834             error=fail,
1835             critical=fail,
1836             log=fail,
1837         ):
1838             ff(THIS_FILE)
1839
1840     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1841     def test_blackd_main(self) -> None:
1842         with patch("blackd.web.run_app"):
1843             result = CliRunner().invoke(blackd.main, [])
1844             if result.exception is not None:
1845                 raise result.exception
1846             self.assertEqual(result.exit_code, 0)
1847
1848     def test_invalid_config_return_code(self) -> None:
1849         tmp_file = Path(black.dump_to_file())
1850         try:
1851             tmp_config = Path(black.dump_to_file())
1852             tmp_config.unlink()
1853             args = ["--config", str(tmp_config), str(tmp_file)]
1854             self.invokeBlack(args, exit_code=2, ignore_config=False)
1855         finally:
1856             tmp_file.unlink()
1857
1858     def test_parse_pyproject_toml(self) -> None:
1859         test_toml_file = THIS_DIR / "test.toml"
1860         config = black.parse_pyproject_toml(str(test_toml_file))
1861         self.assertEqual(config["verbose"], 1)
1862         self.assertEqual(config["check"], "no")
1863         self.assertEqual(config["diff"], "y")
1864         self.assertEqual(config["color"], True)
1865         self.assertEqual(config["line_length"], 79)
1866         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1867         self.assertEqual(config["exclude"], r"\.pyi?$")
1868         self.assertEqual(config["include"], r"\.py?$")
1869
1870     def test_read_pyproject_toml(self) -> None:
1871         test_toml_file = THIS_DIR / "test.toml"
1872         fake_ctx = FakeContext()
1873         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1874         config = fake_ctx.default_map
1875         self.assertEqual(config["verbose"], "1")
1876         self.assertEqual(config["check"], "no")
1877         self.assertEqual(config["diff"], "y")
1878         self.assertEqual(config["color"], "True")
1879         self.assertEqual(config["line_length"], "79")
1880         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1881         self.assertEqual(config["exclude"], r"\.pyi?$")
1882         self.assertEqual(config["include"], r"\.py?$")
1883
1884     def test_find_project_root(self) -> None:
1885         with TemporaryDirectory() as workspace:
1886             root = Path(workspace)
1887             test_dir = root / "test"
1888             test_dir.mkdir()
1889
1890             src_dir = root / "src"
1891             src_dir.mkdir()
1892
1893             root_pyproject = root / "pyproject.toml"
1894             root_pyproject.touch()
1895             src_pyproject = src_dir / "pyproject.toml"
1896             src_pyproject.touch()
1897             src_python = src_dir / "foo.py"
1898             src_python.touch()
1899
1900             self.assertEqual(
1901                 black.find_project_root((src_dir, test_dir)), root.resolve()
1902             )
1903             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1904             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1905
1906
1907 class BlackDTestCase(AioHTTPTestCase):
1908     async def get_application(self) -> web.Application:
1909         return blackd.make_app()
1910
1911     # TODO: remove these decorators once the below is released
1912     # https://github.com/aio-libs/aiohttp/pull/3727
1913     @skip_if_exception("ClientOSError")
1914     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1915     @unittest_run_loop
1916     async def test_blackd_request_needs_formatting(self) -> None:
1917         response = await self.client.post("/", data=b"print('hello world')")
1918         self.assertEqual(response.status, 200)
1919         self.assertEqual(response.charset, "utf8")
1920         self.assertEqual(await response.read(), b'print("hello world")\n')
1921
1922     @skip_if_exception("ClientOSError")
1923     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1924     @unittest_run_loop
1925     async def test_blackd_request_no_change(self) -> None:
1926         response = await self.client.post("/", data=b'print("hello world")\n')
1927         self.assertEqual(response.status, 204)
1928         self.assertEqual(await response.read(), b"")
1929
1930     @skip_if_exception("ClientOSError")
1931     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1932     @unittest_run_loop
1933     async def test_blackd_request_syntax_error(self) -> None:
1934         response = await self.client.post("/", data=b"what even ( is")
1935         self.assertEqual(response.status, 400)
1936         content = await response.text()
1937         self.assertTrue(
1938             content.startswith("Cannot parse"),
1939             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1940         )
1941
1942     @skip_if_exception("ClientOSError")
1943     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1944     @unittest_run_loop
1945     async def test_blackd_unsupported_version(self) -> None:
1946         response = await self.client.post(
1947             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1948         )
1949         self.assertEqual(response.status, 501)
1950
1951     @skip_if_exception("ClientOSError")
1952     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1953     @unittest_run_loop
1954     async def test_blackd_supported_version(self) -> None:
1955         response = await self.client.post(
1956             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1957         )
1958         self.assertEqual(response.status, 200)
1959
1960     @skip_if_exception("ClientOSError")
1961     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1962     @unittest_run_loop
1963     async def test_blackd_invalid_python_variant(self) -> None:
1964         async def check(header_value: str, expected_status: int = 400) -> None:
1965             response = await self.client.post(
1966                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1967             )
1968             self.assertEqual(response.status, expected_status)
1969
1970         await check("lol")
1971         await check("ruby3.5")
1972         await check("pyi3.6")
1973         await check("py1.5")
1974         await check("2.8")
1975         await check("py2.8")
1976         await check("3.0")
1977         await check("pypy3.0")
1978         await check("jython3.4")
1979
1980     @skip_if_exception("ClientOSError")
1981     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1982     @unittest_run_loop
1983     async def test_blackd_pyi(self) -> None:
1984         source, expected = read_data("stub.pyi")
1985         response = await self.client.post(
1986             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1987         )
1988         self.assertEqual(response.status, 200)
1989         self.assertEqual(await response.text(), expected)
1990
1991     @skip_if_exception("ClientOSError")
1992     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1993     @unittest_run_loop
1994     async def test_blackd_diff(self) -> None:
1995         diff_header = re.compile(
1996             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1997         )
1998
1999         source, _ = read_data("blackd_diff.py")
2000         expected, _ = read_data("blackd_diff.diff")
2001
2002         response = await self.client.post(
2003             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
2004         )
2005         self.assertEqual(response.status, 200)
2006
2007         actual = await response.text()
2008         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2009         self.assertEqual(actual, expected)
2010
2011     @skip_if_exception("ClientOSError")
2012     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2013     @unittest_run_loop
2014     async def test_blackd_python_variant(self) -> None:
2015         code = (
2016             "def f(\n"
2017             "    and_has_a_bunch_of,\n"
2018             "    very_long_arguments_too,\n"
2019             "    and_lots_of_them_as_well_lol,\n"
2020             "    **and_very_long_keyword_arguments\n"
2021             "):\n"
2022             "    pass\n"
2023         )
2024
2025         async def check(header_value: str, expected_status: int) -> None:
2026             response = await self.client.post(
2027                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2028             )
2029             self.assertEqual(
2030                 response.status, expected_status, msg=await response.text()
2031             )
2032
2033         await check("3.6", 200)
2034         await check("py3.6", 200)
2035         await check("3.6,3.7", 200)
2036         await check("3.6,py3.7", 200)
2037         await check("py36,py37", 200)
2038         await check("36", 200)
2039         await check("3.6.4", 200)
2040
2041         await check("2", 204)
2042         await check("2.7", 204)
2043         await check("py2.7", 204)
2044         await check("3.4", 204)
2045         await check("py3.4", 204)
2046         await check("py34,py36", 204)
2047         await check("34", 204)
2048
2049     @skip_if_exception("ClientOSError")
2050     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2051     @unittest_run_loop
2052     async def test_blackd_line_length(self) -> None:
2053         response = await self.client.post(
2054             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2055         )
2056         self.assertEqual(response.status, 200)
2057
2058     @skip_if_exception("ClientOSError")
2059     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2060     @unittest_run_loop
2061     async def test_blackd_invalid_line_length(self) -> None:
2062         response = await self.client.post(
2063             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2064         )
2065         self.assertEqual(response.status, 400)
2066
2067     @skip_if_exception("ClientOSError")
2068     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2069     @unittest_run_loop
2070     async def test_blackd_response_black_version_header(self) -> None:
2071         response = await self.client.post("/")
2072         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2073
2074
2075 with open(black.__file__, "r", encoding="utf-8") as _bf:
2076     black_source_lines = _bf.readlines()
2077
2078
2079 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2080     """Show function calls `from black/__init__.py` as they happen.
2081
2082     Register this with `sys.settrace()` in a test you're debugging.
2083     """
2084     if event != "call":
2085         return tracefunc
2086
2087     stack = len(inspect.stack()) - 19
2088     filename = frame.f_code.co_filename
2089     lineno = frame.f_lineno
2090     func_sig_lineno = lineno - 1
2091     funcname = black_source_lines[func_sig_lineno].strip()
2092     while funcname.startswith("@"):
2093         func_sig_lineno += 1
2094         funcname = black_source_lines[func_sig_lineno].strip()
2095     if "black/__init__.py" in filename:
2096         print(f"{' ' * stack}{lineno}:{funcname}")
2097     return tracefunc
2098
2099
2100 if __name__ == "__main__":
2101     unittest.main(module="test_black")