(.venv)$ tox -e py -- --print-tree-diff=False
```
-`Black` has two pytest command-line options affecting test files in `tests/data/` that
+### Testing
+
+All aspects of the _Black_ style should be tested. Normally, tests should be created as
+files in the `tests/data/cases` directory. These files consist of up to three parts:
+
+- A line that starts with `# flags: ` followed by a set of command-line options. For
+ example, if the line is `# flags: --preview --skip-magic-trailing-comma`, the test
+ case will be run with preview mode on and the magic trailing comma off. The options
+ accepted are mostly a subset of those of _Black_ itself, except for the
+ `--minimum-version=` flag, which should be used when testing a grammar feature that
+ works only in newer versions of Python. This flag ensures that we don't try to
+ validate the AST on older versions and tests that we autodetect the Python version
+ correctly when the feature is used. For the exact flags accepted, see the function
+ `get_flags_parser` in `tests/util.py`. If this line is omitted, the default options
+ are used.
+- A block of Python code used as input for the formatter.
+- The line `# output`, followed by the output of _Black_ when run on the previous block.
+ If this is omitted, the test asserts that _Black_ will leave the input code unchanged.
+
+_Black_ has two pytest command-line options affecting test files in `tests/data/` that
are split into an input part, and an output part, separated by a line with`# output`.
These can be passed to `pytest` through `tox`, or directly into pytest if not using
`tox`.
+# flags: --skip-string-normalization
class ALonelyClass:
'''
A multiline class docstring.
+# flags: --preview --skip-string-normalization
def do_not_touch_this_prefix():
R"""There was a bug where docstring prefixes would be normalized even with -S."""
+# flags: --preview --minimum-version=3.10
# normal, short, function definition
def foo(a, b) -> tuple[int, float]: ...
+# flags: --pyi
def f(): # type: ignore
...
+# flags: --line-length=6
# Regression test for #3427, which reproes only with line length <= 6
def f():
"""
+# flags: --pyi --preview
import sys
class Outer:
-#!/usr/bin/env python3.6
-
x = 123456789
x = 123456
x = .1
# output
-
-#!/usr/bin/env python3.6
-
x = 123456789
x = 123456
x = 0.1
-#!/usr/bin/env python3.6
-
x = 123456789
x = 1_2_3_4_5_6_7
x = 1E+1
# output
-#!/usr/bin/env python3.6
-
x = 123456789
x = 1_2_3_4_5_6_7
x = 1e1
+# flags: --minimum-version=3.10
with (CtxManager() as example):
...
+# flags: --minimum-version=3.10
# Cases sampled from Lib/test/test_patma.py
# case black_test_patma_098
+# flags: --minimum-version=3.10
import match
match something:
+# flags: --minimum-version=3.10
re.match()
match = a
with match() as match:
+# flags: --minimum-version=3.10
# Cases sampled from PEP 636 examples
match command.split():
+# flags: --minimum-version=3.10
match something:
case b(): print(1+1)
case c(
+# flags: --preview --minimum-version=3.10
# This has always worked
z= Loooooooooooooooooooooooong | Loooooooooooooooooooooooong | Loooooooooooooooooooooooong | Loooooooooooooooooooooooong
+# flags: --minimum-version=3.8
def positional_only_arg(a, /):
pass
+# flags: --minimum-version=3.8
(a := 1)
(a := a)
if (match := pattern.search(data)) is None:
+# flags: --fast
# Most of the following examples are really dumb, some of them aren't even accepted by Python,
# we're fixing them only so fuzzers (which follow the grammar which actually allows these
# examples matter of fact!) don't yell at us :p
+# flags: --minimum-version=3.10
# Unparenthesized walruses are now allowed in indices since Python 3.10.
x[a:=0]
x[a:=0, b:=1]
+# flags: --minimum-version=3.9
# Unparenthesized walruses are now allowed in set literals & set comprehensions
# since Python 3.9
{x := 1, 2, 3}
+# flags: --minimum-version=3.8
if (foo := 0):
pass
+# flags: --minimum-version=3.11
A[*b]
A[*b] = 1
A
+# flags: --minimum-version=3.11
try:
raise OSError("blah")
except* ExceptionGroup as e:
+# flags: --minimum-version=3.11
try:
raise OSError("blah")
except * ExceptionGroup as e:
+# flags: --line-length=0
importA;()<<0**0#
# output
+# flags: --preview
async def func() -> (int):
return 0
+# flags: --preview
# long variable name
this_is_a_ridiculously_long_name_and_nobody_in_their_right_mind_would_use_one_like_it = 0
this_is_a_ridiculously_long_name_and_nobody_in_their_right_mind_would_use_one_like_it = 1 # with a comment
+# flags: --preview
from .config import (
Any,
Bool,
+# flags: --preview --minimum-version=3.8
with \
make_context_manager1() as cm1, \
make_context_manager2() as cm2, \
+# flags: --preview --minimum-version=3.9
with \
make_context_manager1() as cm1, \
make_context_manager2() as cm2, \
+# flags: --preview --minimum-version=3.10
# This file uses pattern matching introduced in Python 3.10.
+# flags: --preview --minimum-version=3.11
# This file uses except* clause in Python 3.11.
+# flags: --preview
# This file doesn't use any Python 3.9+ only grammars.
+# flags: --preview --minimum-version=3.9
# This file uses parenthesized context managers introduced in Python 3.9.
+# flags: --preview
from typing import NoReturn, Protocol, Union, overload
+# flags: --preview
x = "\x1F"
x = "\\x1B"
x = "\\\x1B"
+# flags: --preview
my_dict = {
"something_something":
r"Lorem ipsum dolor sit amet, an sed convenire eloquentiam \t"
+# flags: --preview
x = "This is a really long string that can't possibly be expected to fit all together on one line. In fact it may even take up three or more lines... like four or five... but probably just three."
x += "This is a really long string that can't possibly be expected to fit all together on one line. In fact it may even take up three or more lines... like four or five... but probably just three."
+# flags: --preview\r
# The following strings do not have not-so-many chars, but are long enough\r
# when these are rendered in a monospace font (if the renderer respects\r
# Unicode East Asian Width properties).\r
+# flags: --preview
some_variable = "This string is long but not so long that it needs to be split just yet"
some_variable = 'This string is long but not so long that it needs to be split just yet'
some_variable = "This string is long, just long enough that it needs to be split, u get?"
+# flags: --preview
class A:
def foo():
result = type(message)("")
+# flags: --preview
def func(
arg1,
arg2,
+# flags: --preview
"""cow
say""",
call(3, "dogsay", textwrap.dedent("""dove
+# flags: --preview
def line_before_docstring():
"""Please move me up"""
+# flags: --preview
x[(a:=0):]
x[:(a:=0)]
+# flags: --preview
("" % a) ** 2
("" % a)[0]
("" % a)()
+# flags: --preview
first_item, second_item = (
some_looooooooong_module.some_looooooooooooooong_function_name(
first_argument, second_argument, third_argument
+# flags: --preview
# Long string example
def frobnicate() -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
pass
+# flags: --preview
e = {
"a": fun(msg, "ts"),
"longggggggggggggggid": ...,
+# flags: --preview --minimum-version=3.10
x[a:=0]
x[a := 0]
x[a := 0, b := 1]
-#!/usr/bin/env python3.7
+# flags: --minimum-version=3.7
def f():
# output
-#!/usr/bin/env python3.7
-
-
def f():
return (i * 2 async for i in arange(42))
-#!/usr/bin/env python3.8
+# flags: --minimum-version=3.8
def starred_return():
# output
-#!/usr/bin/env python3.8
-
-
def starred_return():
my_list = ["value2", "value3"]
return "value1", *my_list
-#!/usr/bin/env python3.9
+# flags: --minimum-version=3.9
@relaxed_decorator[0]
def f():
# output
-
-#!/usr/bin/env python3.9
-
-
@relaxed_decorator[0]
def f():
...
+# flags: --minimum-version=3.10
def http_status(status):
match status:
+# flags: --minimum-version=3.9
with (open("bla.txt")):
pass
+# flags: --skip-magic-trailing-comma
# We should not remove the trailing comma in a single-element subscript.
a: tuple[int,]
b = tuple[int,]
+# flags: --minimum-version=3.10
for x in *a, *b:
print(x)
+# flags: --pyi
X: int
def f(): ...
+# flags: --minimum-version=3.12
type A=int
type Gen[T]=list[T]
+# flags: --minimum-version=3.12
def func [T ](): pass
async def func [ T ] (): pass
class C[ T ] : pass
+# flags: --pyi
from typing import Union
@bird
)
def test_piping(self) -> None:
- source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
+ _, source, expected = read_data_from_file(
+ PROJECT_ROOT / "src/black/__init__.py"
+ )
result = BlackRunner().invoke(
black.main,
[
r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d"
r"\+\d\d:\d\d"
)
- source, _ = read_data("simple_cases", "expression.py")
- expected, _ = read_data("simple_cases", "expression.diff")
+ source, _ = read_data("cases", "expression.py")
+ expected, _ = read_data("cases", "expression.diff")
args = [
"-",
"--fast",
self.assertEqual(expected, actual)
def test_piping_diff_with_color(self) -> None:
- source, _ = read_data("simple_cases", "expression.py")
+ source, _ = read_data("cases", "expression.py")
args = [
"-",
"--fast",
black.assert_stable(source, actual, black.FileMode())
def test_pep_572_version_detection(self) -> None:
- source, _ = read_data("py_38", "pep_572")
+ source, _ = read_data("cases", "pep_572")
root = black.lib2to3_parse(source)
features = black.get_features_used(root)
self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
def test_pep_695_version_detection(self) -> None:
for file in ("type_aliases", "type_params"):
- source, _ = read_data("py_312", file)
+ source, _ = read_data("cases", file)
root = black.lib2to3_parse(source)
features = black.get_features_used(root)
self.assertIn(black.Feature.TYPE_PARAMS, features)
self.assertIn(black.TargetVersion.PY312, versions)
def test_expression_ff(self) -> None:
- source, expected = read_data("simple_cases", "expression.py")
+ source, expected = read_data("cases", "expression.py")
tmp_file = Path(black.dump_to_file(source))
try:
self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
black.assert_stable(source, actual, DEFAULT_MODE)
def test_expression_diff(self) -> None:
- source, _ = read_data("simple_cases", "expression.py")
- expected, _ = read_data("simple_cases", "expression.diff")
+ source, _ = read_data("cases", "expression.py")
+ expected, _ = read_data("cases", "expression.diff")
tmp_file = Path(black.dump_to_file(source))
diff_header = re.compile(
rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
self.assertEqual(expected, actual, msg)
def test_expression_diff_with_color(self) -> None:
- source, _ = read_data("simple_cases", "expression.py")
- expected, _ = read_data("simple_cases", "expression.diff")
+ source, _ = read_data("cases", "expression.py")
+ expected, _ = read_data("cases", "expression.diff")
tmp_file = Path(black.dump_to_file(source))
try:
result = BlackRunner().invoke(
self.assertIn("\033[0m", actual)
def test_detect_pos_only_arguments(self) -> None:
- source, _ = read_data("py_38", "pep_570")
+ source, _ = read_data("cases", "pep_570")
root = black.lib2to3_parse(source)
features = black.get_features_used(root)
self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
self.assertEqual(test_file.read_bytes(), expected)
def test_skip_magic_trailing_comma(self) -> None:
- source, _ = read_data("simple_cases", "expression")
+ source, _ = read_data("cases", "expression")
expected, _ = read_data(
"miscellaneous", "expression_skip_magic_trailing_comma.diff"
)
@patch("black.dump_to_file", dump_to_stderr)
def test_async_as_identifier(self) -> None:
source_path = get_case_path("miscellaneous", "async_as_identifier")
- source, expected = read_data_from_file(source_path)
+ _, source, expected = read_data_from_file(source_path)
actual = fs(source)
self.assertFormatEqual(expected, actual)
major, minor = sys.version_info[:2]
@patch("black.dump_to_file", dump_to_stderr)
def test_python37(self) -> None:
- source_path = get_case_path("py_37", "python37")
- source, expected = read_data_from_file(source_path)
+ source_path = get_case_path("cases", "python37")
+ _, source, expected = read_data_from_file(source_path)
actual = fs(source)
self.assertFormatEqual(expected, actual)
major, minor = sys.version_info[:2]
self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
node = black.lib2to3_parse("123456\n")
self.assertEqual(black.get_features_used(node), set())
- source, expected = read_data("simple_cases", "function")
+ source, expected = read_data("cases", "function")
node = black.lib2to3_parse(source)
expected_features = {
Feature.TRAILING_COMMA_IN_CALL,
self.assertEqual(black.get_features_used(node), expected_features)
node = black.lib2to3_parse(expected)
self.assertEqual(black.get_features_used(node), expected_features)
- source, expected = read_data("simple_cases", "expression")
+ source, expected = read_data("cases", "expression")
node = black.lib2to3_parse(source)
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse(expected)
src1 = get_case_path("miscellaneous", "string_quotes")
self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
# Files which will not be reformatted.
- src2 = get_case_path("simple_cases", "composition")
+ src2 = get_case_path("cases", "composition")
self.invokeBlack([str(src2), "--diff", "--check"])
# Multi file command.
self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
report = MagicMock()
# Even with an existing file, since we are forcing stdin, black
# should output to stdout and not modify the file inplace
- p = THIS_DIR / "data" / "simple_cases" / "collections.py"
+ p = THIS_DIR / "data" / "cases" / "collections.py"
# Make sure is_file actually returns True
self.assertTrue(p.is_file())
path = Path(f"__BLACK_STDIN_FILENAME__{p}")
@unittest_run_loop
async def test_blackd_pyi(self) -> None:
- source, expected = read_data("miscellaneous", "stub.pyi")
+ source, expected = read_data("cases", "stub.py")
response = await self.client.post(
"/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
)
-import re
from dataclasses import replace
from typing import Any, Iterator
from unittest.mock import patch
import pytest
import black
+from black.mode import TargetVersion
from tests.util import (
- DEFAULT_MODE,
- PY36_VERSIONS,
all_data_cases,
assert_format,
dump_to_stderr,
read_data,
+ read_data_with_mode,
)
yield
-def check_file(
- subdir: str, filename: str, mode: black.Mode, *, data: bool = True
-) -> None:
- source, expected = read_data(subdir, filename, data=data)
- assert_format(source, expected, mode, fast=False)
+def check_file(subdir: str, filename: str, *, data: bool = True) -> None:
+ args, source, expected = read_data_with_mode(subdir, filename, data=data)
+ assert_format(
+ source,
+ expected,
+ args.mode,
+ fast=args.fast,
+ minimum_version=args.minimum_version,
+ )
+ if args.minimum_version is not None:
+ major, minor = args.minimum_version
+ target_version = TargetVersion[f"PY{major}{minor}"]
+ mode = replace(args.mode, target_versions={target_version})
+ assert_format(
+ source, expected, mode, fast=args.fast, minimum_version=args.minimum_version
+ )
@pytest.mark.filterwarnings("ignore:invalid escape sequence.*:DeprecationWarning")
-@pytest.mark.parametrize("filename", all_data_cases("simple_cases"))
+@pytest.mark.parametrize("filename", all_data_cases("cases"))
def test_simple_format(filename: str) -> None:
- magic_trailing_comma = filename != "skip_magic_trailing_comma"
- mode = black.Mode(
- magic_trailing_comma=magic_trailing_comma, is_pyi=filename.endswith("_pyi")
- )
- check_file("simple_cases", filename, mode)
-
-
-@pytest.mark.parametrize("filename", all_data_cases("preview"))
-def test_preview_format(filename: str) -> None:
- check_file("preview", filename, black.Mode(preview=True))
-
-
-def test_preview_context_managers_targeting_py38() -> None:
- source, expected = read_data("preview_context_managers", "targeting_py38.py")
- mode = black.Mode(preview=True, target_versions={black.TargetVersion.PY38})
- assert_format(source, expected, mode, minimum_version=(3, 8))
-
-
-def test_preview_context_managers_targeting_py39() -> None:
- source, expected = read_data("preview_context_managers", "targeting_py39.py")
- mode = black.Mode(preview=True, target_versions={black.TargetVersion.PY39})
- assert_format(source, expected, mode, minimum_version=(3, 9))
-
-
-@pytest.mark.parametrize("filename", all_data_cases("preview_py_310"))
-def test_preview_python_310(filename: str) -> None:
- source, expected = read_data("preview_py_310", filename)
- mode = black.Mode(target_versions={black.TargetVersion.PY310}, preview=True)
- assert_format(source, expected, mode, minimum_version=(3, 10))
-
-
-@pytest.mark.parametrize(
- "filename", all_data_cases("preview_context_managers/auto_detect")
-)
-def test_preview_context_managers_auto_detect(filename: str) -> None:
- match = re.match(r"features_3_(\d+)", filename)
- assert match is not None, "Unexpected filename format: %s" % filename
- source, expected = read_data("preview_context_managers/auto_detect", filename)
- mode = black.Mode(preview=True)
- assert_format(source, expected, mode, minimum_version=(3, int(match.group(1))))
+ check_file("cases", filename)
# =============== #
-# Complex cases
-# ============= #
+# Unusual cases
+# =============== #
def test_empty() -> None:
assert_format(source, expected)
-@pytest.mark.parametrize("filename", all_data_cases("py_36"))
-def test_python_36(filename: str) -> None:
- source, expected = read_data("py_36", filename)
- mode = black.Mode(target_versions=PY36_VERSIONS)
- assert_format(source, expected, mode, minimum_version=(3, 6))
-
-
-@pytest.mark.parametrize("filename", all_data_cases("py_37"))
-def test_python_37(filename: str) -> None:
- source, expected = read_data("py_37", filename)
- mode = black.Mode(target_versions={black.TargetVersion.PY37})
- assert_format(source, expected, mode, minimum_version=(3, 7))
-
-
-@pytest.mark.parametrize("filename", all_data_cases("py_38"))
-def test_python_38(filename: str) -> None:
- source, expected = read_data("py_38", filename)
- mode = black.Mode(target_versions={black.TargetVersion.PY38})
- assert_format(source, expected, mode, minimum_version=(3, 8))
-
-
-@pytest.mark.parametrize("filename", all_data_cases("py_39"))
-def test_python_39(filename: str) -> None:
- source, expected = read_data("py_39", filename)
- mode = black.Mode(target_versions={black.TargetVersion.PY39})
- assert_format(source, expected, mode, minimum_version=(3, 9))
-
-
-@pytest.mark.parametrize("filename", all_data_cases("py_310"))
-def test_python_310(filename: str) -> None:
- source, expected = read_data("py_310", filename)
- mode = black.Mode(target_versions={black.TargetVersion.PY310})
- assert_format(source, expected, mode, minimum_version=(3, 10))
-
-
-@pytest.mark.parametrize("filename", all_data_cases("py_310"))
-def test_python_310_without_target_version(filename: str) -> None:
- source, expected = read_data("py_310", filename)
- mode = black.Mode()
- assert_format(source, expected, mode, minimum_version=(3, 10))
-
-
def test_patma_invalid() -> None:
source, expected = read_data("miscellaneous", "pattern_matching_invalid")
mode = black.Mode(target_versions={black.TargetVersion.PY310})
assert_format(source, expected, mode, minimum_version=(3, 10))
exc_info.match("Cannot parse: 10:11")
-
-
-@pytest.mark.parametrize("filename", all_data_cases("py_311"))
-def test_python_311(filename: str) -> None:
- source, expected = read_data("py_311", filename)
- mode = black.Mode(target_versions={black.TargetVersion.PY311})
- assert_format(source, expected, mode, minimum_version=(3, 11))
-
-
-@pytest.mark.parametrize("filename", all_data_cases("py_312"))
-def test_python_312(filename: str) -> None:
- source, expected = read_data("py_312", filename)
- mode = black.Mode(target_versions={black.TargetVersion.PY312})
- assert_format(source, expected, mode, minimum_version=(3, 12))
-
-
-@pytest.mark.parametrize("filename", all_data_cases("fast"))
-def test_fast_cases(filename: str) -> None:
- source, expected = read_data("fast", filename)
- assert_format(source, expected, fast=True)
-
-
-@pytest.mark.filterwarnings("ignore:invalid escape sequence.*:DeprecationWarning")
-def test_docstring_no_string_normalization() -> None:
- """Like test_docstring but with string normalization off."""
- source, expected = read_data("miscellaneous", "docstring_no_string_normalization")
- mode = replace(DEFAULT_MODE, string_normalization=False)
- assert_format(source, expected, mode)
-
-
-def test_docstring_line_length_6() -> None:
- """Like test_docstring but with line length set to 6."""
- source, expected = read_data("miscellaneous", "linelength6")
- mode = black.Mode(line_length=6)
- assert_format(source, expected, mode)
-
-
-def test_preview_docstring_no_string_normalization() -> None:
- """
- Like test_docstring but with string normalization off *and* the preview style
- enabled.
- """
- source, expected = read_data(
- "miscellaneous", "docstring_preview_no_string_normalization"
- )
- mode = replace(DEFAULT_MODE, string_normalization=False, preview=True)
- assert_format(source, expected, mode)
-
-
-def test_long_strings_flag_disabled() -> None:
- """Tests for turning off the string processing logic."""
- source, expected = read_data("miscellaneous", "long_strings_flag_disabled")
- mode = replace(DEFAULT_MODE, experimental_string_processing=False)
- assert_format(source, expected, mode)
-
-
-def test_stub() -> None:
- mode = replace(DEFAULT_MODE, is_pyi=True)
- source, expected = read_data("miscellaneous", "stub.pyi")
- assert_format(source, expected, mode)
-
-
-def test_nested_stub() -> None:
- mode = replace(DEFAULT_MODE, is_pyi=True, preview=True)
- source, expected = read_data("miscellaneous", "nested_stub.pyi")
- assert_format(source, expected, mode)
-
-
-def test_power_op_newline() -> None:
- # requires line_length=0
- source, expected = read_data("miscellaneous", "power_op_newline")
- assert_format(source, expected, mode=black.Mode(line_length=0))
-
-
-def test_type_comment_syntax_error() -> None:
- """Test that black is able to format python code with type comment syntax errors."""
- source, expected = read_data("type_comments", "type_comment_syntax_error")
- assert_format(source, expected)
- black.assert_equivalent(source, expected)
+import argparse
+import functools
import os
+import shlex
import sys
import unittest
from contextlib import contextmanager
-from dataclasses import replace
+from dataclasses import dataclass, field, replace
from functools import partial
from pathlib import Path
from typing import Any, Iterator, List, Optional, Tuple
import black
+from black.const import DEFAULT_LINE_LENGTH
from black.debug import DebugVisitor
from black.mode import TargetVersion
from black.output import diff, err, out
fs = partial(black.format_str, mode=DEFAULT_MODE)
+@dataclass
+class TestCaseArgs:
+ mode: black.Mode = field(default_factory=black.Mode)
+ fast: bool = False
+ minimum_version: Optional[Tuple[int, int]] = None
+
+
def _assert_format_equal(expected: str, actual: str) -> None:
if actual != expected and (conftest.PRINT_FULL_TREE or conftest.PRINT_TREE_DIFF):
bdv: DebugVisitor[Any]
return case_path
+def read_data_with_mode(
+ subdir_name: str, name: str, data: bool = True
+) -> Tuple[TestCaseArgs, str, str]:
+ """read_data_with_mode('test_name') -> Mode(), 'input', 'output'"""
+ return read_data_from_file(get_case_path(subdir_name, name, data))
+
+
def read_data(subdir_name: str, name: str, data: bool = True) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'"""
- return read_data_from_file(get_case_path(subdir_name, name, data))
+ _, input, output = read_data_with_mode(subdir_name, name, data)
+ return input, output
+
+
+def _parse_minimum_version(version: str) -> Tuple[int, int]:
+ major, minor = version.split(".")
+ return int(major), int(minor)
-def read_data_from_file(file_name: Path) -> Tuple[str, str]:
+@functools.lru_cache()
+def get_flags_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--target-version",
+ action="append",
+ type=lambda val: TargetVersion[val.upper()],
+ default=(),
+ )
+ parser.add_argument("--line-length", default=DEFAULT_LINE_LENGTH, type=int)
+ parser.add_argument(
+ "--skip-string-normalization", default=False, action="store_true"
+ )
+ parser.add_argument("--pyi", default=False, action="store_true")
+ parser.add_argument("--ipynb", default=False, action="store_true")
+ parser.add_argument(
+ "--skip-magic-trailing-comma", default=False, action="store_true"
+ )
+ parser.add_argument("--preview", default=False, action="store_true")
+ parser.add_argument("--fast", default=False, action="store_true")
+ parser.add_argument(
+ "--minimum-version",
+ type=_parse_minimum_version,
+ default=None,
+ help=(
+ "Minimum version of Python where this test case is parseable. If this is"
+ " set, the test case will be run twice: once with the specified"
+ " --target-version, and once with --target-version set to exactly the"
+ " specified version. This ensures that Black's autodetection of the target"
+ " version works correctly."
+ ),
+ )
+ return parser
+
+
+def parse_mode(flags_line: str) -> TestCaseArgs:
+ parser = get_flags_parser()
+ args = parser.parse_args(shlex.split(flags_line))
+ mode = black.Mode(
+ target_versions=set(args.target_version),
+ line_length=args.line_length,
+ string_normalization=not args.skip_string_normalization,
+ is_pyi=args.pyi,
+ is_ipynb=args.ipynb,
+ magic_trailing_comma=not args.skip_magic_trailing_comma,
+ preview=args.preview,
+ )
+ return TestCaseArgs(mode=mode, fast=args.fast, minimum_version=args.minimum_version)
+
+
+def read_data_from_file(file_name: Path) -> Tuple[TestCaseArgs, str, str]:
with open(file_name, "r", encoding="utf8") as test:
lines = test.readlines()
_input: List[str] = []
_output: List[str] = []
result = _input
+ mode = TestCaseArgs()
for line in lines:
+ if not _input and line.startswith("# flags: "):
+ mode = parse_mode(line[len("# flags: ") :])
+ continue
line = line.replace(EMPTY_LINE, "")
if line.rstrip() == "# output":
result = _output
if _input and not _output:
# If there's no output marker, treat the entire file as already pre-formatted.
_output = _input[:]
- return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
+ return mode, "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
def read_jupyter_notebook(subdir_name: str, name: str, data: bool = True) -> str: