Union,
cast,
)
+from typing_extensions import Final
+from mypy_extensions import mypyc_attr
from appdirs import user_cache_dir
-from attr import dataclass, evolve, Factory
+from dataclasses import dataclass, field, replace
import click
import toml
from typed_ast import ast3, ast27
@dataclass
-class FileMode:
- target_versions: Set[TargetVersion] = Factory(set)
+class Mode:
+ target_versions: Set[TargetVersion] = field(default_factory=set)
line_length: int = DEFAULT_LINE_LENGTH
string_normalization: bool = True
is_pyi: bool = False
return ".".join(parts)
+# Legacy name, left for integrations.
+FileMode = Mode
+
+
def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
+def find_pyproject_toml(path_search_start: str) -> Optional[str]:
+ """Find the absolute filepath to a pyproject.toml if it exists"""
+ path_project_root = find_project_root(path_search_start)
+ path_pyproject_toml = path_project_root / "pyproject.toml"
+ return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+
+
+def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
+ """Parse a pyproject toml file, pulling out relevant parts for Black
+
+ If parsing fails, will raise a toml.TomlDecodeError
+ """
+ pyproject_toml = toml.load(path_config)
+ config = pyproject_toml.get("tool", {}).get("black", {})
+ return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
+
+
def read_pyproject_toml(
- ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
+ ctx: click.Context, param: click.Parameter, value: Optional[str]
) -> Optional[str]:
"""Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
Returns the path to a successfully found and read configuration file, None
otherwise.
"""
- assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
if not value:
- root = find_project_root(ctx.params.get("src", ()))
- path = root / "pyproject.toml"
- if path.is_file():
- value = str(path)
- else:
+ value = find_pyproject_toml(ctx.params.get("src", ()))
+ if value is None:
return None
try:
- pyproject_toml = toml.load(value)
- config = pyproject_toml.get("tool", {}).get("black", {})
+ config = parse_pyproject_toml(value)
except (toml.TomlDecodeError, OSError) as e:
raise click.FileError(
filename=value, hint=f"Error reading configuration file: {e}"
if not config:
return None
- if ctx.default_map is None:
- ctx.default_map = {}
- ctx.default_map.update( # type: ignore # bad types in .pyi
- {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
- )
+ default_map: Dict[str, Any] = {}
+ if ctx.default_map:
+ default_map.update(ctx.default_map)
+ default_map.update(config)
+
+ ctx.default_map = default_map
return value
+def target_version_option_callback(
+ c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
+) -> List[TargetVersion]:
+ """Compute the target versions from a --target-version flag.
+
+ This is its own function because mypy couldn't infer the type correctly
+ when it was a lambda, causing mypyc trouble.
+ """
+ return [TargetVersion[val.upper()] for val in v]
+
+
@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
@click.option(
"-t",
"--target-version",
type=click.Choice([v.name.lower() for v in TargetVersion]),
- callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
+ callback=target_version_option_callback,
multiple=True,
help=(
"Python versions that should be supported by Black's output. [default: "
@click.option(
"--config",
type=click.Path(
- exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
+ exists=True,
+ file_okay=True,
+ dir_okay=False,
+ readable=True,
+ allow_dash=False,
+ path_type=str,
),
is_eager=True,
callback=read_pyproject_toml,
verbose: bool,
include: str,
exclude: str,
- src: Tuple[str],
+ src: Tuple[str, ...],
config: Optional[str],
) -> None:
"""The uncompromising code formatter."""
write_back = WriteBack.from_configuration(check=check, diff=diff)
if target_version:
if py36:
- err(f"Cannot use both --target-version and --py36")
+ err("Cannot use both --target-version and --py36")
ctx.exit(2)
else:
versions = set(target_version)
else:
# We'll autodetect later.
versions = set()
- mode = FileMode(
+ mode = Mode(
target_versions=versions,
line_length=line_length,
is_pyi=pyi,
except re.error:
err(f"Invalid regular expression for exclude given: {exclude!r}")
ctx.exit(2)
- report = Report(check=check, quiet=quiet, verbose=verbose)
+ report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
root = find_project_root(src)
sources: Set[Path] = set()
path_empty(src, quiet, verbose, ctx)
ctx.exit(report.return_code)
-def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None:
+def path_empty(
+ src: Tuple[str, ...], quiet: bool, verbose: bool, ctx: click.Context
+) -> None:
"""
Exit if there is no `src` provided for formatting
"""
def reformat_one(
- src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
+ src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
) -> None:
"""Reformat a single file under `src` without spawning child processes.
def reformat_many(
- sources: Set[Path],
- fast: bool,
- write_back: WriteBack,
- mode: FileMode,
- report: "Report",
+ sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
) -> None:
"""Reformat multiple files using a ProcessPoolExecutor."""
loop = asyncio.get_event_loop()
sources: Set[Path],
fast: bool,
write_back: WriteBack,
- mode: FileMode,
+ mode: Mode,
report: "Report",
loop: asyncio.AbstractEventLoop,
executor: Executor,
): src
for src in sorted(sources)
}
- pending: Iterable[asyncio.Future] = tasks.keys()
+ pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
try:
loop.add_signal_handler(signal.SIGINT, cancel, pending)
loop.add_signal_handler(signal.SIGTERM, cancel, pending)
def format_file_in_place(
src: Path,
fast: bool,
- mode: FileMode,
+ mode: Mode,
write_back: WriteBack = WriteBack.NO,
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
) -> bool:
`mode` and `fast` options are passed to :func:`format_file_contents`.
"""
if src.suffix == ".pyi":
- mode = evolve(mode, is_pyi=True)
+ mode = replace(mode, is_pyi=True)
then = datetime.utcfromtimestamp(src.stat().st_mtime)
with open(src, "rb") as buf:
except NothingChanged:
return False
- if write_back == write_back.YES:
+ if write_back == WriteBack.YES:
with open(src, "w", encoding=encoding, newline=newline) as f:
f.write(dst_contents)
- elif write_back == write_back.DIFF:
+ elif write_back == WriteBack.DIFF:
now = datetime.utcnow()
src_name = f"{src}\t{then} +0000"
dst_name = f"{src}\t{now} +0000"
def format_stdin_to_stdout(
- fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
+ fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
) -> bool:
"""Format file on stdin. Return True if changed.
f.detach()
-def format_file_contents(
- src_contents: str, *, fast: bool, mode: FileMode
-) -> FileContent:
+def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
"""Reformat contents a file and return new contents.
If `fast` is False, additionally confirm that the reformatted code is
return dst_contents
-def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
+def format_str(src_contents: str, *, mode: Mode) -> FileContent:
"""Reformat a string and return new contents.
`mode` determines formatting options, such as how many characters per line are
- allowed.
+ allowed. Example:
+
+ >>> import black
+ >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
+ def f(arg: str = "") -> None:
+ ...
+
+ A more complex example:
+ >>> print(
+ ... black.format_str(
+ ... "def f(arg:str='')->None: hey",
+ ... mode=black.Mode(
+ ... target_versions={black.TargetVersion.PY36},
+ ... line_length=10,
+ ... string_normalization=False,
+ ... is_pyi=False,
+ ... ),
+ ... ),
+ ... )
+ def f(
+ arg: str = '',
+ ) -> None:
+ hey
+
"""
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
dst_contents = []
if supports_feature(versions, feature)
}
for current_line in lines.visit(src_node):
- for _ in range(after):
- dst_contents.append(str(empty_line))
+ dst_contents.append(str(empty_line) * after)
before, after = elt.maybe_empty_lines(current_line)
- for _ in range(before):
- dst_contents.append(str(empty_line))
+ dst_contents.append(str(empty_line) * before)
for line in split_line(
current_line, line_length=mode.line_length, features=split_line_features
):
if node.type < 256:
name = token.tok_name[node.type]
else:
- name = type_repr(node.type)
- yield from getattr(self, f"visit_{name}", self.visit_default)(node)
+ name = str(type_repr(node.type))
+ # We explicitly branch on whether a visitor exists (instead of
+ # using self.visit_default as the default arg to getattr) in order
+ # to save needing to create a bound method object and so mypyc can
+ # generate a native call to visit_default.
+ visitf = getattr(self, f"visit_{name}", None)
+ if visitf:
+ yield from visitf(node)
+ else:
+ yield from self.visit_default(node)
def visit_default(self, node: LN) -> Iterator[T]:
"""Default `visit_*()` implementation. Recurses to children of `node`."""
list(v.visit(code))
-WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
-STATEMENT = {
+WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
+STATEMENT: Final = {
syms.if_stmt,
syms.while_stmt,
syms.for_stmt,
syms.funcdef,
syms.classdef,
}
-STANDALONE_COMMENT = 153
+STANDALONE_COMMENT: Final = 153
token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
-LOGIC_OPERATORS = {"and", "or"}
-COMPARATORS = {
+LOGIC_OPERATORS: Final = {"and", "or"}
+COMPARATORS: Final = {
token.LESS,
token.GREATER,
token.EQEQUAL,
token.LESSEQUAL,
token.GREATEREQUAL,
}
-MATH_OPERATORS = {
+MATH_OPERATORS: Final = {
token.VBAR,
token.CIRCUMFLEX,
token.AMPER,
token.TILDE,
token.DOUBLESTAR,
}
-STARS = {token.STAR, token.DOUBLESTAR}
-VARARGS_SPECIALS = STARS | {token.SLASH}
-VARARGS_PARENTS = {
+STARS: Final = {token.STAR, token.DOUBLESTAR}
+VARARGS_SPECIALS: Final = STARS | {token.SLASH}
+VARARGS_PARENTS: Final = {
syms.arglist,
syms.argument, # double star in arglist
syms.trailer, # single argument to call
syms.typedargslist,
syms.varargslist, # lambdas
}
-UNPACKING_PARENTS = {
+UNPACKING_PARENTS: Final = {
syms.atom, # single element of a list or set literal
syms.dictsetmaker,
syms.listmaker,
syms.testlist_gexp,
syms.testlist_star_expr,
}
-TEST_DESCENDANTS = {
+TEST_DESCENDANTS: Final = {
syms.test,
syms.lambdef,
syms.or_test,
syms.term,
syms.power,
}
-ASSIGNMENTS = {
+ASSIGNMENTS: Final = {
"=",
"+=",
"-=",
"**=",
"//=",
}
-COMPREHENSION_PRIORITY = 20
-COMMA_PRIORITY = 18
-TERNARY_PRIORITY = 16
-LOGIC_PRIORITY = 14
-STRING_PRIORITY = 12
-COMPARATOR_PRIORITY = 10
-MATH_PRIORITIES = {
+COMPREHENSION_PRIORITY: Final = 20
+COMMA_PRIORITY: Final = 18
+TERNARY_PRIORITY: Final = 16
+LOGIC_PRIORITY: Final = 14
+STRING_PRIORITY: Final = 12
+COMPARATOR_PRIORITY: Final = 10
+MATH_PRIORITIES: Final = {
token.VBAR: 9,
token.CIRCUMFLEX: 8,
token.AMPER: 7,
token.TILDE: 3,
token.DOUBLESTAR: 2,
}
-DOT_PRIORITY = 1
+DOT_PRIORITY: Final = 1
@dataclass
"""Keeps track of brackets on a line."""
depth: int = 0
- bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
- delimiters: Dict[LeafID, Priority] = Factory(dict)
+ bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
+ delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
previous: Optional[Leaf] = None
- _for_loop_depths: List[int] = Factory(list)
- _lambda_argument_depths: List[int] = Factory(list)
+ _for_loop_depths: List[int] = field(default_factory=list)
+ _lambda_argument_depths: List[int] = field(default_factory=list)
def mark(self, leaf: Leaf) -> None:
"""Mark `leaf` with bracket-related metadata. Keep track of delimiters.
"""Holds leaves and comments. Can be printed with `str(line)`."""
depth: int = 0
- leaves: List[Leaf] = Factory(list)
- comments: Dict[LeafID, List[Leaf]] = Factory(dict) # keys ordered like `leaves`
- bracket_tracker: BracketTracker = Factory(BracketTracker)
+ leaves: List[Leaf] = field(default_factory=list)
+ # keys ordered like `leaves`
+ comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
+ bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
inside_brackets: bool = False
should_explode: bool = False
for leaf_id, comments in self.comments.items():
for comment in comments:
if is_type_comment(comment):
- if leaf_id not in ignored_ids or comment_seen:
+ if comment_seen or (
+ not is_type_comment(comment, " ignore")
+ and leaf_id not in ignored_ids
+ ):
return True
comment_seen = True
return False
def contains_multiline_strings(self) -> bool:
- for leaf in self.leaves:
- if is_multiline_string(leaf):
- return True
-
- return False
+ return any(is_multiline_string(leaf) for leaf in self.leaves)
def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
"""Remove trailing comma if there is one and it's safe."""
is_pyi: bool = False
previous_line: Optional[Line] = None
previous_after: int = 0
- previous_defs: List[int] = Factory(list)
+ previous_defs: List[int] = field(default_factory=list)
def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
"""Return the number of extra empty lines before and after the `current_line`.
is_pyi: bool = False
normalize_strings: bool = True
- current_line: Line = Factory(Line)
+ current_line: Line = field(default_factory=Line)
remove_u_prefix: bool = False
def line(self, indent: int = 0) -> Iterator[Line]:
self.current_line.append(node)
yield from super().visit_default(node)
- def visit_INDENT(self, node: Node) -> Iterator[Line]:
+ def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
"""Increase indentation level, maybe yield a line."""
# In blib2to3 INDENT never holds comments.
yield from self.line(+1)
yield from self.visit_default(node)
- def visit_DEDENT(self, node: Node) -> Iterator[Line]:
+ def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
"""Decrease indentation level, maybe yield a line."""
# The current line might still wait for trailing comments. At DEDENT time
# there won't be any (they would be prefixes on the preceding NEWLINE).
node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
yield from self.visit_default(node)
- def __attrs_post_init__(self) -> None:
+ def __post_init__(self) -> None:
"""You are in a twisty little maze of passages."""
v = self.visit_stmt
Ø: Set[str] = set()
body_leaves: List[Leaf] = []
head_leaves: List[Leaf] = []
current_leaves = head_leaves
- matching_bracket = None
+ matching_bracket: Optional[Leaf] = None
for leaf in line.leaves:
if (
current_leaves is body_leaves
body_leaves: List[Leaf] = []
head_leaves: List[Leaf] = []
current_leaves = tail_leaves
- opening_bracket = None
- closing_bracket = None
+ opening_bracket: Optional[Leaf] = None
+ closing_bracket: Optional[Leaf] = None
for leaf in reversed(line.leaves):
if current_leaves is body_leaves:
if leaf is opening_bracket:
match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
assert match is not None, f"failed to match string {leaf.value!r}"
orig_prefix = match.group(1)
- new_prefix = orig_prefix.lower()
+ new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
if remove_u_prefix:
new_prefix = new_prefix.replace("u", "")
leaf.value = f"{new_prefix}{match.group(2)}"
# That happens when one of the `ignored_nodes` ended with a NEWLINE
# leaf (possibly followed by a DEDENT).
hidden_value = hidden_value[:-1]
- first_idx = None
+ first_idx: Optional[int] = None
for ignored in ignored_nodes:
index = ignored.remove()
if first_idx is None:
yield omit
length = 4 * line.depth
- opening_bracket = None
- closing_bracket = None
+ opening_bracket: Optional[Leaf] = None
+ closing_bracket: Optional[Leaf] = None
inner_brackets: Set[LeafID] = set()
for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
length += leaf_length
for child in path.iterdir():
# First ignore files matching .gitignore
if gitignore.match_file(child.as_posix()):
- report.path_ignored(child, f"matches the .gitignore file content")
+ report.path_ignored(child, "matches the .gitignore file content")
continue
# Then ignore with `exclude` option.
exclude_match = exclude.search(normalized_path)
if exclude_match and exclude_match.group(0):
- report.path_ignored(child, f"matches the --exclude regular expression")
+ report.path_ignored(child, "matches the --exclude regular expression")
continue
if child.is_dir():
# Append a fake file so `parents` below returns `common_base_dir`, too.
common_base /= "fake-file"
for directory in common_base.parents:
- if (directory / ".git").is_dir():
+ if (directory / ".git").exists():
return directory
if (directory / ".hg").is_dir():
"""Provides a reformatting counter. Can be rendered with `str(report)`."""
check: bool = False
+ diff: bool = False
quiet: bool = False
verbose: bool = False
change_count: int = 0
def done(self, src: Path, changed: Changed) -> None:
"""Increment the counter for successful reformatting. Write out a message."""
if changed is Changed.YES:
- reformatted = "would reformat" if self.check else "reformatted"
+ reformatted = "would reformat" if self.check or self.diff else "reformatted"
if self.verbose or not self.quiet:
out(f"{reformatted} {src}")
self.change_count += 1
Use `click.unstyle` to remove colors.
"""
- if self.check:
+ if self.check or self.diff:
reformatted = "would be reformatted"
unchanged = "would be left unchanged"
failed = "would fail to reformat"
yield f"{' ' * depth}{node.__class__.__name__}("
- for field in sorted(node._fields):
+ for field in sorted(node._fields): # noqa: F402
# TypeIgnore has only one field 'lineno' which breaks this comparison
type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
if sys.version_info >= (3, 8):
) from None
-def assert_stable(src: str, dst: str, mode: FileMode) -> None:
+def assert_stable(src: str, dst: str, mode: Mode) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
newdst = format_str(dst, mode=mode)
if dst != newdst:
) from None
+@mypyc_attr(patchable=True)
def dump_to_file(*output: str) -> str:
"""Dump `output` to a temporary file. Return path to the file."""
with tempfile.NamedTemporaryFile(
"""Return a unified diff string between strings `a` and `b`."""
import difflib
- a_lines = [line + "\n" for line in a.split("\n")]
- b_lines = [line + "\n" for line in b.split("\n")]
+ a_lines = [line + "\n" for line in a.splitlines()]
+ b_lines = [line + "\n" for line in b.splitlines()]
return "".join(
difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
)
-def cancel(tasks: Iterable[asyncio.Task]) -> None:
+def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
"""asyncio signal handler that cancels all `tasks` and reports to stderr."""
err("Aborted!")
for task in tasks:
return False
-def get_cache_file(mode: FileMode) -> Path:
+def get_cache_file(mode: Mode) -> Path:
return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
-def read_cache(mode: FileMode) -> Cache:
+def read_cache(mode: Mode) -> Cache:
"""Read the cache if it exists and is well formed.
If it is not well formed, the call to write_cache later should resolve the issue.
return todo, done
-def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
+def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
"""Update the cache file."""
cache_file = get_cache_file(mode)
try: