]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

6705490ea13624aab062e12af969346f24eb4f2a
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from dataclasses import replace
7 from functools import partial
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 import regex as re
13 import sys
14 from tempfile import TemporaryDirectory
15 import types
16 from typing import (
17     Any,
18     BinaryIO,
19     Callable,
20     Dict,
21     Generator,
22     List,
23     Tuple,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 try:
38     import blackd
39     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
40     from aiohttp import web
41 except ImportError:
42     has_blackd_deps = False
43 else:
44     has_blackd_deps = True
45
46 from pathspec import PathSpec
47
48 # Import other test classes
49 from .test_primer import PrimerCLITests  # noqa: F401
50
51
52 DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
53 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
54 fs = partial(black.format_str, mode=DEFAULT_MODE)
55 THIS_FILE = Path(__file__)
56 THIS_DIR = THIS_FILE.parent
57 PROJECT_ROOT = THIS_DIR.parent
58 DETERMINISTIC_HEADER = "[Deterministic header]"
59 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
60 PY36_ARGS = [
61     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
62 ]
63 T = TypeVar("T")
64 R = TypeVar("R")
65
66
67 def dump_to_stderr(*output: str) -> str:
68     return "\n" + "\n".join(output) + "\n"
69
70
71 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
72     """read_data('test_name') -> 'input', 'output'"""
73     if not name.endswith((".py", ".pyi", ".out", ".diff")):
74         name += ".py"
75     _input: List[str] = []
76     _output: List[str] = []
77     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
78     with open(base_dir / name, "r", encoding="utf8") as test:
79         lines = test.readlines()
80     result = _input
81     for line in lines:
82         line = line.replace(EMPTY_LINE, "")
83         if line.rstrip() == "# output":
84             result = _output
85             continue
86
87         result.append(line)
88     if _input and not _output:
89         # If there's no output marker, treat the entire file as already pre-formatted.
90         _output = _input[:]
91     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
92
93
94 @contextmanager
95 def cache_dir(exists: bool = True) -> Iterator[Path]:
96     with TemporaryDirectory() as workspace:
97         cache_dir = Path(workspace)
98         if not exists:
99             cache_dir = cache_dir / "new"
100         with patch("black.CACHE_DIR", cache_dir):
101             yield cache_dir
102
103
104 @contextmanager
105 def event_loop() -> Iterator[None]:
106     policy = asyncio.get_event_loop_policy()
107     loop = policy.new_event_loop()
108     asyncio.set_event_loop(loop)
109     try:
110         yield
111
112     finally:
113         loop.close()
114
115
116 @contextmanager
117 def skip_if_exception(e: str) -> Iterator[None]:
118     try:
119         yield
120     except Exception as exc:
121         if exc.__class__.__name__ == e:
122             unittest.skip(f"Encountered expected exception {exc}, skipping")
123         else:
124             raise
125
126
127 class FakeContext(click.Context):
128     """A fake click Context for when calling functions that need it."""
129
130     def __init__(self) -> None:
131         self.default_map: Dict[str, Any] = {}
132
133
134 class FakeParameter(click.Parameter):
135     """A fake click Parameter for when calling functions that need it."""
136
137     def __init__(self) -> None:
138         pass
139
140
141 class BlackRunner(CliRunner):
142     """Modify CliRunner so that stderr is not merged with stdout.
143
144     This is a hack that can be removed once we depend on Click 7.x"""
145
146     def __init__(self) -> None:
147         self.stderrbuf = BytesIO()
148         self.stdoutbuf = BytesIO()
149         self.stdout_bytes = b""
150         self.stderr_bytes = b""
151         super().__init__()
152
153     @contextmanager
154     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
155         with super().isolation(*args, **kwargs) as output:
156             try:
157                 hold_stderr = sys.stderr
158                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
159                 yield output
160             finally:
161                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
162                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
163                 sys.stderr = hold_stderr
164
165
166 class BlackTestCase(unittest.TestCase):
167     maxDiff = None
168     _diffThreshold = 2 ** 20
169
170     def assertFormatEqual(self, expected: str, actual: str) -> None:
171         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
172             bdv: black.DebugVisitor[Any]
173             black.out("Expected tree:", fg="green")
174             try:
175                 exp_node = black.lib2to3_parse(expected)
176                 bdv = black.DebugVisitor()
177                 list(bdv.visit(exp_node))
178             except Exception as ve:
179                 black.err(str(ve))
180             black.out("Actual tree:", fg="red")
181             try:
182                 exp_node = black.lib2to3_parse(actual)
183                 bdv = black.DebugVisitor()
184                 list(bdv.visit(exp_node))
185             except Exception as ve:
186                 black.err(str(ve))
187         self.assertMultiLineEqual(expected, actual)
188
189     def invokeBlack(
190         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
191     ) -> None:
192         runner = BlackRunner()
193         if ignore_config:
194             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
195         result = runner.invoke(black.main, args)
196         self.assertEqual(
197             result.exit_code,
198             exit_code,
199             msg=(
200                 f"Failed with args: {args}\n"
201                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
202                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
203                 f"exception: {result.exception}"
204             ),
205         )
206
207     @patch("black.dump_to_file", dump_to_stderr)
208     def checkSourceFile(self, name: str, mode: black.FileMode = DEFAULT_MODE) -> None:
209         path = THIS_DIR.parent / name
210         source, expected = read_data(str(path), data=False)
211         actual = fs(source, mode=mode)
212         self.assertFormatEqual(expected, actual)
213         black.assert_equivalent(source, actual)
214         black.assert_stable(source, actual, mode)
215         self.assertFalse(ff(path))
216
217     @patch("black.dump_to_file", dump_to_stderr)
218     def test_empty(self) -> None:
219         source = expected = ""
220         actual = fs(source)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, DEFAULT_MODE)
224
225     def test_empty_ff(self) -> None:
226         expected = ""
227         tmp_file = Path(black.dump_to_file())
228         try:
229             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
230             with open(tmp_file, encoding="utf8") as f:
231                 actual = f.read()
232         finally:
233             os.unlink(tmp_file)
234         self.assertFormatEqual(expected, actual)
235
236     def test_self(self) -> None:
237         self.checkSourceFile("tests/test_black.py")
238
239     def test_black(self) -> None:
240         self.checkSourceFile("src/black/__init__.py")
241
242     def test_pygram(self) -> None:
243         self.checkSourceFile("src/blib2to3/pygram.py")
244
245     def test_pytree(self) -> None:
246         self.checkSourceFile("src/blib2to3/pytree.py")
247
248     def test_conv(self) -> None:
249         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
250
251     def test_driver(self) -> None:
252         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
253
254     def test_grammar(self) -> None:
255         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
256
257     def test_literals(self) -> None:
258         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
259
260     def test_parse(self) -> None:
261         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
262
263     def test_pgen(self) -> None:
264         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
265
266     def test_tokenize(self) -> None:
267         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
268
269     def test_token(self) -> None:
270         self.checkSourceFile("src/blib2to3/pgen2/token.py")
271
272     def test_setup(self) -> None:
273         self.checkSourceFile("setup.py")
274
275     def test_piping(self) -> None:
276         source, expected = read_data("src/black/__init__", data=False)
277         result = BlackRunner().invoke(
278             black.main,
279             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
280             input=BytesIO(source.encode("utf8")),
281         )
282         self.assertEqual(result.exit_code, 0)
283         self.assertFormatEqual(expected, result.output)
284         black.assert_equivalent(source, result.output)
285         black.assert_stable(source, result.output, DEFAULT_MODE)
286
287     def test_piping_diff(self) -> None:
288         diff_header = re.compile(
289             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
290             r"\+\d\d\d\d"
291         )
292         source, _ = read_data("expression.py")
293         expected, _ = read_data("expression.diff")
294         config = THIS_DIR / "data" / "empty_pyproject.toml"
295         args = [
296             "-",
297             "--fast",
298             f"--line-length={black.DEFAULT_LINE_LENGTH}",
299             "--diff",
300             f"--config={config}",
301         ]
302         result = BlackRunner().invoke(
303             black.main, args, input=BytesIO(source.encode("utf8"))
304         )
305         self.assertEqual(result.exit_code, 0)
306         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
307         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
308         self.assertEqual(expected, actual)
309
310     def test_piping_diff_with_color(self) -> None:
311         source, _ = read_data("expression.py")
312         config = THIS_DIR / "data" / "empty_pyproject.toml"
313         args = [
314             "-",
315             "--fast",
316             f"--line-length={black.DEFAULT_LINE_LENGTH}",
317             "--diff",
318             "--color",
319             f"--config={config}",
320         ]
321         result = BlackRunner().invoke(
322             black.main, args, input=BytesIO(source.encode("utf8"))
323         )
324         actual = result.output
325         # Again, the contents are checked in a different test, so only look for colors.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_function(self) -> None:
334         source, expected = read_data("function")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_equivalent(source, actual)
338         black.assert_stable(source, actual, DEFAULT_MODE)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_function2(self) -> None:
342         source, expected = read_data("function2")
343         actual = fs(source)
344         self.assertFormatEqual(expected, actual)
345         black.assert_equivalent(source, actual)
346         black.assert_stable(source, actual, DEFAULT_MODE)
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def _test_wip(self) -> None:
350         source, expected = read_data("wip")
351         sys.settrace(tracefunc)
352         mode = replace(
353             DEFAULT_MODE,
354             experimental_string_processing=False,
355             target_versions={black.TargetVersion.PY38},
356         )
357         actual = fs(source, mode=mode)
358         sys.settrace(None)
359         self.assertFormatEqual(expected, actual)
360         black.assert_equivalent(source, actual)
361         black.assert_stable(source, actual, black.FileMode())
362
363     @patch("black.dump_to_file", dump_to_stderr)
364     def test_function_trailing_comma(self) -> None:
365         source, expected = read_data("function_trailing_comma")
366         actual = fs(source)
367         self.assertFormatEqual(expected, actual)
368         black.assert_equivalent(source, actual)
369         black.assert_stable(source, actual, DEFAULT_MODE)
370
371     @patch("black.dump_to_file", dump_to_stderr)
372     def test_expression(self) -> None:
373         source, expected = read_data("expression")
374         actual = fs(source)
375         self.assertFormatEqual(expected, actual)
376         black.assert_equivalent(source, actual)
377         black.assert_stable(source, actual, DEFAULT_MODE)
378
379     @patch("black.dump_to_file", dump_to_stderr)
380     def test_pep_572(self) -> None:
381         source, expected = read_data("pep_572")
382         actual = fs(source)
383         self.assertFormatEqual(expected, actual)
384         black.assert_stable(source, actual, DEFAULT_MODE)
385         if sys.version_info >= (3, 8):
386             black.assert_equivalent(source, actual)
387
388     def test_pep_572_version_detection(self) -> None:
389         source, _ = read_data("pep_572")
390         root = black.lib2to3_parse(source)
391         features = black.get_features_used(root)
392         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
393         versions = black.detect_target_versions(root)
394         self.assertIn(black.TargetVersion.PY38, versions)
395
396     def test_expression_ff(self) -> None:
397         source, expected = read_data("expression")
398         tmp_file = Path(black.dump_to_file(source))
399         try:
400             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
401             with open(tmp_file, encoding="utf8") as f:
402                 actual = f.read()
403         finally:
404             os.unlink(tmp_file)
405         self.assertFormatEqual(expected, actual)
406         with patch("black.dump_to_file", dump_to_stderr):
407             black.assert_equivalent(source, actual)
408             black.assert_stable(source, actual, DEFAULT_MODE)
409
410     def test_expression_diff(self) -> None:
411         source, _ = read_data("expression.py")
412         expected, _ = read_data("expression.diff")
413         tmp_file = Path(black.dump_to_file(source))
414         diff_header = re.compile(
415             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
416             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
417         )
418         try:
419             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
420             self.assertEqual(result.exit_code, 0)
421         finally:
422             os.unlink(tmp_file)
423         actual = result.output
424         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
425         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
426         if expected != actual:
427             dump = black.dump_to_file(actual)
428             msg = (
429                 "Expected diff isn't equal to the actual. If you made changes to"
430                 " expression.py and this is an anticipated difference, overwrite"
431                 f" tests/data/expression.diff with {dump}"
432             )
433             self.assertEqual(expected, actual, msg)
434
435     def test_expression_diff_with_color(self) -> None:
436         source, _ = read_data("expression.py")
437         expected, _ = read_data("expression.diff")
438         tmp_file = Path(black.dump_to_file(source))
439         try:
440             result = BlackRunner().invoke(
441                 black.main, ["--diff", "--color", str(tmp_file)]
442             )
443         finally:
444             os.unlink(tmp_file)
445         actual = result.output
446         # We check the contents of the diff in `test_expression_diff`. All
447         # we need to check here is that color codes exist in the result.
448         self.assertIn("\033[1;37m", actual)
449         self.assertIn("\033[36m", actual)
450         self.assertIn("\033[32m", actual)
451         self.assertIn("\033[31m", actual)
452         self.assertIn("\033[0m", actual)
453
454     @patch("black.dump_to_file", dump_to_stderr)
455     def test_fstring(self) -> None:
456         source, expected = read_data("fstring")
457         actual = fs(source)
458         self.assertFormatEqual(expected, actual)
459         black.assert_equivalent(source, actual)
460         black.assert_stable(source, actual, DEFAULT_MODE)
461
462     @patch("black.dump_to_file", dump_to_stderr)
463     def test_pep_570(self) -> None:
464         source, expected = read_data("pep_570")
465         actual = fs(source)
466         self.assertFormatEqual(expected, actual)
467         black.assert_stable(source, actual, DEFAULT_MODE)
468         if sys.version_info >= (3, 8):
469             black.assert_equivalent(source, actual)
470
471     def test_detect_pos_only_arguments(self) -> None:
472         source, _ = read_data("pep_570")
473         root = black.lib2to3_parse(source)
474         features = black.get_features_used(root)
475         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
476         versions = black.detect_target_versions(root)
477         self.assertIn(black.TargetVersion.PY38, versions)
478
479     @patch("black.dump_to_file", dump_to_stderr)
480     def test_string_quotes(self) -> None:
481         source, expected = read_data("string_quotes")
482         actual = fs(source)
483         self.assertFormatEqual(expected, actual)
484         black.assert_equivalent(source, actual)
485         black.assert_stable(source, actual, DEFAULT_MODE)
486         mode = replace(DEFAULT_MODE, string_normalization=False)
487         not_normalized = fs(source, mode=mode)
488         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
489         black.assert_equivalent(source, not_normalized)
490         black.assert_stable(source, not_normalized, mode=mode)
491
492     @patch("black.dump_to_file", dump_to_stderr)
493     def test_docstring(self) -> None:
494         source, expected = read_data("docstring")
495         actual = fs(source)
496         self.assertFormatEqual(expected, actual)
497         black.assert_equivalent(source, actual)
498         black.assert_stable(source, actual, DEFAULT_MODE)
499
500     def test_long_strings(self) -> None:
501         """Tests for splitting long strings."""
502         source, expected = read_data("long_strings")
503         actual = fs(source)
504         self.assertFormatEqual(expected, actual)
505         black.assert_equivalent(source, actual)
506         black.assert_stable(source, actual, DEFAULT_MODE)
507
508     def test_long_strings_flag_disabled(self) -> None:
509         """Tests for turning off the string processing logic."""
510         source, expected = read_data("long_strings_flag_disabled")
511         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
512         actual = fs(source, mode=mode)
513         self.assertFormatEqual(expected, actual)
514         black.assert_stable(expected, actual, mode)
515
516     @patch("black.dump_to_file", dump_to_stderr)
517     def test_long_strings__edge_case(self) -> None:
518         """Edge-case tests for splitting long strings."""
519         source, expected = read_data("long_strings__edge_case")
520         actual = fs(source)
521         self.assertFormatEqual(expected, actual)
522         black.assert_equivalent(source, actual)
523         black.assert_stable(source, actual, DEFAULT_MODE)
524
525     @patch("black.dump_to_file", dump_to_stderr)
526     def test_long_strings__regression(self) -> None:
527         """Regression tests for splitting long strings."""
528         source, expected = read_data("long_strings__regression")
529         actual = fs(source)
530         self.assertFormatEqual(expected, actual)
531         black.assert_equivalent(source, actual)
532         black.assert_stable(source, actual, DEFAULT_MODE)
533
534     @patch("black.dump_to_file", dump_to_stderr)
535     def test_slices(self) -> None:
536         source, expected = read_data("slices")
537         actual = fs(source)
538         self.assertFormatEqual(expected, actual)
539         black.assert_equivalent(source, actual)
540         black.assert_stable(source, actual, DEFAULT_MODE)
541
542     @patch("black.dump_to_file", dump_to_stderr)
543     def test_percent_precedence(self) -> None:
544         source, expected = read_data("percent_precedence")
545         actual = fs(source)
546         self.assertFormatEqual(expected, actual)
547         black.assert_equivalent(source, actual)
548         black.assert_stable(source, actual, DEFAULT_MODE)
549
550     @patch("black.dump_to_file", dump_to_stderr)
551     def test_comments(self) -> None:
552         source, expected = read_data("comments")
553         actual = fs(source)
554         self.assertFormatEqual(expected, actual)
555         black.assert_equivalent(source, actual)
556         black.assert_stable(source, actual, DEFAULT_MODE)
557
558     @patch("black.dump_to_file", dump_to_stderr)
559     def test_comments2(self) -> None:
560         source, expected = read_data("comments2")
561         actual = fs(source)
562         self.assertFormatEqual(expected, actual)
563         black.assert_equivalent(source, actual)
564         black.assert_stable(source, actual, DEFAULT_MODE)
565
566     @patch("black.dump_to_file", dump_to_stderr)
567     def test_comments3(self) -> None:
568         source, expected = read_data("comments3")
569         actual = fs(source)
570         self.assertFormatEqual(expected, actual)
571         black.assert_equivalent(source, actual)
572         black.assert_stable(source, actual, DEFAULT_MODE)
573
574     @patch("black.dump_to_file", dump_to_stderr)
575     def test_comments4(self) -> None:
576         source, expected = read_data("comments4")
577         actual = fs(source)
578         self.assertFormatEqual(expected, actual)
579         black.assert_equivalent(source, actual)
580         black.assert_stable(source, actual, DEFAULT_MODE)
581
582     @patch("black.dump_to_file", dump_to_stderr)
583     def test_comments5(self) -> None:
584         source, expected = read_data("comments5")
585         actual = fs(source)
586         self.assertFormatEqual(expected, actual)
587         black.assert_equivalent(source, actual)
588         black.assert_stable(source, actual, DEFAULT_MODE)
589
590     @patch("black.dump_to_file", dump_to_stderr)
591     def test_comments6(self) -> None:
592         source, expected = read_data("comments6")
593         actual = fs(source)
594         self.assertFormatEqual(expected, actual)
595         black.assert_equivalent(source, actual)
596         black.assert_stable(source, actual, DEFAULT_MODE)
597
598     @patch("black.dump_to_file", dump_to_stderr)
599     def test_comments7(self) -> None:
600         source, expected = read_data("comments7")
601         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
602         actual = fs(source, mode=mode)
603         self.assertFormatEqual(expected, actual)
604         black.assert_equivalent(source, actual)
605         black.assert_stable(source, actual, DEFAULT_MODE)
606
607     @patch("black.dump_to_file", dump_to_stderr)
608     def test_comment_after_escaped_newline(self) -> None:
609         source, expected = read_data("comment_after_escaped_newline")
610         actual = fs(source)
611         self.assertFormatEqual(expected, actual)
612         black.assert_equivalent(source, actual)
613         black.assert_stable(source, actual, DEFAULT_MODE)
614
615     @patch("black.dump_to_file", dump_to_stderr)
616     def test_cantfit(self) -> None:
617         source, expected = read_data("cantfit")
618         actual = fs(source)
619         self.assertFormatEqual(expected, actual)
620         black.assert_equivalent(source, actual)
621         black.assert_stable(source, actual, DEFAULT_MODE)
622
623     @patch("black.dump_to_file", dump_to_stderr)
624     def test_import_spacing(self) -> None:
625         source, expected = read_data("import_spacing")
626         actual = fs(source)
627         self.assertFormatEqual(expected, actual)
628         black.assert_equivalent(source, actual)
629         black.assert_stable(source, actual, DEFAULT_MODE)
630
631     @patch("black.dump_to_file", dump_to_stderr)
632     def test_composition(self) -> None:
633         source, expected = read_data("composition")
634         actual = fs(source)
635         self.assertFormatEqual(expected, actual)
636         black.assert_equivalent(source, actual)
637         black.assert_stable(source, actual, DEFAULT_MODE)
638
639     @patch("black.dump_to_file", dump_to_stderr)
640     def test_composition_no_trailing_comma(self) -> None:
641         source, expected = read_data("composition_no_trailing_comma")
642         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
643         actual = fs(source, mode=mode)
644         self.assertFormatEqual(expected, actual)
645         black.assert_equivalent(source, actual)
646         black.assert_stable(source, actual, DEFAULT_MODE)
647
648     @patch("black.dump_to_file", dump_to_stderr)
649     def test_empty_lines(self) -> None:
650         source, expected = read_data("empty_lines")
651         actual = fs(source)
652         self.assertFormatEqual(expected, actual)
653         black.assert_equivalent(source, actual)
654         black.assert_stable(source, actual, DEFAULT_MODE)
655
656     @patch("black.dump_to_file", dump_to_stderr)
657     def test_remove_parens(self) -> None:
658         source, expected = read_data("remove_parens")
659         actual = fs(source)
660         self.assertFormatEqual(expected, actual)
661         black.assert_equivalent(source, actual)
662         black.assert_stable(source, actual, DEFAULT_MODE)
663
664     @patch("black.dump_to_file", dump_to_stderr)
665     def test_string_prefixes(self) -> None:
666         source, expected = read_data("string_prefixes")
667         actual = fs(source)
668         self.assertFormatEqual(expected, actual)
669         black.assert_equivalent(source, actual)
670         black.assert_stable(source, actual, DEFAULT_MODE)
671
672     @patch("black.dump_to_file", dump_to_stderr)
673     def test_numeric_literals(self) -> None:
674         source, expected = read_data("numeric_literals")
675         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
676         actual = fs(source, mode=mode)
677         self.assertFormatEqual(expected, actual)
678         black.assert_equivalent(source, actual)
679         black.assert_stable(source, actual, mode)
680
681     @patch("black.dump_to_file", dump_to_stderr)
682     def test_numeric_literals_ignoring_underscores(self) -> None:
683         source, expected = read_data("numeric_literals_skip_underscores")
684         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
685         actual = fs(source, mode=mode)
686         self.assertFormatEqual(expected, actual)
687         black.assert_equivalent(source, actual)
688         black.assert_stable(source, actual, mode)
689
690     @patch("black.dump_to_file", dump_to_stderr)
691     def test_numeric_literals_py2(self) -> None:
692         source, expected = read_data("numeric_literals_py2")
693         actual = fs(source)
694         self.assertFormatEqual(expected, actual)
695         black.assert_stable(source, actual, DEFAULT_MODE)
696
697     @patch("black.dump_to_file", dump_to_stderr)
698     def test_python2(self) -> None:
699         source, expected = read_data("python2")
700         actual = fs(source)
701         self.assertFormatEqual(expected, actual)
702         black.assert_equivalent(source, actual)
703         black.assert_stable(source, actual, DEFAULT_MODE)
704
705     @patch("black.dump_to_file", dump_to_stderr)
706     def test_python2_print_function(self) -> None:
707         source, expected = read_data("python2_print_function")
708         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
709         actual = fs(source, mode=mode)
710         self.assertFormatEqual(expected, actual)
711         black.assert_equivalent(source, actual)
712         black.assert_stable(source, actual, mode)
713
714     @patch("black.dump_to_file", dump_to_stderr)
715     def test_python2_unicode_literals(self) -> None:
716         source, expected = read_data("python2_unicode_literals")
717         actual = fs(source)
718         self.assertFormatEqual(expected, actual)
719         black.assert_equivalent(source, actual)
720         black.assert_stable(source, actual, DEFAULT_MODE)
721
722     @patch("black.dump_to_file", dump_to_stderr)
723     def test_stub(self) -> None:
724         mode = replace(DEFAULT_MODE, is_pyi=True)
725         source, expected = read_data("stub.pyi")
726         actual = fs(source, mode=mode)
727         self.assertFormatEqual(expected, actual)
728         black.assert_stable(source, actual, mode)
729
730     @patch("black.dump_to_file", dump_to_stderr)
731     def test_async_as_identifier(self) -> None:
732         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
733         source, expected = read_data("async_as_identifier")
734         actual = fs(source)
735         self.assertFormatEqual(expected, actual)
736         major, minor = sys.version_info[:2]
737         if major < 3 or (major <= 3 and minor < 7):
738             black.assert_equivalent(source, actual)
739         black.assert_stable(source, actual, DEFAULT_MODE)
740         # ensure black can parse this when the target is 3.6
741         self.invokeBlack([str(source_path), "--target-version", "py36"])
742         # but not on 3.7, because async/await is no longer an identifier
743         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
744
745     @patch("black.dump_to_file", dump_to_stderr)
746     def test_python37(self) -> None:
747         source_path = (THIS_DIR / "data" / "python37.py").resolve()
748         source, expected = read_data("python37")
749         actual = fs(source)
750         self.assertFormatEqual(expected, actual)
751         major, minor = sys.version_info[:2]
752         if major > 3 or (major == 3 and minor >= 7):
753             black.assert_equivalent(source, actual)
754         black.assert_stable(source, actual, DEFAULT_MODE)
755         # ensure black can parse this when the target is 3.7
756         self.invokeBlack([str(source_path), "--target-version", "py37"])
757         # but not on 3.6, because we use async as a reserved keyword
758         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
759
760     @patch("black.dump_to_file", dump_to_stderr)
761     def test_python38(self) -> None:
762         source, expected = read_data("python38")
763         actual = fs(source)
764         self.assertFormatEqual(expected, actual)
765         major, minor = sys.version_info[:2]
766         if major > 3 or (major == 3 and minor >= 8):
767             black.assert_equivalent(source, actual)
768         black.assert_stable(source, actual, DEFAULT_MODE)
769
770     @patch("black.dump_to_file", dump_to_stderr)
771     def test_fmtonoff(self) -> None:
772         source, expected = read_data("fmtonoff")
773         actual = fs(source)
774         self.assertFormatEqual(expected, actual)
775         black.assert_equivalent(source, actual)
776         black.assert_stable(source, actual, DEFAULT_MODE)
777
778     @patch("black.dump_to_file", dump_to_stderr)
779     def test_fmtonoff2(self) -> None:
780         source, expected = read_data("fmtonoff2")
781         actual = fs(source)
782         self.assertFormatEqual(expected, actual)
783         black.assert_equivalent(source, actual)
784         black.assert_stable(source, actual, DEFAULT_MODE)
785
786     @patch("black.dump_to_file", dump_to_stderr)
787     def test_fmtonoff3(self) -> None:
788         source, expected = read_data("fmtonoff3")
789         actual = fs(source)
790         self.assertFormatEqual(expected, actual)
791         black.assert_equivalent(source, actual)
792         black.assert_stable(source, actual, DEFAULT_MODE)
793
794     @patch("black.dump_to_file", dump_to_stderr)
795     def test_fmtonoff4(self) -> None:
796         source, expected = read_data("fmtonoff4")
797         actual = fs(source)
798         self.assertFormatEqual(expected, actual)
799         black.assert_equivalent(source, actual)
800         black.assert_stable(source, actual, DEFAULT_MODE)
801
802     @patch("black.dump_to_file", dump_to_stderr)
803     def test_remove_empty_parentheses_after_class(self) -> None:
804         source, expected = read_data("class_blank_parentheses")
805         actual = fs(source)
806         self.assertFormatEqual(expected, actual)
807         black.assert_equivalent(source, actual)
808         black.assert_stable(source, actual, DEFAULT_MODE)
809
810     @patch("black.dump_to_file", dump_to_stderr)
811     def test_new_line_between_class_and_code(self) -> None:
812         source, expected = read_data("class_methods_new_line")
813         actual = fs(source)
814         self.assertFormatEqual(expected, actual)
815         black.assert_equivalent(source, actual)
816         black.assert_stable(source, actual, DEFAULT_MODE)
817
818     @patch("black.dump_to_file", dump_to_stderr)
819     def test_bracket_match(self) -> None:
820         source, expected = read_data("bracketmatch")
821         actual = fs(source)
822         self.assertFormatEqual(expected, actual)
823         black.assert_equivalent(source, actual)
824         black.assert_stable(source, actual, DEFAULT_MODE)
825
826     @patch("black.dump_to_file", dump_to_stderr)
827     def test_tuple_assign(self) -> None:
828         source, expected = read_data("tupleassign")
829         actual = fs(source)
830         self.assertFormatEqual(expected, actual)
831         black.assert_equivalent(source, actual)
832         black.assert_stable(source, actual, DEFAULT_MODE)
833
834     @patch("black.dump_to_file", dump_to_stderr)
835     def test_beginning_backslash(self) -> None:
836         source, expected = read_data("beginning_backslash")
837         actual = fs(source)
838         self.assertFormatEqual(expected, actual)
839         black.assert_equivalent(source, actual)
840         black.assert_stable(source, actual, DEFAULT_MODE)
841
842     def test_tab_comment_indentation(self) -> None:
843         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
844         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
845         self.assertFormatEqual(contents_spc, fs(contents_spc))
846         self.assertFormatEqual(contents_spc, fs(contents_tab))
847
848         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
849         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
850         self.assertFormatEqual(contents_spc, fs(contents_spc))
851         self.assertFormatEqual(contents_spc, fs(contents_tab))
852
853         # mixed tabs and spaces (valid Python 2 code)
854         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
855         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
856         self.assertFormatEqual(contents_spc, fs(contents_spc))
857         self.assertFormatEqual(contents_spc, fs(contents_tab))
858
859         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
860         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
861         self.assertFormatEqual(contents_spc, fs(contents_spc))
862         self.assertFormatEqual(contents_spc, fs(contents_tab))
863
864     def test_report_verbose(self) -> None:
865         report = black.Report(verbose=True)
866         out_lines = []
867         err_lines = []
868
869         def out(msg: str, **kwargs: Any) -> None:
870             out_lines.append(msg)
871
872         def err(msg: str, **kwargs: Any) -> None:
873             err_lines.append(msg)
874
875         with patch("black.out", out), patch("black.err", err):
876             report.done(Path("f1"), black.Changed.NO)
877             self.assertEqual(len(out_lines), 1)
878             self.assertEqual(len(err_lines), 0)
879             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
880             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
881             self.assertEqual(report.return_code, 0)
882             report.done(Path("f2"), black.Changed.YES)
883             self.assertEqual(len(out_lines), 2)
884             self.assertEqual(len(err_lines), 0)
885             self.assertEqual(out_lines[-1], "reformatted f2")
886             self.assertEqual(
887                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
888             )
889             report.done(Path("f3"), black.Changed.CACHED)
890             self.assertEqual(len(out_lines), 3)
891             self.assertEqual(len(err_lines), 0)
892             self.assertEqual(
893                 out_lines[-1], "f3 wasn't modified on disk since last run."
894             )
895             self.assertEqual(
896                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
897             )
898             self.assertEqual(report.return_code, 0)
899             report.check = True
900             self.assertEqual(report.return_code, 1)
901             report.check = False
902             report.failed(Path("e1"), "boom")
903             self.assertEqual(len(out_lines), 3)
904             self.assertEqual(len(err_lines), 1)
905             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
906             self.assertEqual(
907                 unstyle(str(report)),
908                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
909                 " reformat.",
910             )
911             self.assertEqual(report.return_code, 123)
912             report.done(Path("f3"), black.Changed.YES)
913             self.assertEqual(len(out_lines), 4)
914             self.assertEqual(len(err_lines), 1)
915             self.assertEqual(out_lines[-1], "reformatted f3")
916             self.assertEqual(
917                 unstyle(str(report)),
918                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
919                 " reformat.",
920             )
921             self.assertEqual(report.return_code, 123)
922             report.failed(Path("e2"), "boom")
923             self.assertEqual(len(out_lines), 4)
924             self.assertEqual(len(err_lines), 2)
925             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
926             self.assertEqual(
927                 unstyle(str(report)),
928                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
929                 " reformat.",
930             )
931             self.assertEqual(report.return_code, 123)
932             report.path_ignored(Path("wat"), "no match")
933             self.assertEqual(len(out_lines), 5)
934             self.assertEqual(len(err_lines), 2)
935             self.assertEqual(out_lines[-1], "wat ignored: no match")
936             self.assertEqual(
937                 unstyle(str(report)),
938                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
939                 " reformat.",
940             )
941             self.assertEqual(report.return_code, 123)
942             report.done(Path("f4"), black.Changed.NO)
943             self.assertEqual(len(out_lines), 6)
944             self.assertEqual(len(err_lines), 2)
945             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
946             self.assertEqual(
947                 unstyle(str(report)),
948                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
949                 " reformat.",
950             )
951             self.assertEqual(report.return_code, 123)
952             report.check = True
953             self.assertEqual(
954                 unstyle(str(report)),
955                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
956                 " would fail to reformat.",
957             )
958             report.check = False
959             report.diff = True
960             self.assertEqual(
961                 unstyle(str(report)),
962                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
963                 " would fail to reformat.",
964             )
965
966     def test_report_quiet(self) -> None:
967         report = black.Report(quiet=True)
968         out_lines = []
969         err_lines = []
970
971         def out(msg: str, **kwargs: Any) -> None:
972             out_lines.append(msg)
973
974         def err(msg: str, **kwargs: Any) -> None:
975             err_lines.append(msg)
976
977         with patch("black.out", out), patch("black.err", err):
978             report.done(Path("f1"), black.Changed.NO)
979             self.assertEqual(len(out_lines), 0)
980             self.assertEqual(len(err_lines), 0)
981             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
982             self.assertEqual(report.return_code, 0)
983             report.done(Path("f2"), black.Changed.YES)
984             self.assertEqual(len(out_lines), 0)
985             self.assertEqual(len(err_lines), 0)
986             self.assertEqual(
987                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
988             )
989             report.done(Path("f3"), black.Changed.CACHED)
990             self.assertEqual(len(out_lines), 0)
991             self.assertEqual(len(err_lines), 0)
992             self.assertEqual(
993                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
994             )
995             self.assertEqual(report.return_code, 0)
996             report.check = True
997             self.assertEqual(report.return_code, 1)
998             report.check = False
999             report.failed(Path("e1"), "boom")
1000             self.assertEqual(len(out_lines), 0)
1001             self.assertEqual(len(err_lines), 1)
1002             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1003             self.assertEqual(
1004                 unstyle(str(report)),
1005                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1006                 " reformat.",
1007             )
1008             self.assertEqual(report.return_code, 123)
1009             report.done(Path("f3"), black.Changed.YES)
1010             self.assertEqual(len(out_lines), 0)
1011             self.assertEqual(len(err_lines), 1)
1012             self.assertEqual(
1013                 unstyle(str(report)),
1014                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1015                 " reformat.",
1016             )
1017             self.assertEqual(report.return_code, 123)
1018             report.failed(Path("e2"), "boom")
1019             self.assertEqual(len(out_lines), 0)
1020             self.assertEqual(len(err_lines), 2)
1021             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1022             self.assertEqual(
1023                 unstyle(str(report)),
1024                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1025                 " reformat.",
1026             )
1027             self.assertEqual(report.return_code, 123)
1028             report.path_ignored(Path("wat"), "no match")
1029             self.assertEqual(len(out_lines), 0)
1030             self.assertEqual(len(err_lines), 2)
1031             self.assertEqual(
1032                 unstyle(str(report)),
1033                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1034                 " reformat.",
1035             )
1036             self.assertEqual(report.return_code, 123)
1037             report.done(Path("f4"), black.Changed.NO)
1038             self.assertEqual(len(out_lines), 0)
1039             self.assertEqual(len(err_lines), 2)
1040             self.assertEqual(
1041                 unstyle(str(report)),
1042                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1043                 " reformat.",
1044             )
1045             self.assertEqual(report.return_code, 123)
1046             report.check = True
1047             self.assertEqual(
1048                 unstyle(str(report)),
1049                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1050                 " would fail to reformat.",
1051             )
1052             report.check = False
1053             report.diff = True
1054             self.assertEqual(
1055                 unstyle(str(report)),
1056                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1057                 " would fail to reformat.",
1058             )
1059
1060     def test_report_normal(self) -> None:
1061         report = black.Report()
1062         out_lines = []
1063         err_lines = []
1064
1065         def out(msg: str, **kwargs: Any) -> None:
1066             out_lines.append(msg)
1067
1068         def err(msg: str, **kwargs: Any) -> None:
1069             err_lines.append(msg)
1070
1071         with patch("black.out", out), patch("black.err", err):
1072             report.done(Path("f1"), black.Changed.NO)
1073             self.assertEqual(len(out_lines), 0)
1074             self.assertEqual(len(err_lines), 0)
1075             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1076             self.assertEqual(report.return_code, 0)
1077             report.done(Path("f2"), black.Changed.YES)
1078             self.assertEqual(len(out_lines), 1)
1079             self.assertEqual(len(err_lines), 0)
1080             self.assertEqual(out_lines[-1], "reformatted f2")
1081             self.assertEqual(
1082                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1083             )
1084             report.done(Path("f3"), black.Changed.CACHED)
1085             self.assertEqual(len(out_lines), 1)
1086             self.assertEqual(len(err_lines), 0)
1087             self.assertEqual(out_lines[-1], "reformatted f2")
1088             self.assertEqual(
1089                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1090             )
1091             self.assertEqual(report.return_code, 0)
1092             report.check = True
1093             self.assertEqual(report.return_code, 1)
1094             report.check = False
1095             report.failed(Path("e1"), "boom")
1096             self.assertEqual(len(out_lines), 1)
1097             self.assertEqual(len(err_lines), 1)
1098             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1099             self.assertEqual(
1100                 unstyle(str(report)),
1101                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1102                 " reformat.",
1103             )
1104             self.assertEqual(report.return_code, 123)
1105             report.done(Path("f3"), black.Changed.YES)
1106             self.assertEqual(len(out_lines), 2)
1107             self.assertEqual(len(err_lines), 1)
1108             self.assertEqual(out_lines[-1], "reformatted f3")
1109             self.assertEqual(
1110                 unstyle(str(report)),
1111                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1112                 " reformat.",
1113             )
1114             self.assertEqual(report.return_code, 123)
1115             report.failed(Path("e2"), "boom")
1116             self.assertEqual(len(out_lines), 2)
1117             self.assertEqual(len(err_lines), 2)
1118             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1119             self.assertEqual(
1120                 unstyle(str(report)),
1121                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1122                 " reformat.",
1123             )
1124             self.assertEqual(report.return_code, 123)
1125             report.path_ignored(Path("wat"), "no match")
1126             self.assertEqual(len(out_lines), 2)
1127             self.assertEqual(len(err_lines), 2)
1128             self.assertEqual(
1129                 unstyle(str(report)),
1130                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1131                 " reformat.",
1132             )
1133             self.assertEqual(report.return_code, 123)
1134             report.done(Path("f4"), black.Changed.NO)
1135             self.assertEqual(len(out_lines), 2)
1136             self.assertEqual(len(err_lines), 2)
1137             self.assertEqual(
1138                 unstyle(str(report)),
1139                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1140                 " reformat.",
1141             )
1142             self.assertEqual(report.return_code, 123)
1143             report.check = True
1144             self.assertEqual(
1145                 unstyle(str(report)),
1146                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1147                 " would fail to reformat.",
1148             )
1149             report.check = False
1150             report.diff = True
1151             self.assertEqual(
1152                 unstyle(str(report)),
1153                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1154                 " would fail to reformat.",
1155             )
1156
1157     def test_lib2to3_parse(self) -> None:
1158         with self.assertRaises(black.InvalidInput):
1159             black.lib2to3_parse("invalid syntax")
1160
1161         straddling = "x + y"
1162         black.lib2to3_parse(straddling)
1163         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1164         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1165         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1166
1167         py2_only = "print x"
1168         black.lib2to3_parse(py2_only)
1169         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1170         with self.assertRaises(black.InvalidInput):
1171             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1172         with self.assertRaises(black.InvalidInput):
1173             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1174
1175         py3_only = "exec(x, end=y)"
1176         black.lib2to3_parse(py3_only)
1177         with self.assertRaises(black.InvalidInput):
1178             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1179         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1180         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1181
1182     def test_get_features_used(self) -> None:
1183         node = black.lib2to3_parse("def f(*, arg): ...\n")
1184         self.assertEqual(black.get_features_used(node), set())
1185         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1186         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1187         node = black.lib2to3_parse("f(*arg,)\n")
1188         self.assertEqual(
1189             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1190         )
1191         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1192         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1193         node = black.lib2to3_parse("123_456\n")
1194         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1195         node = black.lib2to3_parse("123456\n")
1196         self.assertEqual(black.get_features_used(node), set())
1197         source, expected = read_data("function")
1198         node = black.lib2to3_parse(source)
1199         expected_features = {
1200             Feature.TRAILING_COMMA_IN_CALL,
1201             Feature.TRAILING_COMMA_IN_DEF,
1202             Feature.F_STRINGS,
1203         }
1204         self.assertEqual(black.get_features_used(node), expected_features)
1205         node = black.lib2to3_parse(expected)
1206         self.assertEqual(black.get_features_used(node), expected_features)
1207         source, expected = read_data("expression")
1208         node = black.lib2to3_parse(source)
1209         self.assertEqual(black.get_features_used(node), set())
1210         node = black.lib2to3_parse(expected)
1211         self.assertEqual(black.get_features_used(node), set())
1212
1213     def test_get_future_imports(self) -> None:
1214         node = black.lib2to3_parse("\n")
1215         self.assertEqual(set(), black.get_future_imports(node))
1216         node = black.lib2to3_parse("from __future__ import black\n")
1217         self.assertEqual({"black"}, black.get_future_imports(node))
1218         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1219         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1220         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1221         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1222         node = black.lib2to3_parse(
1223             "from __future__ import multiple\nfrom __future__ import imports\n"
1224         )
1225         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1226         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1227         self.assertEqual({"black"}, black.get_future_imports(node))
1228         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1229         self.assertEqual({"black"}, black.get_future_imports(node))
1230         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1231         self.assertEqual(set(), black.get_future_imports(node))
1232         node = black.lib2to3_parse("from some.module import black\n")
1233         self.assertEqual(set(), black.get_future_imports(node))
1234         node = black.lib2to3_parse(
1235             "from __future__ import unicode_literals as _unicode_literals"
1236         )
1237         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1238         node = black.lib2to3_parse(
1239             "from __future__ import unicode_literals as _lol, print"
1240         )
1241         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1242
1243     def test_debug_visitor(self) -> None:
1244         source, _ = read_data("debug_visitor.py")
1245         expected, _ = read_data("debug_visitor.out")
1246         out_lines = []
1247         err_lines = []
1248
1249         def out(msg: str, **kwargs: Any) -> None:
1250             out_lines.append(msg)
1251
1252         def err(msg: str, **kwargs: Any) -> None:
1253             err_lines.append(msg)
1254
1255         with patch("black.out", out), patch("black.err", err):
1256             black.DebugVisitor.show(source)
1257         actual = "\n".join(out_lines) + "\n"
1258         log_name = ""
1259         if expected != actual:
1260             log_name = black.dump_to_file(*out_lines)
1261         self.assertEqual(
1262             expected,
1263             actual,
1264             f"AST print out is different. Actual version dumped to {log_name}",
1265         )
1266
1267     def test_format_file_contents(self) -> None:
1268         empty = ""
1269         mode = DEFAULT_MODE
1270         with self.assertRaises(black.NothingChanged):
1271             black.format_file_contents(empty, mode=mode, fast=False)
1272         just_nl = "\n"
1273         with self.assertRaises(black.NothingChanged):
1274             black.format_file_contents(just_nl, mode=mode, fast=False)
1275         same = "j = [1, 2, 3]\n"
1276         with self.assertRaises(black.NothingChanged):
1277             black.format_file_contents(same, mode=mode, fast=False)
1278         different = "j = [1,2,3]"
1279         expected = same
1280         actual = black.format_file_contents(different, mode=mode, fast=False)
1281         self.assertEqual(expected, actual)
1282         invalid = "return if you can"
1283         with self.assertRaises(black.InvalidInput) as e:
1284             black.format_file_contents(invalid, mode=mode, fast=False)
1285         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1286
1287     def test_endmarker(self) -> None:
1288         n = black.lib2to3_parse("\n")
1289         self.assertEqual(n.type, black.syms.file_input)
1290         self.assertEqual(len(n.children), 1)
1291         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1292
1293     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1294     def test_assertFormatEqual(self) -> None:
1295         out_lines = []
1296         err_lines = []
1297
1298         def out(msg: str, **kwargs: Any) -> None:
1299             out_lines.append(msg)
1300
1301         def err(msg: str, **kwargs: Any) -> None:
1302             err_lines.append(msg)
1303
1304         with patch("black.out", out), patch("black.err", err):
1305             with self.assertRaises(AssertionError):
1306                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1307
1308         out_str = "".join(out_lines)
1309         self.assertTrue("Expected tree:" in out_str)
1310         self.assertTrue("Actual tree:" in out_str)
1311         self.assertEqual("".join(err_lines), "")
1312
1313     def test_cache_broken_file(self) -> None:
1314         mode = DEFAULT_MODE
1315         with cache_dir() as workspace:
1316             cache_file = black.get_cache_file(mode)
1317             with cache_file.open("w") as fobj:
1318                 fobj.write("this is not a pickle")
1319             self.assertEqual(black.read_cache(mode), {})
1320             src = (workspace / "test.py").resolve()
1321             with src.open("w") as fobj:
1322                 fobj.write("print('hello')")
1323             self.invokeBlack([str(src)])
1324             cache = black.read_cache(mode)
1325             self.assertIn(src, cache)
1326
1327     def test_cache_single_file_already_cached(self) -> None:
1328         mode = DEFAULT_MODE
1329         with cache_dir() as workspace:
1330             src = (workspace / "test.py").resolve()
1331             with src.open("w") as fobj:
1332                 fobj.write("print('hello')")
1333             black.write_cache({}, [src], mode)
1334             self.invokeBlack([str(src)])
1335             with src.open("r") as fobj:
1336                 self.assertEqual(fobj.read(), "print('hello')")
1337
1338     @event_loop()
1339     def test_cache_multiple_files(self) -> None:
1340         mode = DEFAULT_MODE
1341         with cache_dir() as workspace, patch(
1342             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1343         ):
1344             one = (workspace / "one.py").resolve()
1345             with one.open("w") as fobj:
1346                 fobj.write("print('hello')")
1347             two = (workspace / "two.py").resolve()
1348             with two.open("w") as fobj:
1349                 fobj.write("print('hello')")
1350             black.write_cache({}, [one], mode)
1351             self.invokeBlack([str(workspace)])
1352             with one.open("r") as fobj:
1353                 self.assertEqual(fobj.read(), "print('hello')")
1354             with two.open("r") as fobj:
1355                 self.assertEqual(fobj.read(), 'print("hello")\n')
1356             cache = black.read_cache(mode)
1357             self.assertIn(one, cache)
1358             self.assertIn(two, cache)
1359
1360     def test_no_cache_when_writeback_diff(self) -> None:
1361         mode = DEFAULT_MODE
1362         with cache_dir() as workspace:
1363             src = (workspace / "test.py").resolve()
1364             with src.open("w") as fobj:
1365                 fobj.write("print('hello')")
1366             self.invokeBlack([str(src), "--diff"])
1367             cache_file = black.get_cache_file(mode)
1368             self.assertFalse(cache_file.exists())
1369
1370     def test_no_cache_when_stdin(self) -> None:
1371         mode = DEFAULT_MODE
1372         with cache_dir():
1373             result = CliRunner().invoke(
1374                 black.main, ["-"], input=BytesIO(b"print('hello')")
1375             )
1376             self.assertEqual(result.exit_code, 0)
1377             cache_file = black.get_cache_file(mode)
1378             self.assertFalse(cache_file.exists())
1379
1380     def test_read_cache_no_cachefile(self) -> None:
1381         mode = DEFAULT_MODE
1382         with cache_dir():
1383             self.assertEqual(black.read_cache(mode), {})
1384
1385     def test_write_cache_read_cache(self) -> None:
1386         mode = DEFAULT_MODE
1387         with cache_dir() as workspace:
1388             src = (workspace / "test.py").resolve()
1389             src.touch()
1390             black.write_cache({}, [src], mode)
1391             cache = black.read_cache(mode)
1392             self.assertIn(src, cache)
1393             self.assertEqual(cache[src], black.get_cache_info(src))
1394
1395     def test_filter_cached(self) -> None:
1396         with TemporaryDirectory() as workspace:
1397             path = Path(workspace)
1398             uncached = (path / "uncached").resolve()
1399             cached = (path / "cached").resolve()
1400             cached_but_changed = (path / "changed").resolve()
1401             uncached.touch()
1402             cached.touch()
1403             cached_but_changed.touch()
1404             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1405             todo, done = black.filter_cached(
1406                 cache, {uncached, cached, cached_but_changed}
1407             )
1408             self.assertEqual(todo, {uncached, cached_but_changed})
1409             self.assertEqual(done, {cached})
1410
1411     def test_write_cache_creates_directory_if_needed(self) -> None:
1412         mode = DEFAULT_MODE
1413         with cache_dir(exists=False) as workspace:
1414             self.assertFalse(workspace.exists())
1415             black.write_cache({}, [], mode)
1416             self.assertTrue(workspace.exists())
1417
1418     @event_loop()
1419     def test_failed_formatting_does_not_get_cached(self) -> None:
1420         mode = DEFAULT_MODE
1421         with cache_dir() as workspace, patch(
1422             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1423         ):
1424             failing = (workspace / "failing.py").resolve()
1425             with failing.open("w") as fobj:
1426                 fobj.write("not actually python")
1427             clean = (workspace / "clean.py").resolve()
1428             with clean.open("w") as fobj:
1429                 fobj.write('print("hello")\n')
1430             self.invokeBlack([str(workspace)], exit_code=123)
1431             cache = black.read_cache(mode)
1432             self.assertNotIn(failing, cache)
1433             self.assertIn(clean, cache)
1434
1435     def test_write_cache_write_fail(self) -> None:
1436         mode = DEFAULT_MODE
1437         with cache_dir(), patch.object(Path, "open") as mock:
1438             mock.side_effect = OSError
1439             black.write_cache({}, [], mode)
1440
1441     @event_loop()
1442     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1443     def test_works_in_mono_process_only_environment(self) -> None:
1444         with cache_dir() as workspace:
1445             for f in [
1446                 (workspace / "one.py").resolve(),
1447                 (workspace / "two.py").resolve(),
1448             ]:
1449                 f.write_text('print("hello")\n')
1450             self.invokeBlack([str(workspace)])
1451
1452     @event_loop()
1453     def test_check_diff_use_together(self) -> None:
1454         with cache_dir():
1455             # Files which will be reformatted.
1456             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1457             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1458             # Files which will not be reformatted.
1459             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1460             self.invokeBlack([str(src2), "--diff", "--check"])
1461             # Multi file command.
1462             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1463
1464     def test_no_files(self) -> None:
1465         with cache_dir():
1466             # Without an argument, black exits with error code 0.
1467             self.invokeBlack([])
1468
1469     def test_broken_symlink(self) -> None:
1470         with cache_dir() as workspace:
1471             symlink = workspace / "broken_link.py"
1472             try:
1473                 symlink.symlink_to("nonexistent.py")
1474             except OSError as e:
1475                 self.skipTest(f"Can't create symlinks: {e}")
1476             self.invokeBlack([str(workspace.resolve())])
1477
1478     def test_read_cache_line_lengths(self) -> None:
1479         mode = DEFAULT_MODE
1480         short_mode = replace(DEFAULT_MODE, line_length=1)
1481         with cache_dir() as workspace:
1482             path = (workspace / "file.py").resolve()
1483             path.touch()
1484             black.write_cache({}, [path], mode)
1485             one = black.read_cache(mode)
1486             self.assertIn(path, one)
1487             two = black.read_cache(short_mode)
1488             self.assertNotIn(path, two)
1489
1490     def test_tricky_unicode_symbols(self) -> None:
1491         source, expected = read_data("tricky_unicode_symbols")
1492         actual = fs(source)
1493         self.assertFormatEqual(expected, actual)
1494         black.assert_equivalent(source, actual)
1495         black.assert_stable(source, actual, DEFAULT_MODE)
1496
1497     def test_single_file_force_pyi(self) -> None:
1498         reg_mode = DEFAULT_MODE
1499         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1500         contents, expected = read_data("force_pyi")
1501         with cache_dir() as workspace:
1502             path = (workspace / "file.py").resolve()
1503             with open(path, "w") as fh:
1504                 fh.write(contents)
1505             self.invokeBlack([str(path), "--pyi"])
1506             with open(path, "r") as fh:
1507                 actual = fh.read()
1508             # verify cache with --pyi is separate
1509             pyi_cache = black.read_cache(pyi_mode)
1510             self.assertIn(path, pyi_cache)
1511             normal_cache = black.read_cache(reg_mode)
1512             self.assertNotIn(path, normal_cache)
1513         self.assertEqual(actual, expected)
1514
1515     @event_loop()
1516     def test_multi_file_force_pyi(self) -> None:
1517         reg_mode = DEFAULT_MODE
1518         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1519         contents, expected = read_data("force_pyi")
1520         with cache_dir() as workspace:
1521             paths = [
1522                 (workspace / "file1.py").resolve(),
1523                 (workspace / "file2.py").resolve(),
1524             ]
1525             for path in paths:
1526                 with open(path, "w") as fh:
1527                     fh.write(contents)
1528             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1529             for path in paths:
1530                 with open(path, "r") as fh:
1531                     actual = fh.read()
1532                 self.assertEqual(actual, expected)
1533             # verify cache with --pyi is separate
1534             pyi_cache = black.read_cache(pyi_mode)
1535             normal_cache = black.read_cache(reg_mode)
1536             for path in paths:
1537                 self.assertIn(path, pyi_cache)
1538                 self.assertNotIn(path, normal_cache)
1539
1540     def test_pipe_force_pyi(self) -> None:
1541         source, expected = read_data("force_pyi")
1542         result = CliRunner().invoke(
1543             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1544         )
1545         self.assertEqual(result.exit_code, 0)
1546         actual = result.output
1547         self.assertFormatEqual(actual, expected)
1548
1549     def test_single_file_force_py36(self) -> None:
1550         reg_mode = DEFAULT_MODE
1551         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1552         source, expected = read_data("force_py36")
1553         with cache_dir() as workspace:
1554             path = (workspace / "file.py").resolve()
1555             with open(path, "w") as fh:
1556                 fh.write(source)
1557             self.invokeBlack([str(path), *PY36_ARGS])
1558             with open(path, "r") as fh:
1559                 actual = fh.read()
1560             # verify cache with --target-version is separate
1561             py36_cache = black.read_cache(py36_mode)
1562             self.assertIn(path, py36_cache)
1563             normal_cache = black.read_cache(reg_mode)
1564             self.assertNotIn(path, normal_cache)
1565         self.assertEqual(actual, expected)
1566
1567     @event_loop()
1568     def test_multi_file_force_py36(self) -> None:
1569         reg_mode = DEFAULT_MODE
1570         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1571         source, expected = read_data("force_py36")
1572         with cache_dir() as workspace:
1573             paths = [
1574                 (workspace / "file1.py").resolve(),
1575                 (workspace / "file2.py").resolve(),
1576             ]
1577             for path in paths:
1578                 with open(path, "w") as fh:
1579                     fh.write(source)
1580             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1581             for path in paths:
1582                 with open(path, "r") as fh:
1583                     actual = fh.read()
1584                 self.assertEqual(actual, expected)
1585             # verify cache with --target-version is separate
1586             pyi_cache = black.read_cache(py36_mode)
1587             normal_cache = black.read_cache(reg_mode)
1588             for path in paths:
1589                 self.assertIn(path, pyi_cache)
1590                 self.assertNotIn(path, normal_cache)
1591
1592     def test_collections(self) -> None:
1593         source, expected = read_data("collections")
1594         actual = fs(source)
1595         self.assertFormatEqual(expected, actual)
1596         black.assert_equivalent(source, actual)
1597         black.assert_stable(source, actual, DEFAULT_MODE)
1598
1599     def test_pipe_force_py36(self) -> None:
1600         source, expected = read_data("force_py36")
1601         result = CliRunner().invoke(
1602             black.main,
1603             ["-", "-q", "--target-version=py36"],
1604             input=BytesIO(source.encode("utf8")),
1605         )
1606         self.assertEqual(result.exit_code, 0)
1607         actual = result.output
1608         self.assertFormatEqual(actual, expected)
1609
1610     def test_include_exclude(self) -> None:
1611         path = THIS_DIR / "data" / "include_exclude_tests"
1612         include = re.compile(r"\.pyi?$")
1613         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1614         report = black.Report()
1615         gitignore = PathSpec.from_lines("gitwildmatch", [])
1616         sources: List[Path] = []
1617         expected = [
1618             Path(path / "b/dont_exclude/a.py"),
1619             Path(path / "b/dont_exclude/a.pyi"),
1620         ]
1621         this_abs = THIS_DIR.resolve()
1622         sources.extend(
1623             black.gen_python_files(
1624                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1625             )
1626         )
1627         self.assertEqual(sorted(expected), sorted(sources))
1628
1629     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1630     def test_exclude_for_issue_1572(self) -> None:
1631         # Exclude shouldn't touch files that were explicitly given to Black through the
1632         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1633         # https://github.com/psf/black/issues/1572
1634         path = THIS_DIR / "data" / "include_exclude_tests"
1635         include = ""
1636         exclude = r"/exclude/|a\.py"
1637         src = str(path / "b/exclude/a.py")
1638         report = black.Report()
1639         expected = [Path(path / "b/exclude/a.py")]
1640         sources = list(
1641             black.get_sources(
1642                 ctx=FakeContext(),
1643                 src=(src,),
1644                 quiet=True,
1645                 verbose=False,
1646                 include=include,
1647                 exclude=exclude,
1648                 force_exclude=None,
1649                 report=report,
1650             )
1651         )
1652         self.assertEqual(sorted(expected), sorted(sources))
1653
1654     def test_gitignore_exclude(self) -> None:
1655         path = THIS_DIR / "data" / "include_exclude_tests"
1656         include = re.compile(r"\.pyi?$")
1657         exclude = re.compile(r"")
1658         report = black.Report()
1659         gitignore = PathSpec.from_lines(
1660             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1661         )
1662         sources: List[Path] = []
1663         expected = [
1664             Path(path / "b/dont_exclude/a.py"),
1665             Path(path / "b/dont_exclude/a.pyi"),
1666         ]
1667         this_abs = THIS_DIR.resolve()
1668         sources.extend(
1669             black.gen_python_files(
1670                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1671             )
1672         )
1673         self.assertEqual(sorted(expected), sorted(sources))
1674
1675     def test_empty_include(self) -> None:
1676         path = THIS_DIR / "data" / "include_exclude_tests"
1677         report = black.Report()
1678         gitignore = PathSpec.from_lines("gitwildmatch", [])
1679         empty = re.compile(r"")
1680         sources: List[Path] = []
1681         expected = [
1682             Path(path / "b/exclude/a.pie"),
1683             Path(path / "b/exclude/a.py"),
1684             Path(path / "b/exclude/a.pyi"),
1685             Path(path / "b/dont_exclude/a.pie"),
1686             Path(path / "b/dont_exclude/a.py"),
1687             Path(path / "b/dont_exclude/a.pyi"),
1688             Path(path / "b/.definitely_exclude/a.pie"),
1689             Path(path / "b/.definitely_exclude/a.py"),
1690             Path(path / "b/.definitely_exclude/a.pyi"),
1691         ]
1692         this_abs = THIS_DIR.resolve()
1693         sources.extend(
1694             black.gen_python_files(
1695                 path.iterdir(),
1696                 this_abs,
1697                 empty,
1698                 re.compile(black.DEFAULT_EXCLUDES),
1699                 None,
1700                 report,
1701                 gitignore,
1702             )
1703         )
1704         self.assertEqual(sorted(expected), sorted(sources))
1705
1706     def test_empty_exclude(self) -> None:
1707         path = THIS_DIR / "data" / "include_exclude_tests"
1708         report = black.Report()
1709         gitignore = PathSpec.from_lines("gitwildmatch", [])
1710         empty = re.compile(r"")
1711         sources: List[Path] = []
1712         expected = [
1713             Path(path / "b/dont_exclude/a.py"),
1714             Path(path / "b/dont_exclude/a.pyi"),
1715             Path(path / "b/exclude/a.py"),
1716             Path(path / "b/exclude/a.pyi"),
1717             Path(path / "b/.definitely_exclude/a.py"),
1718             Path(path / "b/.definitely_exclude/a.pyi"),
1719         ]
1720         this_abs = THIS_DIR.resolve()
1721         sources.extend(
1722             black.gen_python_files(
1723                 path.iterdir(),
1724                 this_abs,
1725                 re.compile(black.DEFAULT_INCLUDES),
1726                 empty,
1727                 None,
1728                 report,
1729                 gitignore,
1730             )
1731         )
1732         self.assertEqual(sorted(expected), sorted(sources))
1733
1734     def test_invalid_include_exclude(self) -> None:
1735         for option in ["--include", "--exclude"]:
1736             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1737
1738     def test_preserves_line_endings(self) -> None:
1739         with TemporaryDirectory() as workspace:
1740             test_file = Path(workspace) / "test.py"
1741             for nl in ["\n", "\r\n"]:
1742                 contents = nl.join(["def f(  ):", "    pass"])
1743                 test_file.write_bytes(contents.encode())
1744                 ff(test_file, write_back=black.WriteBack.YES)
1745                 updated_contents: bytes = test_file.read_bytes()
1746                 self.assertIn(nl.encode(), updated_contents)
1747                 if nl == "\n":
1748                     self.assertNotIn(b"\r\n", updated_contents)
1749
1750     def test_preserves_line_endings_via_stdin(self) -> None:
1751         for nl in ["\n", "\r\n"]:
1752             contents = nl.join(["def f(  ):", "    pass"])
1753             runner = BlackRunner()
1754             result = runner.invoke(
1755                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1756             )
1757             self.assertEqual(result.exit_code, 0)
1758             output = runner.stdout_bytes
1759             self.assertIn(nl.encode("utf8"), output)
1760             if nl == "\n":
1761                 self.assertNotIn(b"\r\n", output)
1762
1763     def test_assert_equivalent_different_asts(self) -> None:
1764         with self.assertRaises(AssertionError):
1765             black.assert_equivalent("{}", "None")
1766
1767     def test_symlink_out_of_root_directory(self) -> None:
1768         path = MagicMock()
1769         root = THIS_DIR.resolve()
1770         child = MagicMock()
1771         include = re.compile(black.DEFAULT_INCLUDES)
1772         exclude = re.compile(black.DEFAULT_EXCLUDES)
1773         report = black.Report()
1774         gitignore = PathSpec.from_lines("gitwildmatch", [])
1775         # `child` should behave like a symlink which resolved path is clearly
1776         # outside of the `root` directory.
1777         path.iterdir.return_value = [child]
1778         child.resolve.return_value = Path("/a/b/c")
1779         child.as_posix.return_value = "/a/b/c"
1780         child.is_symlink.return_value = True
1781         try:
1782             list(
1783                 black.gen_python_files(
1784                     path.iterdir(), root, include, exclude, None, report, gitignore
1785                 )
1786             )
1787         except ValueError as ve:
1788             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1789         path.iterdir.assert_called_once()
1790         child.resolve.assert_called_once()
1791         child.is_symlink.assert_called_once()
1792         # `child` should behave like a strange file which resolved path is clearly
1793         # outside of the `root` directory.
1794         child.is_symlink.return_value = False
1795         with self.assertRaises(ValueError):
1796             list(
1797                 black.gen_python_files(
1798                     path.iterdir(), root, include, exclude, None, report, gitignore
1799                 )
1800             )
1801         path.iterdir.assert_called()
1802         self.assertEqual(path.iterdir.call_count, 2)
1803         child.resolve.assert_called()
1804         self.assertEqual(child.resolve.call_count, 2)
1805         child.is_symlink.assert_called()
1806         self.assertEqual(child.is_symlink.call_count, 2)
1807
1808     def test_shhh_click(self) -> None:
1809         try:
1810             from click import _unicodefun  # type: ignore
1811         except ModuleNotFoundError:
1812             self.skipTest("Incompatible Click version")
1813         if not hasattr(_unicodefun, "_verify_python3_env"):
1814             self.skipTest("Incompatible Click version")
1815         # First, let's see if Click is crashing with a preferred ASCII charset.
1816         with patch("locale.getpreferredencoding") as gpe:
1817             gpe.return_value = "ASCII"
1818             with self.assertRaises(RuntimeError):
1819                 _unicodefun._verify_python3_env()
1820         # Now, let's silence Click...
1821         black.patch_click()
1822         # ...and confirm it's silent.
1823         with patch("locale.getpreferredencoding") as gpe:
1824             gpe.return_value = "ASCII"
1825             try:
1826                 _unicodefun._verify_python3_env()
1827             except RuntimeError as re:
1828                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1829
1830     def test_root_logger_not_used_directly(self) -> None:
1831         def fail(*args: Any, **kwargs: Any) -> None:
1832             self.fail("Record created with root logger")
1833
1834         with patch.multiple(
1835             logging.root,
1836             debug=fail,
1837             info=fail,
1838             warning=fail,
1839             error=fail,
1840             critical=fail,
1841             log=fail,
1842         ):
1843             ff(THIS_FILE)
1844
1845     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1846     def test_blackd_main(self) -> None:
1847         with patch("blackd.web.run_app"):
1848             result = CliRunner().invoke(blackd.main, [])
1849             if result.exception is not None:
1850                 raise result.exception
1851             self.assertEqual(result.exit_code, 0)
1852
1853     def test_invalid_config_return_code(self) -> None:
1854         tmp_file = Path(black.dump_to_file())
1855         try:
1856             tmp_config = Path(black.dump_to_file())
1857             tmp_config.unlink()
1858             args = ["--config", str(tmp_config), str(tmp_file)]
1859             self.invokeBlack(args, exit_code=2, ignore_config=False)
1860         finally:
1861             tmp_file.unlink()
1862
1863     def test_parse_pyproject_toml(self) -> None:
1864         test_toml_file = THIS_DIR / "test.toml"
1865         config = black.parse_pyproject_toml(str(test_toml_file))
1866         self.assertEqual(config["verbose"], 1)
1867         self.assertEqual(config["check"], "no")
1868         self.assertEqual(config["diff"], "y")
1869         self.assertEqual(config["color"], True)
1870         self.assertEqual(config["line_length"], 79)
1871         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1872         self.assertEqual(config["exclude"], r"\.pyi?$")
1873         self.assertEqual(config["include"], r"\.py?$")
1874
1875     def test_read_pyproject_toml(self) -> None:
1876         test_toml_file = THIS_DIR / "test.toml"
1877         fake_ctx = FakeContext()
1878         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1879         config = fake_ctx.default_map
1880         self.assertEqual(config["verbose"], "1")
1881         self.assertEqual(config["check"], "no")
1882         self.assertEqual(config["diff"], "y")
1883         self.assertEqual(config["color"], "True")
1884         self.assertEqual(config["line_length"], "79")
1885         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1886         self.assertEqual(config["exclude"], r"\.pyi?$")
1887         self.assertEqual(config["include"], r"\.py?$")
1888
1889     def test_find_project_root(self) -> None:
1890         with TemporaryDirectory() as workspace:
1891             root = Path(workspace)
1892             test_dir = root / "test"
1893             test_dir.mkdir()
1894
1895             src_dir = root / "src"
1896             src_dir.mkdir()
1897
1898             root_pyproject = root / "pyproject.toml"
1899             root_pyproject.touch()
1900             src_pyproject = src_dir / "pyproject.toml"
1901             src_pyproject.touch()
1902             src_python = src_dir / "foo.py"
1903             src_python.touch()
1904
1905             self.assertEqual(
1906                 black.find_project_root((src_dir, test_dir)), root.resolve()
1907             )
1908             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1909             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1910
1911
1912 class BlackDTestCase(AioHTTPTestCase):
1913     async def get_application(self) -> web.Application:
1914         return blackd.make_app()
1915
1916     # TODO: remove these decorators once the below is released
1917     # https://github.com/aio-libs/aiohttp/pull/3727
1918     @skip_if_exception("ClientOSError")
1919     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1920     @unittest_run_loop
1921     async def test_blackd_request_needs_formatting(self) -> None:
1922         response = await self.client.post("/", data=b"print('hello world')")
1923         self.assertEqual(response.status, 200)
1924         self.assertEqual(response.charset, "utf8")
1925         self.assertEqual(await response.read(), b'print("hello world")\n')
1926
1927     @skip_if_exception("ClientOSError")
1928     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1929     @unittest_run_loop
1930     async def test_blackd_request_no_change(self) -> None:
1931         response = await self.client.post("/", data=b'print("hello world")\n')
1932         self.assertEqual(response.status, 204)
1933         self.assertEqual(await response.read(), b"")
1934
1935     @skip_if_exception("ClientOSError")
1936     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1937     @unittest_run_loop
1938     async def test_blackd_request_syntax_error(self) -> None:
1939         response = await self.client.post("/", data=b"what even ( is")
1940         self.assertEqual(response.status, 400)
1941         content = await response.text()
1942         self.assertTrue(
1943             content.startswith("Cannot parse"),
1944             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1945         )
1946
1947     @skip_if_exception("ClientOSError")
1948     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1949     @unittest_run_loop
1950     async def test_blackd_unsupported_version(self) -> None:
1951         response = await self.client.post(
1952             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1953         )
1954         self.assertEqual(response.status, 501)
1955
1956     @skip_if_exception("ClientOSError")
1957     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1958     @unittest_run_loop
1959     async def test_blackd_supported_version(self) -> None:
1960         response = await self.client.post(
1961             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1962         )
1963         self.assertEqual(response.status, 200)
1964
1965     @skip_if_exception("ClientOSError")
1966     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1967     @unittest_run_loop
1968     async def test_blackd_invalid_python_variant(self) -> None:
1969         async def check(header_value: str, expected_status: int = 400) -> None:
1970             response = await self.client.post(
1971                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1972             )
1973             self.assertEqual(response.status, expected_status)
1974
1975         await check("lol")
1976         await check("ruby3.5")
1977         await check("pyi3.6")
1978         await check("py1.5")
1979         await check("2.8")
1980         await check("py2.8")
1981         await check("3.0")
1982         await check("pypy3.0")
1983         await check("jython3.4")
1984
1985     @skip_if_exception("ClientOSError")
1986     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1987     @unittest_run_loop
1988     async def test_blackd_pyi(self) -> None:
1989         source, expected = read_data("stub.pyi")
1990         response = await self.client.post(
1991             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1992         )
1993         self.assertEqual(response.status, 200)
1994         self.assertEqual(await response.text(), expected)
1995
1996     @skip_if_exception("ClientOSError")
1997     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1998     @unittest_run_loop
1999     async def test_blackd_diff(self) -> None:
2000         diff_header = re.compile(
2001             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2002         )
2003
2004         source, _ = read_data("blackd_diff.py")
2005         expected, _ = read_data("blackd_diff.diff")
2006
2007         response = await self.client.post(
2008             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
2009         )
2010         self.assertEqual(response.status, 200)
2011
2012         actual = await response.text()
2013         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2014         self.assertEqual(actual, expected)
2015
2016     @skip_if_exception("ClientOSError")
2017     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2018     @unittest_run_loop
2019     async def test_blackd_python_variant(self) -> None:
2020         code = (
2021             "def f(\n"
2022             "    and_has_a_bunch_of,\n"
2023             "    very_long_arguments_too,\n"
2024             "    and_lots_of_them_as_well_lol,\n"
2025             "    **and_very_long_keyword_arguments\n"
2026             "):\n"
2027             "    pass\n"
2028         )
2029
2030         async def check(header_value: str, expected_status: int) -> None:
2031             response = await self.client.post(
2032                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2033             )
2034             self.assertEqual(
2035                 response.status, expected_status, msg=await response.text()
2036             )
2037
2038         await check("3.6", 200)
2039         await check("py3.6", 200)
2040         await check("3.6,3.7", 200)
2041         await check("3.6,py3.7", 200)
2042         await check("py36,py37", 200)
2043         await check("36", 200)
2044         await check("3.6.4", 200)
2045
2046         await check("2", 204)
2047         await check("2.7", 204)
2048         await check("py2.7", 204)
2049         await check("3.4", 204)
2050         await check("py3.4", 204)
2051         await check("py34,py36", 204)
2052         await check("34", 204)
2053
2054     @skip_if_exception("ClientOSError")
2055     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2056     @unittest_run_loop
2057     async def test_blackd_line_length(self) -> None:
2058         response = await self.client.post(
2059             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2060         )
2061         self.assertEqual(response.status, 200)
2062
2063     @skip_if_exception("ClientOSError")
2064     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2065     @unittest_run_loop
2066     async def test_blackd_invalid_line_length(self) -> None:
2067         response = await self.client.post(
2068             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2069         )
2070         self.assertEqual(response.status, 400)
2071
2072     @skip_if_exception("ClientOSError")
2073     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2074     @unittest_run_loop
2075     async def test_blackd_response_black_version_header(self) -> None:
2076         response = await self.client.post("/")
2077         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2078
2079
2080 with open(black.__file__, "r", encoding="utf-8") as _bf:
2081     black_source_lines = _bf.readlines()
2082
2083
2084 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2085     """Show function calls `from black/__init__.py` as they happen.
2086
2087     Register this with `sys.settrace()` in a test you're debugging.
2088     """
2089     if event != "call":
2090         return tracefunc
2091
2092     stack = len(inspect.stack()) - 19
2093     stack *= 2
2094     filename = frame.f_code.co_filename
2095     lineno = frame.f_lineno
2096     func_sig_lineno = lineno - 1
2097     funcname = black_source_lines[func_sig_lineno].strip()
2098     while funcname.startswith("@"):
2099         func_sig_lineno += 1
2100         funcname = black_source_lines[func_sig_lineno].strip()
2101     if "black/__init__.py" in filename:
2102         print(f"{' ' * stack}{lineno}:{funcname}")
2103     return tracefunc
2104
2105
2106 if __name__ == "__main__":
2107     unittest.main(module="test_black")