from asyncio.base_events import BaseEventLoop
from concurrent.futures import Executor, ProcessPoolExecutor
from datetime import datetime
-from enum import Enum, Flag
+from enum import Enum
from functools import lru_cache, partial, wraps
import io
import itertools
)
from appdirs import user_cache_dir
-from attr import dataclass, Factory
+from attr import dataclass, evolve, Factory
import click
import toml
from blib2to3.pytree import Node, Leaf, type_repr
from blib2to3 import pygram, pytree
from blib2to3.pgen2 import driver, token
+from blib2to3.pgen2.grammar import Grammar
from blib2to3.pgen2.parse import ParseError
YES = 2
-class FileMode(Flag):
- AUTO_DETECT = 0
- PYTHON36 = 1
- PYI = 2
- NO_STRING_NORMALIZATION = 4
- NO_NUMERIC_UNDERSCORE_NORMALIZATION = 8
+class TargetVersion(Enum):
+ PYPY35 = 1
+ CPY27 = 2
+ CPY33 = 3
+ CPY34 = 4
+ CPY35 = 5
+ CPY36 = 6
+ CPY37 = 7
+ CPY38 = 8
+
+ def is_python2(self) -> bool:
+ return self is TargetVersion.CPY27
+
+
+PY36_VERSIONS = {TargetVersion.CPY36, TargetVersion.CPY37, TargetVersion.CPY38}
+
+
+class Feature(Enum):
+ # All string literals are unicode
+ UNICODE_LITERALS = 1
+ F_STRINGS = 2
+ NUMERIC_UNDERSCORES = 3
+ TRAILING_COMMA = 4
+
+
+VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
+ TargetVersion.CPY27: set(),
+ TargetVersion.PYPY35: {Feature.UNICODE_LITERALS, Feature.F_STRINGS},
+ TargetVersion.CPY33: {Feature.UNICODE_LITERALS},
+ TargetVersion.CPY34: {Feature.UNICODE_LITERALS},
+ TargetVersion.CPY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
+ TargetVersion.CPY36: {
+ Feature.UNICODE_LITERALS,
+ Feature.F_STRINGS,
+ Feature.NUMERIC_UNDERSCORES,
+ Feature.TRAILING_COMMA,
+ },
+ TargetVersion.CPY37: {
+ Feature.UNICODE_LITERALS,
+ Feature.F_STRINGS,
+ Feature.NUMERIC_UNDERSCORES,
+ Feature.TRAILING_COMMA,
+ },
+ TargetVersion.CPY38: {
+ Feature.UNICODE_LITERALS,
+ Feature.F_STRINGS,
+ Feature.NUMERIC_UNDERSCORES,
+ Feature.TRAILING_COMMA,
+ },
+}
- @classmethod
- def from_configuration(
- cls,
- *,
- py36: bool,
- pyi: bool,
- skip_string_normalization: bool,
- skip_numeric_underscore_normalization: bool,
- ) -> "FileMode":
- mode = cls.AUTO_DETECT
- if py36:
- mode |= cls.PYTHON36
- if pyi:
- mode |= cls.PYI
- if skip_string_normalization:
- mode |= cls.NO_STRING_NORMALIZATION
- if skip_numeric_underscore_normalization:
- mode |= cls.NO_NUMERIC_UNDERSCORE_NORMALIZATION
- return mode
+
+@dataclass
+class FileMode:
+ target_versions: Set[TargetVersion] = Factory(set)
+ line_length: int = DEFAULT_LINE_LENGTH
+ numeric_underscore_normalization: bool = True
+ string_normalization: bool = True
+ is_pyi: bool = False
+
+ def get_cache_key(self) -> str:
+ if self.target_versions:
+ version_str = ",".join(
+ str(version.value)
+ for version in sorted(self.target_versions, key=lambda v: v.value)
+ )
+ else:
+ version_str = "-"
+ parts = [
+ version_str,
+ str(self.line_length),
+ str(int(self.numeric_underscore_normalization)),
+ str(int(self.string_normalization)),
+ str(int(self.is_pyi)),
+ ]
+ return ".".join(parts)
+
+
+def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
+ return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
def read_pyproject_toml(
help="How many characters per line to allow.",
show_default=True,
)
+@click.option(
+ "-t",
+ "--target-version",
+ type=click.Choice([v.name.lower() for v in TargetVersion]),
+ callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
+ multiple=True,
+ help=(
+ "Python versions that should be supported by Black's output. [default: "
+ "per-file auto-detection]"
+ ),
+)
@click.option(
"--py36",
is_flag=True,
def main(
ctx: click.Context,
line_length: int,
+ target_version: List[TargetVersion],
check: bool,
diff: bool,
fast: bool,
) -> None:
"""The uncompromising code formatter."""
write_back = WriteBack.from_configuration(check=check, diff=diff)
- mode = FileMode.from_configuration(
- py36=py36,
- pyi=pyi,
- skip_string_normalization=skip_string_normalization,
- skip_numeric_underscore_normalization=skip_numeric_underscore_normalization,
+ if target_version:
+ if py36:
+ err(f"Cannot use both --target-version and --py36")
+ ctx.exit(2)
+ else:
+ versions = set(target_version)
+ elif py36:
+ versions = PY36_VERSIONS
+ else:
+ # We'll autodetect later.
+ versions = set()
+ mode = FileMode(
+ target_versions=versions,
+ line_length=line_length,
+ is_pyi=pyi,
+ string_normalization=not skip_string_normalization,
+ numeric_underscore_normalization=not skip_numeric_underscore_normalization,
)
if config and verbose:
out(f"Using configuration from {config}.", bold=False, fg="blue")
if len(sources) == 1:
reformat_one(
src=sources.pop(),
- line_length=line_length,
fast=fast,
write_back=write_back,
mode=mode,
loop.run_until_complete(
schedule_formatting(
sources=sources,
- line_length=line_length,
fast=fast,
write_back=write_back,
mode=mode,
def reformat_one(
- src: Path,
- line_length: int,
- fast: bool,
- write_back: WriteBack,
- mode: FileMode,
- report: "Report",
+ src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
) -> None:
"""Reformat a single file under `src` without spawning child processes.
try:
changed = Changed.NO
if not src.is_file() and str(src) == "-":
- if format_stdin_to_stdout(
- line_length=line_length, fast=fast, write_back=write_back, mode=mode
- ):
+ if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
changed = Changed.YES
else:
cache: Cache = {}
if write_back != WriteBack.DIFF:
- cache = read_cache(line_length, mode)
+ cache = read_cache(mode)
res_src = src.resolve()
if res_src in cache and cache[res_src] == get_cache_info(res_src):
changed = Changed.CACHED
if changed is not Changed.CACHED and format_file_in_place(
- src,
- line_length=line_length,
- fast=fast,
- write_back=write_back,
- mode=mode,
+ src, fast=fast, write_back=write_back, mode=mode
):
changed = Changed.YES
if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
write_back is WriteBack.CHECK and changed is Changed.NO
):
- write_cache(cache, [src], line_length, mode)
+ write_cache(cache, [src], mode)
report.done(src, changed)
except Exception as exc:
report.failed(src, str(exc))
async def schedule_formatting(
sources: Set[Path],
- line_length: int,
fast: bool,
write_back: WriteBack,
mode: FileMode,
"""
cache: Cache = {}
if write_back != WriteBack.DIFF:
- cache = read_cache(line_length, mode)
+ cache = read_cache(mode)
sources, cached = filter_cached(cache, sources)
for src in sorted(cached):
report.done(src, Changed.CACHED)
lock = manager.Lock()
tasks = {
loop.run_in_executor(
- executor,
- format_file_in_place,
- src,
- line_length,
- fast,
- write_back,
- mode,
- lock,
+ executor, format_file_in_place, src, fast, mode, write_back, lock
): src
for src in sorted(sources)
}
if cancelled:
await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
if sources_to_cache:
- write_cache(cache, sources_to_cache, line_length, mode)
+ write_cache(cache, sources_to_cache, mode)
def format_file_in_place(
src: Path,
- line_length: int,
fast: bool,
+ mode: FileMode,
write_back: WriteBack = WriteBack.NO,
- mode: FileMode = FileMode.AUTO_DETECT,
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
) -> bool:
"""Format file under `src` path. Return True if changed.
`line_length` and `fast` options are passed to :func:`format_file_contents`.
"""
if src.suffix == ".pyi":
- mode |= FileMode.PYI
+ mode = evolve(mode, is_pyi=True)
then = datetime.utcfromtimestamp(src.stat().st_mtime)
with open(src, "rb") as buf:
src_contents, encoding, newline = decode_bytes(buf.read())
try:
- dst_contents = format_file_contents(
- src_contents, line_length=line_length, fast=fast, mode=mode
- )
+ dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
except NothingChanged:
return False
def format_stdin_to_stdout(
- line_length: int,
- fast: bool,
- write_back: WriteBack = WriteBack.NO,
- mode: FileMode = FileMode.AUTO_DETECT,
+ fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
) -> bool:
"""Format file on stdin. Return True if changed.
If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
- write a diff to stdout.
- `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
+ write a diff to stdout. The `mode` argument is passed to
:func:`format_file_contents`.
"""
then = datetime.utcnow()
src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
dst = src
try:
- dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
+ dst = format_file_contents(src, fast=fast, mode=mode)
return True
except NothingChanged:
def format_file_contents(
- src_contents: str,
- *,
- line_length: int,
- fast: bool,
- mode: FileMode = FileMode.AUTO_DETECT,
+ src_contents: str, *, fast: bool, mode: FileMode
) -> FileContent:
"""Reformat contents a file and return new contents.
if src_contents.strip() == "":
raise NothingChanged
- dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
+ dst_contents = format_str(src_contents, mode=mode)
if src_contents == dst_contents:
raise NothingChanged
if not fast:
assert_equivalent(src_contents, dst_contents)
- assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
+ assert_stable(src_contents, dst_contents, mode=mode)
return dst_contents
-def format_str(
- src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
-) -> FileContent:
+def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
"""Reformat a string and return new contents.
`line_length` determines how many characters per line are allowed.
"""
- src_node = lib2to3_parse(src_contents.lstrip())
+ src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
dst_contents = ""
future_imports = get_future_imports(src_node)
- is_pyi = bool(mode & FileMode.PYI)
- py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
- normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
+ if mode.target_versions:
+ versions = mode.target_versions
+ else:
+ versions = detect_target_versions(src_node)
normalize_fmt_off(src_node)
lines = LineGenerator(
- remove_u_prefix=py36 or "unicode_literals" in future_imports,
- is_pyi=is_pyi,
- normalize_strings=normalize_strings,
- allow_underscores=py36
- and not bool(mode & FileMode.NO_NUMERIC_UNDERSCORE_NORMALIZATION),
+ remove_u_prefix="unicode_literals" in future_imports
+ or supports_feature(versions, Feature.UNICODE_LITERALS),
+ is_pyi=mode.is_pyi,
+ normalize_strings=mode.string_normalization,
+ allow_underscores=mode.numeric_underscore_normalization
+ and supports_feature(versions, Feature.NUMERIC_UNDERSCORES),
)
- elt = EmptyLineTracker(is_pyi=is_pyi)
+ elt = EmptyLineTracker(is_pyi=mode.is_pyi)
empty_line = Line()
after = 0
for current_line in lines.visit(src_node):
before, after = elt.maybe_empty_lines(current_line)
for _ in range(before):
dst_contents += str(empty_line)
- for line in split_line(current_line, line_length=line_length, py36=py36):
+ for line in split_line(
+ current_line,
+ line_length=mode.line_length,
+ supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
+ ):
dst_contents += str(line)
return dst_contents
]
-def lib2to3_parse(src_txt: str) -> Node:
+def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
+ if not target_versions:
+ return GRAMMARS
+ elif all(not version.is_python2() for version in target_versions):
+ # Python 2-compatible code, so don't try Python 3 grammar.
+ return [
+ pygram.python_grammar_no_print_statement_no_exec_statement,
+ pygram.python_grammar_no_print_statement,
+ ]
+ else:
+ return [pygram.python_grammar]
+
+
+def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
"""Given a string with source, return the lib2to3 Node."""
if src_txt[-1:] != "\n":
src_txt += "\n"
- for grammar in GRAMMARS:
+
+ for grammar in get_grammars(set(target_versions)):
drv = driver.Driver(grammar, pytree.convert)
try:
result = drv.parse_string(src_txt, True)
def split_line(
- line: Line, line_length: int, inner: bool = False, py36: bool = False
+ line: Line,
+ line_length: int,
+ inner: bool = False,
+ supports_trailing_commas: bool = False,
) -> Iterator[Line]:
"""Split a `line` into potentially many lines.
current `line`, possibly transitively. This means we can fallback to splitting
by delimiters if the LHS/RHS don't yield any results.
- If `py36` is True, splitting may generate syntax that is only compatible
- with Python 3.6 and later.
+ If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
"""
if line.is_comment:
yield line
split_funcs = [left_hand_split]
else:
- def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
+ def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
for omit in generate_trailers_to_omit(line, line_length):
- lines = list(right_hand_split(line, line_length, py36, omit=omit))
+ lines = list(
+ right_hand_split(
+ line, line_length, supports_trailing_commas, omit=omit
+ )
+ )
if is_line_short_enough(lines[0], line_length=line_length):
yield from lines
return
# All splits failed, best effort split with no omits.
# This mostly happens to multiline strings that are by definition
# reported as not fitting a single line.
- yield from right_hand_split(line, py36)
+ yield from right_hand_split(line, supports_trailing_commas)
if line.inside_brackets:
split_funcs = [delimiter_split, standalone_comment_split, rhs]
# split altogether.
result: List[Line] = []
try:
- for l in split_func(line, py36):
+ for l in split_func(line, supports_trailing_commas):
if str(l).strip("\n") == line_str:
raise CannotSplit("Split function returned an unchanged result")
result.extend(
- split_line(l, line_length=line_length, inner=True, py36=py36)
+ split_line(
+ l,
+ line_length=line_length,
+ inner=True,
+ supports_trailing_commas=supports_trailing_commas,
+ )
)
except CannotSplit:
continue
yield line
-def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def left_hand_split(
+ line: Line, supports_trailing_commas: bool = False
+) -> Iterator[Line]:
"""Split line into many lines, starting with the first matching bracket pair.
Note: this usually looks weird, only use this for function definitions.
def right_hand_split(
- line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
+ line: Line,
+ line_length: int,
+ supports_trailing_commas: bool = False,
+ omit: Collection[LeafID] = (),
) -> Iterator[Line]:
"""Split line into many lines, starting with the last matching bracket pair.
):
omit = {id(closing_bracket), *omit}
try:
- yield from right_hand_split(line, line_length, py36=py36, omit=omit)
+ yield from right_hand_split(
+ line,
+ line_length,
+ supports_trailing_commas=supports_trailing_commas,
+ omit=omit,
+ )
return
except CannotSplit:
"""
@wraps(split_func)
- def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
- for l in split_func(line, py36):
+ def split_wrapper(
+ line: Line, supports_trailing_commas: bool = False
+ ) -> Iterator[Line]:
+ for l in split_func(line, supports_trailing_commas):
normalize_prefix(l.leaves[0], inside_brackets=True)
yield l
@dont_increase_indentation
-def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def delimiter_split(
+ line: Line, supports_trailing_commas: bool = False
+) -> Iterator[Line]:
"""Split according to delimiters of the highest priority.
If `py36` is True, the split will add trailing commas also in function
if leaf.bracket_depth == lowest_depth and is_vararg(
leaf, within=VARARGS_PARENTS
):
- trailing_comma_safe = trailing_comma_safe and py36
+ trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
leaf_priority = bt.delimiters.get(id(leaf))
if leaf_priority == delimiter_priority:
yield current_line
@dont_increase_indentation
-def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def standalone_comment_split(
+ line: Line, supports_trailing_commas: bool = False
+) -> Iterator[Line]:
"""Split standalone comments from the rest of the line."""
if not line.contains_standalone_comments(0):
raise CannotSplit("Line does not have any standalone comments")
return max_priority == COMMA_PRIORITY
-def is_python36(node: Node) -> bool:
- """Return True if the current file is using Python 3.6+ features.
+def get_features_used(node: Node) -> Set[Feature]:
+ """Return a set of (relatively) new Python features used in this file.
Currently looking for:
- f-strings;
- underscores in numeric literals; and
- trailing commas after * or ** in function signatures and calls.
"""
+ features: Set[Feature] = set()
for n in node.pre_order():
if n.type == token.STRING:
value_head = n.value[:2] # type: ignore
if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
- return True
+ features.add(Feature.F_STRINGS)
elif n.type == token.NUMBER:
if "_" in n.value: # type: ignore
- return True
+ features.add(Feature.NUMERIC_UNDERSCORES)
elif (
n.type in {syms.typedargslist, syms.arglist}
):
for ch in n.children:
if ch.type in STARS:
- return True
+ features.add(Feature.TRAILING_COMMA)
if ch.type == syms.argument:
for argch in ch.children:
if argch.type in STARS:
- return True
+ features.add(Feature.TRAILING_COMMA)
- return False
+ return features
+
+
+def detect_target_versions(node: Node) -> Set[TargetVersion]:
+ """Detect the version to target based on the nodes used."""
+ features = get_features_used(node)
+ return {
+ version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
+ }
def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
) from None
-def assert_stable(
- src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
-) -> None:
+def assert_stable(src: str, dst: str, mode: FileMode) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
- newdst = format_str(dst, line_length=line_length, mode=mode)
+ newdst = format_str(dst, mode=mode)
if dst != newdst:
log = dump_to_file(
diff(src, dst, "source", "first pass"),
return False
-def get_cache_file(line_length: int, mode: FileMode) -> Path:
- return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
+def get_cache_file(mode: FileMode) -> Path:
+ return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
-def read_cache(line_length: int, mode: FileMode) -> Cache:
+def read_cache(mode: FileMode) -> Cache:
"""Read the cache if it exists and is well formed.
If it is not well formed, the call to write_cache later should resolve the issue.
"""
- cache_file = get_cache_file(line_length, mode)
+ cache_file = get_cache_file(mode)
if not cache_file.exists():
return {}
return todo, done
-def write_cache(
- cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
-) -> None:
+def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
"""Update the cache file."""
- cache_file = get_cache_file(line_length, mode)
+ cache_file = get_cache_file(mode)
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
from click.testing import CliRunner
import black
+from black import Feature
try:
import blackd
has_blackd_deps = True
-ll = 88
-ff = partial(black.format_file_in_place, line_length=ll, fast=True)
-fs = partial(black.format_str, line_length=ll)
+ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
+fs = partial(black.format_str, mode=black.FileMode())
THIS_FILE = Path(__file__)
THIS_DIR = THIS_FILE.parent
EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
black.err(str(ve))
self.assertEqual(expected, actual)
+ def invokeBlack(
+ self, args: List[str], exit_code: int = 0, ignore_config: bool = True
+ ) -> None:
+ runner = BlackRunner()
+ if ignore_config:
+ args = ["--config", str(THIS_DIR / "empty.toml"), *args]
+ result = runner.invoke(black.main, args)
+ self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
+
@patch("black.dump_to_file", dump_to_stderr)
def test_empty(self) -> None:
source = expected = ""
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
def test_empty_ff(self) -> None:
expected = ""
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
self.assertFalse(ff(THIS_FILE))
@patch("black.dump_to_file", dump_to_stderr)
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
def test_piping(self) -> None:
source, expected = read_data("../black", data=False)
result = BlackRunner().invoke(
black.main,
- ["-", "--fast", f"--line-length={ll}"],
+ ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
input=BytesIO(source.encode("utf8")),
)
self.assertEqual(result.exit_code, 0)
self.assertFormatEqual(expected, result.output)
black.assert_equivalent(source, result.output)
- black.assert_stable(source, result.output, line_length=ll)
+ black.assert_stable(source, result.output, black.FileMode())
def test_piping_diff(self) -> None:
diff_header = re.compile(
source, _ = read_data("expression.py")
expected, _ = read_data("expression.diff")
config = THIS_DIR / "data" / "empty_pyproject.toml"
- args = ["-", "--fast", f"--line-length={ll}", "--diff", f"--config={config}"]
+ args = [
+ "-",
+ "--fast",
+ f"--line-length={black.DEFAULT_LINE_LENGTH}",
+ "--diff",
+ f"--config={config}",
+ ]
result = BlackRunner().invoke(
black.main, args, input=BytesIO(source.encode("utf8"))
)
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
@patch("black.dump_to_file", dump_to_stderr)
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_function2(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_expression(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
def test_expression_ff(self) -> None:
source, expected = read_data("expression")
self.assertFormatEqual(expected, actual)
with patch("black.dump_to_file", dump_to_stderr):
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
def test_expression_diff(self) -> None:
source, _ = read_data("expression.py")
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_string_quotes(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
- mode = black.FileMode.NO_STRING_NORMALIZATION
+ black.assert_stable(source, actual, black.FileMode())
+ mode = black.FileMode(string_normalization=False)
not_normalized = fs(source, mode=mode)
self.assertFormatEqual(source, not_normalized)
black.assert_equivalent(source, not_normalized)
- black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
+ black.assert_stable(source, not_normalized, mode=mode)
@patch("black.dump_to_file", dump_to_stderr)
def test_slices(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments2(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments3(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments4(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments5(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_comments6(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_cantfit(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_import_spacing(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_composition(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_empty_lines(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_string_prefixes(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_numeric_literals(self) -> None:
source, expected = read_data("numeric_literals")
- actual = fs(source, mode=black.FileMode.PYTHON36)
+ mode = black.FileMode(target_versions=black.PY36_VERSIONS)
+ actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, mode)
@patch("black.dump_to_file", dump_to_stderr)
def test_numeric_literals_ignoring_underscores(self) -> None:
source, expected = read_data("numeric_literals_skip_underscores")
- mode = (
- black.FileMode.PYTHON36 | black.FileMode.NO_NUMERIC_UNDERSCORE_NORMALIZATION
+ mode = black.FileMode(
+ numeric_underscore_normalization=False, target_versions=black.PY36_VERSIONS
)
actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll, mode=mode)
+ black.assert_stable(source, actual, mode)
@patch("black.dump_to_file", dump_to_stderr)
def test_numeric_literals_py2(self) -> None:
source, expected = read_data("numeric_literals_py2")
actual = fs(source)
self.assertFormatEqual(expected, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_python2(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
# black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_python2_unicode_literals(self) -> None:
source, expected = read_data("python2_unicode_literals")
actual = fs(source)
self.assertFormatEqual(expected, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_stub(self) -> None:
- mode = black.FileMode.PYI
+ mode = black.FileMode(is_pyi=True)
source, expected = read_data("stub.pyi")
actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual)
- black.assert_stable(source, actual, line_length=ll, mode=mode)
+ black.assert_stable(source, actual, mode)
@patch("black.dump_to_file", dump_to_stderr)
def test_python37(self) -> None:
major, minor = sys.version_info[:2]
if major > 3 or (major == 3 and minor >= 7):
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_fmtonoff(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_fmtonoff2(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_remove_empty_parentheses_after_class(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_new_line_between_class_and_code(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
@patch("black.dump_to_file", dump_to_stderr)
def test_bracket_match(self) -> None:
actual = fs(source)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
- black.assert_stable(source, actual, line_length=ll)
+ black.assert_stable(source, actual, black.FileMode())
def test_comment_indentation(self) -> None:
contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
"2 files would fail to reformat.",
)
- def test_is_python36(self) -> None:
+ def test_get_features_used(self) -> None:
node = black.lib2to3_parse("def f(*, arg): ...\n")
- self.assertFalse(black.is_python36(node))
+ self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("def f(*, arg,): ...\n")
- self.assertTrue(black.is_python36(node))
+ self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
node = black.lib2to3_parse("def f(*, arg): f'string'\n")
- self.assertTrue(black.is_python36(node))
+ self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
node = black.lib2to3_parse("123_456\n")
- self.assertTrue(black.is_python36(node))
+ self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
node = black.lib2to3_parse("123456\n")
- self.assertFalse(black.is_python36(node))
+ self.assertEqual(black.get_features_used(node), set())
source, expected = read_data("function")
node = black.lib2to3_parse(source)
- self.assertTrue(black.is_python36(node))
+ self.assertEqual(
+ black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
+ )
node = black.lib2to3_parse(expected)
- self.assertTrue(black.is_python36(node))
+ self.assertEqual(
+ black.get_features_used(node),
+ {Feature.TRAILING_COMMA, Feature.F_STRINGS, Feature.NUMERIC_UNDERSCORES},
+ )
source, expected = read_data("expression")
node = black.lib2to3_parse(source)
- self.assertFalse(black.is_python36(node))
+ self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse(expected)
- self.assertFalse(black.is_python36(node))
+ self.assertEqual(black.get_features_used(node), set())
def test_get_future_imports(self) -> None:
node = black.lib2to3_parse("\n")
def test_format_file_contents(self) -> None:
empty = ""
+ mode = black.FileMode()
with self.assertRaises(black.NothingChanged):
- black.format_file_contents(empty, line_length=ll, fast=False)
+ black.format_file_contents(empty, mode=mode, fast=False)
just_nl = "\n"
with self.assertRaises(black.NothingChanged):
- black.format_file_contents(just_nl, line_length=ll, fast=False)
+ black.format_file_contents(just_nl, mode=mode, fast=False)
same = "l = [1, 2, 3]\n"
with self.assertRaises(black.NothingChanged):
- black.format_file_contents(same, line_length=ll, fast=False)
+ black.format_file_contents(same, mode=mode, fast=False)
different = "l = [1,2,3]"
expected = same
- actual = black.format_file_contents(different, line_length=ll, fast=False)
+ actual = black.format_file_contents(different, mode=mode, fast=False)
self.assertEqual(expected, actual)
invalid = "return if you can"
with self.assertRaises(black.InvalidInput) as e:
- black.format_file_contents(invalid, line_length=ll, fast=False)
+ black.format_file_contents(invalid, mode=mode, fast=False)
self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
def test_endmarker(self) -> None:
self.assertEqual("".join(err_lines), "")
def test_cache_broken_file(self) -> None:
- mode = black.FileMode.AUTO_DETECT
+ mode = black.FileMode()
with cache_dir() as workspace:
- cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
+ cache_file = black.get_cache_file(mode)
with cache_file.open("w") as fobj:
fobj.write("this is not a pickle")
- self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
+ self.assertEqual(black.read_cache(mode), {})
src = (workspace / "test.py").resolve()
with src.open("w") as fobj:
fobj.write("print('hello')")
- result = CliRunner().invoke(black.main, [str(src)])
- self.assertEqual(result.exit_code, 0)
- cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
+ self.invokeBlack([str(src)])
+ cache = black.read_cache(mode)
self.assertIn(src, cache)
def test_cache_single_file_already_cached(self) -> None:
- mode = black.FileMode.AUTO_DETECT
+ mode = black.FileMode()
with cache_dir() as workspace:
src = (workspace / "test.py").resolve()
with src.open("w") as fobj:
fobj.write("print('hello')")
- black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
- result = CliRunner().invoke(black.main, [str(src)])
- self.assertEqual(result.exit_code, 0)
+ black.write_cache({}, [src], mode)
+ self.invokeBlack([str(src)])
with src.open("r") as fobj:
self.assertEqual(fobj.read(), "print('hello')")
@event_loop(close=False)
def test_cache_multiple_files(self) -> None:
- mode = black.FileMode.AUTO_DETECT
+ mode = black.FileMode()
with cache_dir() as workspace, patch(
"black.ProcessPoolExecutor", new=ThreadPoolExecutor
):
two = (workspace / "two.py").resolve()
with two.open("w") as fobj:
fobj.write("print('hello')")
- black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
- result = CliRunner().invoke(black.main, [str(workspace)])
- self.assertEqual(result.exit_code, 0)
+ black.write_cache({}, [one], mode)
+ self.invokeBlack([str(workspace)])
with one.open("r") as fobj:
self.assertEqual(fobj.read(), "print('hello')")
with two.open("r") as fobj:
self.assertEqual(fobj.read(), 'print("hello")\n')
- cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
+ cache = black.read_cache(mode)
self.assertIn(one, cache)
self.assertIn(two, cache)
def test_no_cache_when_writeback_diff(self) -> None:
- mode = black.FileMode.AUTO_DETECT
+ mode = black.FileMode()
with cache_dir() as workspace:
src = (workspace / "test.py").resolve()
with src.open("w") as fobj:
fobj.write("print('hello')")
- result = CliRunner().invoke(black.main, [str(src), "--diff"])
- self.assertEqual(result.exit_code, 0)
- cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
+ self.invokeBlack([str(src), "--diff"])
+ cache_file = black.get_cache_file(mode)
self.assertFalse(cache_file.exists())
def test_no_cache_when_stdin(self) -> None:
- mode = black.FileMode.AUTO_DETECT
+ mode = black.FileMode()
with cache_dir():
result = CliRunner().invoke(
black.main, ["-"], input=BytesIO(b"print('hello')")
)
self.assertEqual(result.exit_code, 0)
- cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
+ cache_file = black.get_cache_file(mode)
self.assertFalse(cache_file.exists())
def test_read_cache_no_cachefile(self) -> None:
- mode = black.FileMode.AUTO_DETECT
+ mode = black.FileMode()
with cache_dir():
- self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
+ self.assertEqual(black.read_cache(mode), {})
def test_write_cache_read_cache(self) -> None:
- mode = black.FileMode.AUTO_DETECT
+ mode = black.FileMode()
with cache_dir() as workspace:
src = (workspace / "test.py").resolve()
src.touch()
- black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
- cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
+ black.write_cache({}, [src], mode)
+ cache = black.read_cache(mode)
self.assertIn(src, cache)
self.assertEqual(cache[src], black.get_cache_info(src))
self.assertEqual(done, {cached})
def test_write_cache_creates_directory_if_needed(self) -> None:
- mode = black.FileMode.AUTO_DETECT
+ mode = black.FileMode()
with cache_dir(exists=False) as workspace:
self.assertFalse(workspace.exists())
- black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
+ black.write_cache({}, [], mode)
self.assertTrue(workspace.exists())
@event_loop(close=False)
def test_failed_formatting_does_not_get_cached(self) -> None:
- mode = black.FileMode.AUTO_DETECT
+ mode = black.FileMode()
with cache_dir() as workspace, patch(
"black.ProcessPoolExecutor", new=ThreadPoolExecutor
):
clean = (workspace / "clean.py").resolve()
with clean.open("w") as fobj:
fobj.write('print("hello")\n')
- result = CliRunner().invoke(black.main, [str(workspace)])
- self.assertEqual(result.exit_code, 123)
- cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
+ self.invokeBlack([str(workspace)], exit_code=123)
+ cache = black.read_cache(mode)
self.assertNotIn(failing, cache)
self.assertIn(clean, cache)
def test_write_cache_write_fail(self) -> None:
- mode = black.FileMode.AUTO_DETECT
+ mode = black.FileMode()
with cache_dir(), patch.object(Path, "open") as mock:
mock.side_effect = OSError
- black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
+ black.write_cache({}, [], mode)
@event_loop(close=False)
def test_check_diff_use_together(self) -> None:
with cache_dir():
# Files which will be reformatted.
src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
- result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
- self.assertEqual(result.exit_code, 1, result.output)
+ self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
# Files which will not be reformatted.
src2 = (THIS_DIR / "data" / "composition.py").resolve()
- result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
- self.assertEqual(result.exit_code, 0, result.output)
+ self.invokeBlack([str(src2), "--diff", "--check"])
# Multi file command.
- result = CliRunner().invoke(
- black.main, [str(src1), str(src2), "--diff", "--check"]
- )
- self.assertEqual(result.exit_code, 1, result.output)
+ self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
def test_no_files(self) -> None:
with cache_dir():
# Without an argument, black exits with error code 0.
- result = CliRunner().invoke(black.main, [])
- self.assertEqual(result.exit_code, 0)
+ self.invokeBlack([])
def test_broken_symlink(self) -> None:
with cache_dir() as workspace:
symlink.symlink_to("nonexistent.py")
except OSError as e:
self.skipTest(f"Can't create symlinks: {e}")
- result = CliRunner().invoke(black.main, [str(workspace.resolve())])
- self.assertEqual(result.exit_code, 0)
+ self.invokeBlack([str(workspace.resolve())])
def test_read_cache_line_lengths(self) -> None:
- mode = black.FileMode.AUTO_DETECT
+ mode = black.FileMode()
+ short_mode = black.FileMode(line_length=1)
with cache_dir() as workspace:
path = (workspace / "file.py").resolve()
path.touch()
- black.write_cache({}, [path], 1, mode)
- one = black.read_cache(1, mode)
+ black.write_cache({}, [path], mode)
+ one = black.read_cache(mode)
self.assertIn(path, one)
- two = black.read_cache(2, mode)
+ two = black.read_cache(short_mode)
self.assertNotIn(path, two)
def test_single_file_force_pyi(self) -> None:
- reg_mode = black.FileMode.AUTO_DETECT
- pyi_mode = black.FileMode.PYI
+ reg_mode = black.FileMode()
+ pyi_mode = black.FileMode(is_pyi=True)
contents, expected = read_data("force_pyi")
with cache_dir() as workspace:
path = (workspace / "file.py").resolve()
with open(path, "w") as fh:
fh.write(contents)
- result = CliRunner().invoke(black.main, [str(path), "--pyi"])
- self.assertEqual(result.exit_code, 0)
+ self.invokeBlack([str(path), "--pyi"])
with open(path, "r") as fh:
actual = fh.read()
# verify cache with --pyi is separate
- pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
+ pyi_cache = black.read_cache(pyi_mode)
self.assertIn(path, pyi_cache)
- normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
+ normal_cache = black.read_cache(reg_mode)
self.assertNotIn(path, normal_cache)
self.assertEqual(actual, expected)
@event_loop(close=False)
def test_multi_file_force_pyi(self) -> None:
- reg_mode = black.FileMode.AUTO_DETECT
- pyi_mode = black.FileMode.PYI
+ reg_mode = black.FileMode()
+ pyi_mode = black.FileMode(is_pyi=True)
contents, expected = read_data("force_pyi")
with cache_dir() as workspace:
paths = [
for path in paths:
with open(path, "w") as fh:
fh.write(contents)
- result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
- self.assertEqual(result.exit_code, 0)
+ self.invokeBlack([str(p) for p in paths] + ["--pyi"])
for path in paths:
with open(path, "r") as fh:
actual = fh.read()
self.assertEqual(actual, expected)
# verify cache with --pyi is separate
- pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
- normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
+ pyi_cache = black.read_cache(pyi_mode)
+ normal_cache = black.read_cache(reg_mode)
for path in paths:
self.assertIn(path, pyi_cache)
self.assertNotIn(path, normal_cache)
self.assertFormatEqual(actual, expected)
def test_single_file_force_py36(self) -> None:
- reg_mode = black.FileMode.AUTO_DETECT
- py36_mode = black.FileMode.PYTHON36
+ reg_mode = black.FileMode()
+ py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
source, expected = read_data("force_py36")
with cache_dir() as workspace:
path = (workspace / "file.py").resolve()
with open(path, "w") as fh:
fh.write(source)
- result = CliRunner().invoke(black.main, [str(path), "--py36"])
- self.assertEqual(result.exit_code, 0)
+ self.invokeBlack([str(path), "--py36"])
with open(path, "r") as fh:
actual = fh.read()
# verify cache with --py36 is separate
- py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
+ py36_cache = black.read_cache(py36_mode)
self.assertIn(path, py36_cache)
- normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
+ normal_cache = black.read_cache(reg_mode)
self.assertNotIn(path, normal_cache)
self.assertEqual(actual, expected)
@event_loop(close=False)
def test_multi_file_force_py36(self) -> None:
- reg_mode = black.FileMode.AUTO_DETECT
- py36_mode = black.FileMode.PYTHON36
+ reg_mode = black.FileMode()
+ py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
source, expected = read_data("force_py36")
with cache_dir() as workspace:
paths = [
for path in paths:
with open(path, "w") as fh:
fh.write(source)
- result = CliRunner().invoke(
- black.main, [str(p) for p in paths] + ["--py36"]
- )
- self.assertEqual(result.exit_code, 0)
+ self.invokeBlack([str(p) for p in paths] + ["--py36"])
for path in paths:
with open(path, "r") as fh:
actual = fh.read()
self.assertEqual(actual, expected)
# verify cache with --py36 is separate
- pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
- normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
+ pyi_cache = black.read_cache(py36_mode)
+ normal_cache = black.read_cache(reg_mode)
for path in paths:
self.assertIn(path, pyi_cache)
self.assertNotIn(path, normal_cache)
def test_invalid_include_exclude(self) -> None:
for option in ["--include", "--exclude"]:
- result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
- self.assertEqual(result.exit_code, 2)
+ self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
def test_preserves_line_endings(self) -> None:
with TemporaryDirectory() as workspace:
async def test_blackd_invalid_python_variant(self) -> None:
app = blackd.make_app()
async with TestClient(TestServer(app)) as client:
- response = await client.post(
- "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: "lol"}
- )
- self.assertEqual(response.status, 400)
+
+ async def check(header_value: str, expected_status: int = 400) -> None:
+ response = await client.post(
+ "/",
+ data=b"what",
+ headers={blackd.PYTHON_VARIANT_HEADER: header_value},
+ )
+ self.assertEqual(response.status, expected_status)
+
+ await check("lol")
+ await check("ruby3.5")
+ await check("pyi3.6")
+ await check("cpy1.5")
+ await check("2.8")
+ await check("cpy2.8")
+ await check("3.0")
+ await check("pypy3.0")
+ await check("jython3.4")
@unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
@async_test
@unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
@async_test
- async def test_blackd_py36(self) -> None:
+ async def test_blackd_python_variant(self) -> None:
app = blackd.make_app()
+ code = (
+ "def f(\n"
+ " and_has_a_bunch_of,\n"
+ " very_long_arguments_too,\n"
+ " and_lots_of_them_as_well_lol,\n"
+ " **and_very_long_keyword_arguments\n"
+ "):\n"
+ " pass\n"
+ )
async with TestClient(TestServer(app)) as client:
- response = await client.post(
- "/",
- data=(
- "def f(\n"
- " and_has_a_bunch_of,\n"
- " very_long_arguments_too,\n"
- " and_lots_of_them_as_well_lol,\n"
- " **and_very_long_keyword_arguments\n"
- "):\n"
- " pass\n"
- ),
- headers={blackd.PYTHON_VARIANT_HEADER: "3.6"},
- )
- self.assertEqual(response.status, 200)
- response = await client.post(
- "/",
- data=(
- "def f(\n"
- " and_has_a_bunch_of,\n"
- " very_long_arguments_too,\n"
- " and_lots_of_them_as_well_lol,\n"
- " **and_very_long_keyword_arguments\n"
- "):\n"
- " pass\n"
- ),
- headers={blackd.PYTHON_VARIANT_HEADER: "3.5"},
- )
- self.assertEqual(response.status, 204)
- response = await client.post(
- "/",
- data=(
- "def f(\n"
- " and_has_a_bunch_of,\n"
- " very_long_arguments_too,\n"
- " and_lots_of_them_as_well_lol,\n"
- " **and_very_long_keyword_arguments\n"
- "):\n"
- " pass\n"
- ),
- headers={blackd.PYTHON_VARIANT_HEADER: "2"},
- )
- self.assertEqual(response.status, 204)
+
+ async def check(header_value: str, expected_status: int) -> None:
+ response = await client.post(
+ "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
+ )
+ self.assertEqual(response.status, expected_status)
+
+ await check("3.6", 200)
+ await check("cpy3.6", 200)
+ await check("3.5,3.7", 200)
+ await check("3.5,cpy3.7", 200)
+
+ await check("2", 204)
+ await check("2.7", 204)
+ await check("cpy2.7", 204)
+ await check("pypy2.7", 204)
+ await check("3.4", 204)
+ await check("cpy3.4", 204)
+ await check("pypy3.4", 204)
@unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
@async_test