return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
+def find_pyproject_toml(path_search_start: str) -> Optional[str]:
+ """Find the absolute filepath to a pyproject.toml if it exists"""
+ path_project_root = find_project_root(path_search_start)
+ path_pyproject_toml = path_project_root / "pyproject.toml"
+ return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+
+
+def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
+ """Parse a pyproject toml file, pulling out relevant parts for Black
+
+ If parsing fails, will raise a toml.TomlDecodeError
+ """
+ pyproject_toml = toml.load(path_config)
+ config = pyproject_toml.get("tool", {}).get("black", {})
+ return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
+
+
def read_pyproject_toml(
ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
) -> Optional[str]:
"""
assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
if not value:
- root = find_project_root(ctx.params.get("src", ()))
- path = root / "pyproject.toml"
- if path.is_file():
- value = str(path)
- else:
+ value = find_pyproject_toml(ctx.params.get("src", ()))
+ if value is None:
return None
try:
- pyproject_toml = toml.load(value)
- config = pyproject_toml.get("tool", {}).get("black", {})
+ config = parse_pyproject_toml(value)
except (toml.TomlDecodeError, OSError) as e:
raise click.FileError(
filename=value, hint=f"Error reading configuration file: {e}"
if ctx.default_map is None:
ctx.default_map = {}
- ctx.default_map.update( # type: ignore # bad types in .pyi
- {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
- )
+ ctx.default_map.update(config) # type: ignore # bad types in .pyi
return value
except re.error:
err(f"Invalid regular expression for exclude given: {exclude!r}")
ctx.exit(2)
- report = Report(check=check, quiet=quiet, verbose=verbose)
+ report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
root = find_project_root(src)
sources: Set[Path] = set()
path_empty(src, quiet, verbose, ctx)
if supports_feature(versions, feature)
}
for current_line in lines.visit(src_node):
- for _ in range(after):
- dst_contents.append(str(empty_line))
+ dst_contents.append(str(empty_line) * after)
before, after = elt.maybe_empty_lines(current_line)
- for _ in range(before):
- dst_contents.append(str(empty_line))
+ dst_contents.append(str(empty_line) * before)
for line in split_line(
current_line, line_length=mode.line_length, features=split_line_features
):
return False
def contains_multiline_strings(self) -> bool:
- for leaf in self.leaves:
- if is_multiline_string(leaf):
- return True
-
- return False
+ return any(is_multiline_string(leaf) for leaf in self.leaves)
def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
"""Remove trailing comma if there is one and it's safe."""
match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
assert match is not None, f"failed to match string {leaf.value!r}"
orig_prefix = match.group(1)
- new_prefix = orig_prefix.lower()
+ new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
if remove_u_prefix:
new_prefix = new_prefix.replace("u", "")
leaf.value = f"{new_prefix}{match.group(2)}"
"""Provides a reformatting counter. Can be rendered with `str(report)`."""
check: bool = False
+ diff: bool = False
quiet: bool = False
verbose: bool = False
change_count: int = 0
def done(self, src: Path, changed: Changed) -> None:
"""Increment the counter for successful reformatting. Write out a message."""
if changed is Changed.YES:
- reformatted = "would reformat" if self.check else "reformatted"
+ reformatted = "would reformat" if self.check or self.diff else "reformatted"
if self.verbose or not self.quiet:
out(f"{reformatted} {src}")
self.change_count += 1
Use `click.unstyle` to remove colors.
"""
- if self.check:
+ if self.check or self.diff:
reformatted = "would be reformatted"
unchanged = "would be left unchanged"
failed = "would fail to reformat"