from dataclasses import dataclass, field, replace
import click
import toml
-from typed_ast import ast3, ast27
+
+try:
+ from typed_ast import ast3, ast27
+except ImportError:
+ if sys.version_info < (3, 8):
+ print(
+ "The typed_ast package is not installed.\n"
+ "You can install it with `python3 -m pip install typed-ast`.",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+ else:
+ ast3 = ast27 = ast
+
from pathspec import PathSpec
# lib2to3 fork
"""Find the absolute filepath to a pyproject.toml if it exists"""
path_project_root = find_project_root(path_search_start)
path_pyproject_toml = path_project_root / "pyproject.toml"
- return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+ if path_pyproject_toml.is_file():
+ return str(path_pyproject_toml)
+
+ path_user_pyproject_toml = find_user_pyproject_toml()
+ return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None
def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
return [TargetVersion[val.upper()] for val in v]
+def validate_regex(
+ ctx: click.Context,
+ param: click.Parameter,
+ value: Optional[str],
+) -> Optional[Pattern]:
+ try:
+ return re_compile_maybe_verbose(value) if value is not None else None
+ except re.error:
+ raise click.BadParameter("Not a valid regular expression")
+
+
@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
@click.option(
"--check",
is_flag=True,
help=(
- "Don't write the files back, just return the status. Return code 0 means"
- " nothing would change. Return code 1 means some files would be reformatted."
+ "Don't write the files back, just return the status. Return code 0 means"
+ " nothing would change. Return code 1 means some files would be reformatted."
" Return code 123 means there was an internal error."
),
)
"--include",
type=str,
default=DEFAULT_INCLUDES,
+ callback=validate_regex,
help=(
"A regular expression that matches files and directories that should be"
- " included on recursive searches. An empty value means all files are included"
- " regardless of the name. Use forward slashes for directories on all platforms"
- " (Windows, too). Exclusions are calculated first, inclusions later."
+ " included on recursive searches. An empty value means all files are included"
+ " regardless of the name. Use forward slashes for directories on all platforms"
+ " (Windows, too). Exclusions are calculated first, inclusions later."
),
show_default=True,
)
"--exclude",
type=str,
default=DEFAULT_EXCLUDES,
+ callback=validate_regex,
help=(
"A regular expression that matches files and directories that should be"
- " excluded on recursive searches. An empty value means no paths are excluded."
- " Use forward slashes for directories on all platforms (Windows, too). "
+ " excluded on recursive searches. An empty value means no paths are excluded."
+ " Use forward slashes for directories on all platforms (Windows, too)."
" Exclusions are calculated first, inclusions later."
),
show_default=True,
)
+@click.option(
+ "--extend-exclude",
+ type=str,
+ callback=validate_regex,
+ help=(
+ "Like --exclude, but adds additional files and directories on top of the"
+ " excluded ones. (Useful if you simply want to add to the default)"
+ ),
+)
@click.option(
"--force-exclude",
type=str,
+ callback=validate_regex,
help=(
"Like --exclude, but files and directories matching this regex will be "
"excluded even when they are passed explicitly as arguments."
is_flag=True,
help=(
"Also emit messages to stderr about files that were not changed or were ignored"
- " due to --exclude=."
+ " due to exclusion patterns."
),
)
@click.version_option(version=__version__)
experimental_string_processing: bool,
quiet: bool,
verbose: bool,
- include: str,
- exclude: str,
- force_exclude: Optional[str],
+ include: Pattern,
+ exclude: Pattern,
+ extend_exclude: Optional[Pattern],
+ force_exclude: Optional[Pattern],
stdin_filename: Optional[str],
src: Tuple[str, ...],
config: Optional[str],
verbose=verbose,
include=include,
exclude=exclude,
+ extend_exclude=extend_exclude,
force_exclude=force_exclude,
report=report,
stdin_filename=stdin_filename,
src: Tuple[str, ...],
quiet: bool,
verbose: bool,
- include: str,
- exclude: str,
- force_exclude: Optional[str],
+ include: Pattern[str],
+ exclude: Pattern[str],
+ extend_exclude: Optional[Pattern[str]],
+ force_exclude: Optional[Pattern[str]],
report: "Report",
stdin_filename: Optional[str],
) -> Set[Path]:
"""Compute the set of files to be formatted."""
- try:
- include_regex = re_compile_maybe_verbose(include)
- except re.error:
- err(f"Invalid regular expression for include given: {include!r}")
- ctx.exit(2)
- try:
- exclude_regex = re_compile_maybe_verbose(exclude)
- except re.error:
- err(f"Invalid regular expression for exclude given: {exclude!r}")
- ctx.exit(2)
- try:
- force_exclude_regex = (
- re_compile_maybe_verbose(force_exclude) if force_exclude else None
- )
- except re.error:
- err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
- ctx.exit(2)
root = find_project_root(src)
sources: Set[Path] = set()
normalized_path = "/" + normalized_path
# Hard-exclude any files that matches the `--force-exclude` regex.
- if force_exclude_regex:
- force_exclude_match = force_exclude_regex.search(normalized_path)
+ if force_exclude:
+ force_exclude_match = force_exclude.search(normalized_path)
else:
force_exclude_match = None
if force_exclude_match and force_exclude_match.group(0):
gen_python_files(
p.iterdir(),
root,
- include_regex,
- exclude_regex,
- force_exclude_regex,
+ include,
+ exclude,
+ extend_exclude,
+ force_exclude,
report,
gitignore,
)
dst_name = f"{src}\t{now} +0000"
diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
- if write_back == write_back.COLOR_DIFF:
+ if write_back == WriteBack.COLOR_DIFF:
diff_contents = color_diff(diff_contents)
with lock or nullcontext():
comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
inside_brackets: bool = False
- should_explode: bool = False
+ should_split_rhs: bool = False
+ magic_trailing_comma: Optional[Leaf] = None
def append(self, leaf: Leaf, preformatted: bool = False) -> None:
"""Add a new `leaf` to the end of the line.
self.bracket_tracker.mark(leaf)
if self.mode.magic_trailing_comma:
if self.has_magic_trailing_comma(leaf):
- self.should_explode = True
+ self.magic_trailing_comma = leaf
elif self.has_magic_trailing_comma(leaf, ensure_removable=True):
self.remove_trailing_comma()
if not self.append_comment(leaf):
mode=self.mode,
depth=self.depth,
inside_brackets=self.inside_brackets,
- should_explode=self.should_explode,
+ should_split_rhs=self.should_split_rhs,
+ magic_trailing_comma=self.magic_trailing_comma,
)
def __str__(self) -> str:
def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
"""Visit a statement without nested statements."""
+ if first_child_is_arith(node):
+ wrap_in_parentheses(node, node.children[0], visible=False)
is_suite_like = node.parent and node.parent.type in STATEMENT
if is_suite_like:
if self.mode.is_pyi and is_stub_body(node):
transformers: List[Transformer]
if (
not line.contains_uncollapsable_type_comments()
- and not line.should_explode
+ and not line.should_split_rhs
+ and not line.magic_trailing_comma
and (
is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
or line.contains_unsplittable_type_ignore()
mode=line.mode,
depth=line.depth + 1,
inside_brackets=True,
- should_explode=line.should_explode,
+ should_split_rhs=line.should_split_rhs,
+ magic_trailing_comma=line.magic_trailing_comma,
)
string_leaf = Leaf(token.STRING, string_value)
insert_str_child(string_leaf)
result.append(leaf, preformatted=True)
for comment_after in original.comments_after(leaf):
result.append(comment_after, preformatted=True)
- if is_body and should_split(result, opening_bracket):
- result.should_explode = True
+ if is_body and should_split_line(result, opening_bracket):
+ result.should_split_rhs = True
return result
check_lpar = True
if check_lpar:
- if is_walrus_assignment(child):
- pass
-
- elif child.type == syms.atom:
+ if child.type == syms.atom:
if maybe_make_parens_invisible_in_atom(child, parent=node):
wrap_in_parentheses(node, child, visible=False)
elif is_one_tuple(child):
Returns whether the node should itself be wrapped in invisible parentheses.
"""
+
if (
node.type != syms.atom
or is_empty_tuple(node)
):
return False
+ if is_walrus_assignment(node):
+ if parent.type in [syms.annassign, syms.expr_stmt]:
+ return False
+
first = node.children[0]
last = node.children[-1]
if first.type == token.LPAR and last.type == token.RPAR:
return wrapped
+def first_child_is_arith(node: Node) -> bool:
+ """Whether first child is an arithmetic or a binary arithmetic expression"""
+ expr_types = {
+ syms.arith_expr,
+ syms.shift_expr,
+ syms.xor_expr,
+ syms.and_expr,
+ }
+ return bool(node.children and node.children[0].type in expr_types)
+
+
def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
"""Wrap `child` in parentheses.
leaf.value = ")"
-def should_split(line: Line, opening_bracket: Leaf) -> bool:
+def should_split_line(line: Line, opening_bracket: Leaf) -> bool:
"""Should `line` be immediately split with `delimiter_split()` after RHS?"""
if not (opening_bracket.parent and opening_bracket.value in "[{("):
"""
omit: Set[LeafID] = set()
- if not line.should_explode:
+ if not line.magic_trailing_comma:
yield omit
length = 4 * line.depth
elif leaf.type in CLOSING_BRACKETS:
prev = line.leaves[index - 1] if index > 0 else None
if (
- line.should_explode
- and prev
+ prev
and prev.type == token.COMMA
and not is_one_tuple_between(
leaf.opening_bracket, leaf, line.leaves
yield omit
if (
- line.should_explode
- and prev
+ prev
and prev.type == token.COMMA
and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
):
return normalized_path
+def path_is_excluded(
+ normalized_path: str,
+ pattern: Optional[Pattern[str]],
+) -> bool:
+ match = pattern.search(normalized_path) if pattern else None
+ return bool(match and match.group(0))
+
+
def gen_python_files(
paths: Iterable[Path],
root: Path,
include: Optional[Pattern[str]],
exclude: Pattern[str],
+ extend_exclude: Optional[Pattern[str]],
force_exclude: Optional[Pattern[str]],
report: "Report",
gitignore: PathSpec,
) -> Iterator[Path]:
"""Generate all files under `path` whose paths are not excluded by the
- `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
+ `exclude_regex`, `extend_exclude`, or `force_exclude` regexes,
+ but are included by the `include` regex.
Symbolic links pointing outside of the `root` directory are ignored.
report.path_ignored(child, "matches the .gitignore file content")
continue
- # Then ignore with `--exclude` and `--force-exclude` options.
+ # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options.
normalized_path = "/" + normalized_path
if child.is_dir():
normalized_path += "/"
- exclude_match = exclude.search(normalized_path) if exclude else None
- if exclude_match and exclude_match.group(0):
+ if path_is_excluded(normalized_path, exclude):
report.path_ignored(child, "matches the --exclude regular expression")
continue
- force_exclude_match = (
- force_exclude.search(normalized_path) if force_exclude else None
- )
- if force_exclude_match and force_exclude_match.group(0):
+ if path_is_excluded(normalized_path, extend_exclude):
+ report.path_ignored(
+ child, "matches the --extend-exclude regular expression"
+ )
+ continue
+
+ if path_is_excluded(normalized_path, force_exclude):
report.path_ignored(child, "matches the --force-exclude regular expression")
continue
root,
include,
exclude,
+ extend_exclude,
force_exclude,
report,
gitignore,
return directory
+@lru_cache()
+def find_user_pyproject_toml() -> Path:
+ r"""Return the path to the top-level user configuration for black.
+
+ This looks for ~\.black on Windows and ~/.config/black on Linux and other
+ Unix systems.
+ """
+ if sys.platform == "win32":
+ # Windows
+ user_config_path = Path.home() / ".black"
+ else:
+ config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config")
+ user_config_path = Path(config_root).expanduser() / "black"
+ return user_config_path.resolve()
+
+
@dataclass
class Report:
"""Provides a reformatting counter. Can be rendered with `str(report)`."""
return ast3.parse(src, filename, feature_version=feature_version)
except SyntaxError:
continue
-
+ if ast27.__name__ == "ast":
+ raise SyntaxError(
+ "The requested source code has invalid Python 3 syntax.\n"
+ "If you are trying to format Python 2 files please reinstall Black"
+ " with the 'python2' extra: `python3 -m pip install black[python2]`."
+ )
return ast27.parse(src)
@mypyc_attr(patchable=True)
-def dump_to_file(*output: str) -> str:
+def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str:
"""Dump `output` to a temporary file. Return path to the file."""
with tempfile.NamedTemporaryFile(
mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
) as f:
for lines in output:
f.write(lines)
- if lines and lines[-1] != "\n":
+ if ensure_final_newline and lines and lines[-1] != "\n":
f.write("\n")
return f.name
"""Return a unified diff string between strings `a` and `b`."""
import difflib
- a_lines = [line + "\n" for line in a.splitlines()]
- b_lines = [line + "\n" for line in b.splitlines()]
- return "".join(
- difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
- )
+ a_lines = [line for line in a.splitlines(keepends=True)]
+ b_lines = [line for line in b.splitlines(keepends=True)]
+ diff_lines = []
+ for line in difflib.unified_diff(
+ a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5
+ ):
+ # Work around https://bugs.python.org/issue2142
+ # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html
+ if line[-1] == "\n":
+ diff_lines.append(line)
+ else:
+ diff_lines.append(line + "\n")
+ diff_lines.append("\\ No newline at end of file\n")
+ return "".join(diff_lines)
def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
penultimate = line.leaves[-2]
last = line.leaves[-1]
- if line.should_explode:
+ if line.magic_trailing_comma:
try:
penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
except LookupError:
# unnecessary.
return True
- if line.should_explode and penultimate.type == token.COMMA:
+ if line.magic_trailing_comma and penultimate.type == token.COMMA:
# The rightmost non-omitted bracket pair is the one we want to explode on.
return True