import click
from click.core import ParameterSource
from mypy_extensions import mypyc_attr
+from pathspec import PathSpec
from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
from _black_version import version as __version__
unmask_cell,
)
from black.linegen import LN, LineGenerator, transform_line
-from black.lines import EmptyLineTracker, Line
+from black.lines import EmptyLineTracker, LinesBlock
from black.mode import (
FUTURE_FLAG_TO_FEATURE,
VERSION_TO_FEATURES,
),
default=[],
)
+@click.option(
+ "-x",
+ "--skip-source-first-line",
+ is_flag=True,
+ help="Skip the first line of the source code.",
+)
@click.option(
"-S",
"--skip-string-normalization",
pyi: bool,
ipynb: bool,
python_cell_magics: Sequence[str],
+ skip_source_first_line: bool,
skip_string_normalization: bool,
skip_magic_trailing_comma: bool,
experimental_string_processing: bool,
)
normalized = [
- (source, source)
- if source == "-"
- else (normalize_path_maybe_ignore(Path(source), root), source)
+ (
+ (source, source)
+ if source == "-"
+ else (normalize_path_maybe_ignore(Path(source), root), source)
+ )
for source in src
]
srcs_string = ", ".join(
[
- f'"{_norm}"'
- if _norm
- else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+ (
+ f'"{_norm}"'
+ if _norm
+ else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+ )
for _norm, source in normalized
]
)
user_level_config = str(find_user_pyproject_toml())
if config == user_level_config:
out(
- "Using configuration from user-level config at "
- f"'{user_level_config}'.",
+ (
+ "Using configuration from user-level config at "
+ f"'{user_level_config}'."
+ ),
fg="blue",
)
elif config_source in (
out("Using configuration from project root.", fg="blue")
else:
out(f"Using configuration in '{config}'.", fg="blue")
+ if ctx.default_map:
+ for param, value in ctx.default_map.items():
+ out(f"{param}: {value}")
error_msg = "Oh no! 💥 💔 💥"
if (
line_length=line_length,
is_pyi=pyi,
is_ipynb=ipynb,
+ skip_source_first_line=skip_source_first_line,
string_normalization=not skip_string_normalization,
magic_trailing_comma=not skip_magic_trailing_comma,
experimental_string_processing=experimental_string_processing,
sources: Set[Path] = set()
root = ctx.obj["root"]
+ using_default_exclude = exclude is None
+ exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
+ gitignore: Optional[Dict[Path, PathSpec]] = None
+ root_gitignore = get_gitignore(root)
+
for s in src:
if s == "-" and stdin_filename:
p = Path(stdin_filename)
sources.add(p)
elif p.is_dir():
- if exclude is None:
- exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
- gitignore = get_gitignore(root)
- p_gitignore = get_gitignore(p)
- # No need to use p's gitignore if it is identical to root's gitignore
- # (i.e. root and p point to the same directory).
- if gitignore != p_gitignore:
- gitignore += p_gitignore
- else:
- gitignore = None
+ p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
+ if using_default_exclude:
+ gitignore = {
+ root: root_gitignore,
+ p: get_gitignore(p),
+ }
sources.update(
gen_python_files(
p.iterdir(),
mode = replace(mode, is_ipynb=True)
then = datetime.utcfromtimestamp(src.stat().st_mtime)
+ header = b""
with open(src, "rb") as buf:
+ if mode.skip_source_first_line:
+ header = buf.readline()
src_contents, encoding, newline = decode_bytes(buf.read())
try:
dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
raise ValueError(
f"File '{src}' cannot be parsed as valid Jupyter notebook."
) from None
+ src_contents = header.decode(encoding) + src_contents
+ dst_contents = header.decode(encoding) + dst_contents
if write_back == WriteBack.YES:
with open(src, "w", encoding=encoding, newline=newline) as f:
valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
`mode` is passed to :func:`format_str`.
"""
- if not src_contents.strip():
+ if not mode.preview and not src_contents.strip():
raise NothingChanged
if mode.is_ipynb:
Operate cell-by-cell, only on code cells, only for Python notebooks.
If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
"""
+ if mode.preview and not src_contents:
+ raise NothingChanged
+
trailing_newline = src_contents[-1] == "\n"
modified = False
nb = json.loads(src_contents)
def _format_str_once(src_contents: str, *, mode: Mode) -> str:
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
- dst_contents = []
+ dst_blocks: List[LinesBlock] = []
if mode.target_versions:
versions = mode.target_versions
else:
future_imports = get_future_imports(src_node)
versions = detect_target_versions(src_node, future_imports=future_imports)
+ context_manager_features = {
+ feature
+ for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
+ if supports_feature(versions, feature)
+ }
normalize_fmt_off(src_node, preview=mode.preview)
- lines = LineGenerator(mode=mode)
- elt = EmptyLineTracker(is_pyi=mode.is_pyi)
- empty_line = Line(mode=mode)
- after = 0
+ lines = LineGenerator(mode=mode, features=context_manager_features)
+ elt = EmptyLineTracker(mode=mode)
split_line_features = {
feature
for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
if supports_feature(versions, feature)
}
+ block: Optional[LinesBlock] = None
for current_line in lines.visit(src_node):
- dst_contents.append(str(empty_line) * after)
- before, after = elt.maybe_empty_lines(current_line)
- dst_contents.append(str(empty_line) * before)
+ block = elt.maybe_empty_lines(current_line)
+ dst_blocks.append(block)
for line in transform_line(
current_line, mode=mode, features=split_line_features
):
- dst_contents.append(str(line))
+ block.content_lines.append(str(line))
+ if dst_blocks:
+ dst_blocks[-1].after = 0
+ dst_contents = []
+ for block in dst_blocks:
+ dst_contents.extend(block.all_lines())
+ if mode.preview and not dst_contents:
+ # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
+ # and check if normalized_content has more than one line
+ normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
+ if "\n" in normalized_content:
+ return newline
+ return ""
return "".join(dst_contents)
- relaxed decorator syntax;
- usage of __future__ flags (annotations);
- print / exec statements;
+ - parenthesized context managers;
+ - match statements;
+ - except* clause;
+ - variadic generics;
"""
features: Set[Feature] = set()
if future_imports:
):
features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
+ elif (
+ n.type == syms.with_stmt
+ and len(n.children) > 2
+ and n.children[1].type == syms.atom
+ ):
+ atom_children = n.children[1].children
+ if (
+ len(atom_children) == 3
+ and atom_children[0].type == token.LPAR
+ and atom_children[1].type == syms.testlist_gexp
+ and atom_children[2].type == token.RPAR
+ ):
+ features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
+
+ elif n.type == syms.match_stmt:
+ features.add(Feature.PATTERN_MATCHING)
+
elif (
n.type == syms.except_clause
and len(n.children) >= 2
for module in modules:
if hasattr(module, "_verify_python3_env"):
- module._verify_python3_env = lambda: None # type: ignore
+ module._verify_python3_env = lambda: None
if hasattr(module, "_verify_python_env"):
- module._verify_python_env = lambda: None # type: ignore
+ module._verify_python_env = lambda: None
def patched_main() -> None: