@dataclass
-class FileMode:
+class Mode:
target_versions: Set[TargetVersion] = field(default_factory=set)
line_length: int = DEFAULT_LINE_LENGTH
string_normalization: bool = True
return ".".join(parts)
+# Legacy name, left for integrations.
+FileMode = Mode
+
+
def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
+def find_pyproject_toml(path_search_start: str) -> Optional[str]:
+ """Find the absolute filepath to a pyproject.toml if it exists"""
+ path_project_root = find_project_root(path_search_start)
+ path_pyproject_toml = path_project_root / "pyproject.toml"
+ return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+
+
+def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
+ """Parse a pyproject toml file, pulling out relevant parts for Black
+
+ If parsing fails, will raise a toml.TomlDecodeError
+ """
+ pyproject_toml = toml.load(path_config)
+ config = pyproject_toml.get("tool", {}).get("black", {})
+ return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
+
+
def read_pyproject_toml(
ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
) -> Optional[str]:
"""
assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
if not value:
- root = find_project_root(ctx.params.get("src", ()))
- path = root / "pyproject.toml"
- if path.is_file():
- value = str(path)
- else:
+ value = find_pyproject_toml(ctx.params.get("src", ()))
+ if value is None:
return None
try:
- pyproject_toml = toml.load(value)
- config = pyproject_toml.get("tool", {}).get("black", {})
+ config = parse_pyproject_toml(value)
except (toml.TomlDecodeError, OSError) as e:
raise click.FileError(
filename=value, hint=f"Error reading configuration file: {e}"
if ctx.default_map is None:
ctx.default_map = {}
- ctx.default_map.update( # type: ignore # bad types in .pyi
- {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
- )
+ ctx.default_map.update(config) # type: ignore # bad types in .pyi
return value
else:
# We'll autodetect later.
versions = set()
- mode = FileMode(
+ mode = Mode(
target_versions=versions,
line_length=line_length,
is_pyi=pyi,
def reformat_one(
- src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
+ src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
) -> None:
"""Reformat a single file under `src` without spawning child processes.
def reformat_many(
- sources: Set[Path],
- fast: bool,
- write_back: WriteBack,
- mode: FileMode,
- report: "Report",
+ sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
) -> None:
"""Reformat multiple files using a ProcessPoolExecutor."""
loop = asyncio.get_event_loop()
sources: Set[Path],
fast: bool,
write_back: WriteBack,
- mode: FileMode,
+ mode: Mode,
report: "Report",
loop: asyncio.AbstractEventLoop,
executor: Executor,
def format_file_in_place(
src: Path,
fast: bool,
- mode: FileMode,
+ mode: Mode,
write_back: WriteBack = WriteBack.NO,
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
) -> bool:
def format_stdin_to_stdout(
- fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
+ fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
) -> bool:
"""Format file on stdin. Return True if changed.
f.detach()
-def format_file_contents(
- src_contents: str, *, fast: bool, mode: FileMode
-) -> FileContent:
+def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
"""Reformat contents a file and return new contents.
If `fast` is False, additionally confirm that the reformatted code is
return dst_contents
-def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
+def format_str(src_contents: str, *, mode: Mode) -> FileContent:
"""Reformat a string and return new contents.
`mode` determines formatting options, such as how many characters per line are
- allowed.
+ allowed. Example:
+
+ >>> import black
+ >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
+ def f(arg: str = "") -> None:
+ ...
+
+ A more complex example:
+ >>> print(
+ ... black.format_str(
+ ... "def f(arg:str='')->None: hey",
+ ... mode=black.Mode(
+ ... target_versions={black.TargetVersion.PY36},
+ ... line_length=10,
+ ... string_normalization=False,
+ ... is_pyi=False,
+ ... ),
+ ... ),
+ ... )
+ def f(
+ arg: str = '',
+ ) -> None:
+ hey
+
"""
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
dst_contents = []
match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
assert match is not None, f"failed to match string {leaf.value!r}"
orig_prefix = match.group(1)
- new_prefix = orig_prefix.lower()
+ new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
if remove_u_prefix:
new_prefix = new_prefix.replace("u", "")
leaf.value = f"{new_prefix}{match.group(2)}"
) from None
-def assert_stable(src: str, dst: str, mode: FileMode) -> None:
+def assert_stable(src: str, dst: str, mode: Mode) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
newdst = format_str(dst, mode=mode)
if dst != newdst:
return False
-def get_cache_file(mode: FileMode) -> Path:
+def get_cache_file(mode: Mode) -> Path:
return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
-def read_cache(mode: FileMode) -> Cache:
+def read_cache(mode: Mode) -> Cache:
"""Read the cache if it exists and is well formed.
If it is not well formed, the call to write_cache later should resolve the issue.
return todo, done
-def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
+def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
"""Update the cache file."""
cache_file = get_cache_file(mode)
try: