@dataclass
-class FileMode:
+class Mode:
target_versions: Set[TargetVersion] = field(default_factory=set)
line_length: int = DEFAULT_LINE_LENGTH
string_normalization: bool = True
return ".".join(parts)
+# Legacy name, left for integrations.
+FileMode = Mode
+
+
def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
+def find_pyproject_toml(path_search_start: str) -> Optional[str]:
+ """Find the absolute filepath to a pyproject.toml if it exists"""
+ path_project_root = find_project_root(path_search_start)
+ path_pyproject_toml = path_project_root / "pyproject.toml"
+ return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+
+
+def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
+ """Parse a pyproject toml file, pulling out relevant parts for Black
+
+ If parsing fails, will raise a toml.TomlDecodeError
+ """
+ pyproject_toml = toml.load(path_config)
+ config = pyproject_toml.get("tool", {}).get("black", {})
+ return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
+
+
def read_pyproject_toml(
ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
) -> Optional[str]:
"""
assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
if not value:
- root = find_project_root(ctx.params.get("src", ()))
- path = root / "pyproject.toml"
- if path.is_file():
- value = str(path)
- else:
+ value = find_pyproject_toml(ctx.params.get("src", ()))
+ if value is None:
return None
try:
- pyproject_toml = toml.load(value)
- config = pyproject_toml.get("tool", {}).get("black", {})
+ config = parse_pyproject_toml(value)
except (toml.TomlDecodeError, OSError) as e:
raise click.FileError(
filename=value, hint=f"Error reading configuration file: {e}"
if ctx.default_map is None:
ctx.default_map = {}
- ctx.default_map.update( # type: ignore # bad types in .pyi
- {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
- )
+ ctx.default_map.update(config) # type: ignore # bad types in .pyi
return value
else:
# We'll autodetect later.
versions = set()
- mode = FileMode(
+ mode = Mode(
target_versions=versions,
line_length=line_length,
is_pyi=pyi,
except re.error:
err(f"Invalid regular expression for exclude given: {exclude!r}")
ctx.exit(2)
- report = Report(check=check, quiet=quiet, verbose=verbose)
+ report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
root = find_project_root(src)
sources: Set[Path] = set()
path_empty(src, quiet, verbose, ctx)
def reformat_one(
- src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
+ src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
) -> None:
"""Reformat a single file under `src` without spawning child processes.
def reformat_many(
- sources: Set[Path],
- fast: bool,
- write_back: WriteBack,
- mode: FileMode,
- report: "Report",
+ sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
) -> None:
"""Reformat multiple files using a ProcessPoolExecutor."""
loop = asyncio.get_event_loop()
sources: Set[Path],
fast: bool,
write_back: WriteBack,
- mode: FileMode,
+ mode: Mode,
report: "Report",
loop: asyncio.AbstractEventLoop,
executor: Executor,
def format_file_in_place(
src: Path,
fast: bool,
- mode: FileMode,
+ mode: Mode,
write_back: WriteBack = WriteBack.NO,
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
) -> bool:
def format_stdin_to_stdout(
- fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
+ fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
) -> bool:
"""Format file on stdin. Return True if changed.
f.detach()
-def format_file_contents(
- src_contents: str, *, fast: bool, mode: FileMode
-) -> FileContent:
+def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
"""Reformat contents a file and return new contents.
If `fast` is False, additionally confirm that the reformatted code is
return dst_contents
-def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
+def format_str(src_contents: str, *, mode: Mode) -> FileContent:
"""Reformat a string and return new contents.
`mode` determines formatting options, such as how many characters per line are
- allowed.
+ allowed. Example:
+
+ >>> import black
+ >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
+ def f(arg: str = "") -> None:
+ ...
+
+ A more complex example:
+ >>> print(
+ ... black.format_str(
+ ... "def f(arg:str='')->None: hey",
+ ... mode=black.Mode(
+ ... target_versions={black.TargetVersion.PY36},
+ ... line_length=10,
+ ... string_normalization=False,
+ ... is_pyi=False,
+ ... ),
+ ... ),
+ ... )
+ def f(
+ arg: str = '',
+ ) -> None:
+ hey
+
"""
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
dst_contents = []
if supports_feature(versions, feature)
}
for current_line in lines.visit(src_node):
- for _ in range(after):
- dst_contents.append(str(empty_line))
+ dst_contents.append(str(empty_line) * after)
before, after = elt.maybe_empty_lines(current_line)
- for _ in range(before):
- dst_contents.append(str(empty_line))
+ dst_contents.append(str(empty_line) * before)
for line in split_line(
current_line, line_length=mode.line_length, features=split_line_features
):
return False
def contains_multiline_strings(self) -> bool:
- for leaf in self.leaves:
- if is_multiline_string(leaf):
- return True
-
- return False
+ return any(is_multiline_string(leaf) for leaf in self.leaves)
def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
"""Remove trailing comma if there is one and it's safe."""
match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
assert match is not None, f"failed to match string {leaf.value!r}"
orig_prefix = match.group(1)
- new_prefix = orig_prefix.lower()
+ new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
if remove_u_prefix:
new_prefix = new_prefix.replace("u", "")
leaf.value = f"{new_prefix}{match.group(2)}"
# Append a fake file so `parents` below returns `common_base_dir`, too.
common_base /= "fake-file"
for directory in common_base.parents:
- if (directory / ".git").is_dir():
+ if (directory / ".git").exists():
return directory
if (directory / ".hg").is_dir():
"""Provides a reformatting counter. Can be rendered with `str(report)`."""
check: bool = False
+ diff: bool = False
quiet: bool = False
verbose: bool = False
change_count: int = 0
def done(self, src: Path, changed: Changed) -> None:
"""Increment the counter for successful reformatting. Write out a message."""
if changed is Changed.YES:
- reformatted = "would reformat" if self.check else "reformatted"
+ reformatted = "would reformat" if self.check or self.diff else "reformatted"
if self.verbose or not self.quiet:
out(f"{reformatted} {src}")
self.change_count += 1
Use `click.unstyle` to remove colors.
"""
- if self.check:
+ if self.check or self.diff:
reformatted = "would be reformatted"
unchanged = "would be left unchanged"
failed = "would fail to reformat"
) from None
-def assert_stable(src: str, dst: str, mode: FileMode) -> None:
+def assert_stable(src: str, dst: str, mode: Mode) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
newdst = format_str(dst, mode=mode)
if dst != newdst:
return False
-def get_cache_file(mode: FileMode) -> Path:
+def get_cache_file(mode: Mode) -> Path:
return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
-def read_cache(mode: FileMode) -> Cache:
+def read_cache(mode: Mode) -> Cache:
"""Read the cache if it exists and is well formed.
If it is not well formed, the call to write_cache later should resolve the issue.
return todo, done
-def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
+def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
"""Update the cache file."""
cache_file = get_cache_file(mode)
try: