cast,
TYPE_CHECKING,
)
-from typing_extensions import Final
from mypy_extensions import mypyc_attr
from appdirs import user_cache_dir
from dataclasses import dataclass, field, replace
import click
import toml
-from typed_ast import ast3, ast27
+
+try:
+ from typed_ast import ast3, ast27
+except ImportError:
+ if sys.version_info < (3, 8):
+ print(
+ "The typed_ast package is not installed.\n"
+ "You can install it with `python3 -m pip install typed-ast`.",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+ else:
+ ast3 = ast27 = ast
+
from pathspec import PathSpec
# lib2to3 fork
from _black_version import version as __version__
+if sys.version_info < (3, 8):
+ from typing_extensions import Final
+else:
+ from typing import Final
+
if TYPE_CHECKING:
import colorama # noqa: F401
DEFAULT_LINE_LENGTH = 88
-DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950
+DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950
DEFAULT_INCLUDES = r"\.pyi?$"
CACHE_DIR = Path(user_cache_dir("black", version=__version__))
+STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__"
STRING_PREFIX_CHARS: Final = "furbFURB" # All possible string prefix characters.
Timestamp = float
FileSize = int
CacheInfo = Tuple[Timestamp, FileSize]
-Cache = Dict[Path, CacheInfo]
+Cache = Dict[str, CacheInfo]
out = partial(click.secho, bold=True, err=True)
err = partial(click.secho, fg="red", err=True)
PY36 = 6
PY37 = 7
PY38 = 8
+ PY39 = 9
def is_python2(self) -> bool:
return self is TargetVersion.PY27
-PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
-
-
class Feature(Enum):
# All string literals are unicode
UNICODE_LITERALS = 1
ASYNC_KEYWORDS = 7
ASSIGNMENT_EXPRESSIONS = 8
POS_ONLY_ARGUMENTS = 9
+ RELAXED_DECORATORS = 10
FORCE_OPTIONAL_PARENTHESES = 50
Feature.ASSIGNMENT_EXPRESSIONS,
Feature.POS_ONLY_ARGUMENTS,
},
+ TargetVersion.PY39: {
+ Feature.UNICODE_LITERALS,
+ Feature.F_STRINGS,
+ Feature.NUMERIC_UNDERSCORES,
+ Feature.TRAILING_COMMA_IN_CALL,
+ Feature.TRAILING_COMMA_IN_DEF,
+ Feature.ASYNC_KEYWORDS,
+ Feature.ASSIGNMENT_EXPRESSIONS,
+ Feature.RELAXED_DECORATORS,
+ Feature.POS_ONLY_ARGUMENTS,
+ },
}
target_versions: Set[TargetVersion] = field(default_factory=set)
line_length: int = DEFAULT_LINE_LENGTH
string_normalization: bool = True
- experimental_string_processing: bool = False
is_pyi: bool = False
+ magic_trailing_comma: bool = True
+ experimental_string_processing: bool = False
def get_cache_key(self) -> str:
if self.target_versions:
str(self.line_length),
str(int(self.string_normalization)),
str(int(self.is_pyi)),
+ str(int(self.magic_trailing_comma)),
+ str(int(self.experimental_string_processing)),
]
return ".".join(parts)
return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
-def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
+def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
"""Find the absolute filepath to a pyproject.toml if it exists"""
path_project_root = find_project_root(path_search_start)
path_pyproject_toml = path_project_root / "pyproject.toml"
- return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+ if path_pyproject_toml.is_file():
+ return str(path_pyproject_toml)
+
+ path_user_pyproject_toml = find_user_pyproject_toml()
+ return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None
def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
return [TargetVersion[val.upper()] for val in v]
+def validate_regex(
+ ctx: click.Context,
+ param: click.Parameter,
+ value: Optional[str],
+) -> Optional[Pattern]:
+ try:
+ return re_compile_maybe_verbose(value) if value is not None else None
+ except re.error:
+ raise click.BadParameter("Not a valid regular expression")
+
+
@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
@click.option(
is_flag=True,
help="Don't normalize string quotes or prefixes.",
)
+@click.option(
+ "-C",
+ "--skip-magic-trailing-comma",
+ is_flag=True,
+ help="Don't use trailing commas as a reason to split lines.",
+)
@click.option(
"--experimental-string-processing",
is_flag=True,
"--check",
is_flag=True,
help=(
- "Don't write the files back, just return the status. Return code 0 means"
- " nothing would change. Return code 1 means some files would be reformatted."
+ "Don't write the files back, just return the status. Return code 0 means"
+ " nothing would change. Return code 1 means some files would be reformatted."
" Return code 123 means there was an internal error."
),
)
"--include",
type=str,
default=DEFAULT_INCLUDES,
+ callback=validate_regex,
help=(
"A regular expression that matches files and directories that should be"
- " included on recursive searches. An empty value means all files are included"
- " regardless of the name. Use forward slashes for directories on all platforms"
- " (Windows, too). Exclusions are calculated first, inclusions later."
+ " included on recursive searches. An empty value means all files are included"
+ " regardless of the name. Use forward slashes for directories on all platforms"
+ " (Windows, too). Exclusions are calculated first, inclusions later."
),
show_default=True,
)
"--exclude",
type=str,
default=DEFAULT_EXCLUDES,
+ callback=validate_regex,
help=(
"A regular expression that matches files and directories that should be"
- " excluded on recursive searches. An empty value means no paths are excluded."
- " Use forward slashes for directories on all platforms (Windows, too). "
+ " excluded on recursive searches. An empty value means no paths are excluded."
+ " Use forward slashes for directories on all platforms (Windows, too)."
" Exclusions are calculated first, inclusions later."
),
show_default=True,
)
+@click.option(
+ "--extend-exclude",
+ type=str,
+ callback=validate_regex,
+ help=(
+ "Like --exclude, but adds additional files and directories on top of the"
+ " excluded ones. (Useful if you simply want to add to the default)"
+ ),
+)
@click.option(
"--force-exclude",
type=str,
+ callback=validate_regex,
help=(
"Like --exclude, but files and directories matching this regex will be "
- "excluded even when they are passed explicitly as arguments"
+ "excluded even when they are passed explicitly as arguments."
+ ),
+)
+@click.option(
+ "--stdin-filename",
+ type=str,
+ help=(
+ "The name of the file when passing it through stdin. Useful to make "
+ "sure Black will respect --force-exclude option on some "
+ "editors that rely on using stdin."
),
)
@click.option(
is_flag=True,
help=(
"Also emit messages to stderr about files that were not changed or were ignored"
- " due to --exclude=."
+ " due to exclusion patterns."
),
)
@click.version_option(version=__version__)
fast: bool,
pyi: bool,
skip_string_normalization: bool,
+ skip_magic_trailing_comma: bool,
experimental_string_processing: bool,
quiet: bool,
verbose: bool,
- include: str,
- exclude: str,
- force_exclude: Optional[str],
+ include: Pattern,
+ exclude: Pattern,
+ extend_exclude: Optional[Pattern],
+ force_exclude: Optional[Pattern],
+ stdin_filename: Optional[str],
src: Tuple[str, ...],
config: Optional[str],
) -> None:
line_length=line_length,
is_pyi=pyi,
string_normalization=not skip_string_normalization,
+ magic_trailing_comma=not skip_magic_trailing_comma,
experimental_string_processing=experimental_string_processing,
)
if config and verbose:
verbose=verbose,
include=include,
exclude=exclude,
+ extend_exclude=extend_exclude,
force_exclude=force_exclude,
report=report,
+ stdin_filename=stdin_filename,
)
path_empty(
src: Tuple[str, ...],
quiet: bool,
verbose: bool,
- include: str,
- exclude: str,
- force_exclude: Optional[str],
+ include: Pattern[str],
+ exclude: Pattern[str],
+ extend_exclude: Optional[Pattern[str]],
+ force_exclude: Optional[Pattern[str]],
report: "Report",
+ stdin_filename: Optional[str],
) -> Set[Path]:
"""Compute the set of files to be formatted."""
- try:
- include_regex = re_compile_maybe_verbose(include)
- except re.error:
- err(f"Invalid regular expression for include given: {include!r}")
- ctx.exit(2)
- try:
- exclude_regex = re_compile_maybe_verbose(exclude)
- except re.error:
- err(f"Invalid regular expression for exclude given: {exclude!r}")
- ctx.exit(2)
- try:
- force_exclude_regex = (
- re_compile_maybe_verbose(force_exclude) if force_exclude else None
- )
- except re.error:
- err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
- ctx.exit(2)
root = find_project_root(src)
sources: Set[Path] = set()
gitignore = get_gitignore(root)
for s in src:
- p = Path(s)
- if p.is_dir():
- sources.update(
- gen_python_files(
- p.iterdir(),
- root,
- include_regex,
- exclude_regex,
- force_exclude_regex,
- report,
- gitignore,
- )
- )
- elif s == "-":
- sources.add(p)
- elif p.is_file():
+ if s == "-" and stdin_filename:
+ p = Path(stdin_filename)
+ is_stdin = True
+ else:
+ p = Path(s)
+ is_stdin = False
+
+ if is_stdin or p.is_file():
normalized_path = normalize_path_maybe_ignore(p, root, report)
if normalized_path is None:
continue
normalized_path = "/" + normalized_path
# Hard-exclude any files that matches the `--force-exclude` regex.
- if force_exclude_regex:
- force_exclude_match = force_exclude_regex.search(normalized_path)
+ if force_exclude:
+ force_exclude_match = force_exclude.search(normalized_path)
else:
force_exclude_match = None
if force_exclude_match and force_exclude_match.group(0):
report.path_ignored(p, "matches the --force-exclude regular expression")
continue
+ if is_stdin:
+ p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
+
+ sources.add(p)
+ elif p.is_dir():
+ sources.update(
+ gen_python_files(
+ p.iterdir(),
+ root,
+ include,
+ exclude,
+ extend_exclude,
+ force_exclude,
+ report,
+ gitignore,
+ )
+ )
+ elif s == "-":
sources.add(p)
else:
err(f"invalid path: {s}")
"""
try:
changed = Changed.NO
- if not src.is_file() and str(src) == "-":
+
+ if str(src) == "-":
+ is_stdin = True
+ elif str(src).startswith(STDIN_PLACEHOLDER):
+ is_stdin = True
+ # Use the original name again in case we want to print something
+ # to the user
+ src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
+ else:
+ is_stdin = False
+
+ if is_stdin:
if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
changed = Changed.YES
else:
if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
cache = read_cache(mode)
res_src = src.resolve()
- if res_src in cache and cache[res_src] == get_cache_info(res_src):
+ res_src_s = str(res_src)
+ if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
changed = Changed.CACHED
if changed is not Changed.CACHED and format_file_in_place(
src, fast=fast, write_back=write_back, mode=mode
worker_count = os.cpu_count()
if sys.platform == "win32":
# Work around https://bugs.python.org/issue26903
- worker_count = min(worker_count, 61)
+ worker_count = min(worker_count, 60)
try:
executor = ProcessPoolExecutor(max_workers=worker_count)
except (ImportError, OSError):
# we arrive here if the underlying system does not support multi-processing
# like in AWS Lambda or Termux, in which case we gracefully fallback to
- # a ThreadPollExecutor with just a single worker (more workers would not do us
+ # a ThreadPoolExecutor with just a single worker (more workers would not do us
# any good due to the Global Interpreter Lock)
executor = ThreadPoolExecutor(max_workers=1)
): src
for src in sorted(sources)
}
- pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
+ pending = tasks.keys()
try:
loop.add_signal_handler(signal.SIGINT, cancel, pending)
loop.add_signal_handler(signal.SIGTERM, cancel, pending)
dst_name = f"{src}\t{now} +0000"
diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
- if write_back == write_back.COLOR_DIFF:
+ if write_back == WriteBack.COLOR_DIFF:
diff_contents = color_diff(diff_contents)
with lock or nullcontext():
for i, line in enumerate(lines):
if line.startswith("+++") or line.startswith("---"):
line = "\033[1;37m" + line + "\033[0m" # bold white, reset
- if line.startswith("@@"):
+ elif line.startswith("@@"):
line = "\033[36m" + line + "\033[0m" # cyan, reset
- if line.startswith("+"):
+ elif line.startswith("+"):
line = "\033[32m" + line + "\033[0m" # green, reset
elif line.startswith("-"):
line = "\033[31m" + line + "\033[0m" # red, reset
def wrap_stream_for_windows(
f: io.TextIOWrapper,
-) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
+) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32"]:
"""
- Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
+ Wrap stream with colorama's wrap_stream so colors are shown on Windows.
- If `colorama` is not found, then no change is made. If `colorama` does
- exist, then it handles the logic to determine whether or not to change
- things.
+ If `colorama` is unavailable, the original stream is returned unmodified.
+ Otherwise, the `wrap_stream()` function determines whether the stream needs
+ to be wrapped for a Windows environment and will accordingly either return
+ an `AnsiToWin32` wrapper or the original stream.
"""
try:
- from colorama import initialise
-
- # We set `strip=False` so that we can don't have to modify
- # test_express_diff_with_color.
- f = initialise.wrap_stream(
- f, convert=None, strip=False, autoreset=False, wrap=True
- )
-
- # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
- # which does not have a `detach()` method. So we fake one.
- f.detach = lambda *args, **kwargs: None # type: ignore
+ from colorama.initialise import wrap_stream
except ImportError:
- pass
-
- return f
+ return f
+ else:
+ # Set `strip=False` to avoid needing to modify test_express_diff_with_color.
+ return wrap_stream(f, convert=None, strip=False, autoreset=False, wrap=True)
def format_stdin_to_stdout(
if not fast:
assert_equivalent(src_contents, dst_contents)
- assert_stable(src_contents, dst_contents, mode=mode)
+
+ # Forced second pass to work around optional trailing commas (becoming
+ # forced trailing commas on pass 2) interacting differently with optional
+ # parentheses. Admittedly ugly.
+ dst_contents_pass2 = format_str(dst_contents, mode=mode)
+ if dst_contents != dst_contents_pass2:
+ dst_contents = dst_contents_pass2
+ assert_equivalent(src_contents, dst_contents, pass_num=2)
+ assert_stable(src_contents, dst_contents, mode=mode)
+ # Note: no need to explicitly call `assert_stable` if `dst_contents` was
+ # the same as `dst_contents_pass2`.
return dst_contents
allowed. Example:
>>> import black
- >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
+ >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
def f(arg: str = "") -> None:
...
versions = detect_target_versions(src_node)
normalize_fmt_off(src_node)
lines = LineGenerator(
+ mode=mode,
remove_u_prefix="unicode_literals" in future_imports
or supports_feature(versions, Feature.UNICODE_LITERALS),
- is_pyi=mode.is_pyi,
- normalize_strings=mode.string_normalization,
)
elt = EmptyLineTracker(is_pyi=mode.is_pyi)
- empty_line = Line()
+ empty_line = Line(mode=mode)
after = 0
split_line_features = {
feature
class Line:
"""Holds leaves and comments. Can be printed with `str(line)`."""
+ mode: Mode
depth: int = 0
leaves: List[Leaf] = field(default_factory=list)
# keys ordered like `leaves`
comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
inside_brackets: bool = False
- should_explode: bool = False
+ should_split_rhs: bool = False
+ magic_trailing_comma: Optional[Leaf] = None
def append(self, leaf: Leaf, preformatted: bool = False) -> None:
"""Add a new `leaf` to the end of the line.
)
if self.inside_brackets or not preformatted:
self.bracket_tracker.mark(leaf)
- if self.maybe_should_explode(leaf):
- self.should_explode = True
+ if self.mode.magic_trailing_comma:
+ if self.has_magic_trailing_comma(leaf):
+ self.magic_trailing_comma = leaf
+ elif self.has_magic_trailing_comma(leaf, ensure_removable=True):
+ self.remove_trailing_comma()
if not self.append_comment(leaf):
self.leaves.append(leaf)
def contains_multiline_strings(self) -> bool:
return any(is_multiline_string(leaf) for leaf in self.leaves)
- def maybe_should_explode(self, closing: Leaf) -> bool:
- """Return True if this line should explode (always be split), that is when:
- - there's a trailing comma here; and
- - it's not a one-tuple.
+ def has_magic_trailing_comma(
+ self, closing: Leaf, ensure_removable: bool = False
+ ) -> bool:
+ """Return True if we have a magic trailing comma, that is when:
+ - there's a trailing comma here
+ - it's not a one-tuple
+ Additionally, if ensure_removable:
+ - it's not from square bracket indexing
"""
if not (
closing.type in CLOSING_BRACKETS
):
return False
- if closing.type in {token.RBRACE, token.RSQB}:
+ if closing.type == token.RBRACE:
return True
+ if closing.type == token.RSQB:
+ if not ensure_removable:
+ return True
+ comma = self.leaves[-1]
+ return bool(comma.parent and comma.parent.type == syms.listmaker)
+
if self.is_import:
return True
def clone(self) -> "Line":
return Line(
+ mode=self.mode,
depth=self.depth,
inside_brackets=self.inside_brackets,
- should_explode=self.should_explode,
+ should_split_rhs=self.should_split_rhs,
+ magic_trailing_comma=self.magic_trailing_comma,
)
def __str__(self) -> str:
in ways that will no longer stringify to valid Python code on the tree.
"""
- is_pyi: bool = False
- normalize_strings: bool = True
- current_line: Line = field(default_factory=Line)
+ mode: Mode
remove_u_prefix: bool = False
+ current_line: Line = field(init=False)
def line(self, indent: int = 0) -> Iterator[Line]:
"""Generate a line.
return # Line is empty, don't emit. Creating a new one unnecessary.
complete_line = self.current_line
- self.current_line = Line(depth=complete_line.depth + indent)
+ self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent)
yield complete_line
def visit_default(self, node: LN) -> Iterator[Line]:
yield from self.line()
normalize_prefix(node, inside_brackets=any_open_brackets)
- if self.normalize_strings and node.type == token.STRING:
+ if self.mode.string_normalization and node.type == token.STRING:
normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
normalize_string_quotes(node)
if node.type == token.NUMBER:
def visit_suite(self, node: Node) -> Iterator[Line]:
"""Visit a suite."""
- if self.is_pyi and is_stub_suite(node):
+ if self.mode.is_pyi and is_stub_suite(node):
yield from self.visit(node.children[2])
else:
yield from self.visit_default(node)
def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
"""Visit a statement without nested statements."""
+ if first_child_is_arith(node):
+ wrap_in_parentheses(node, node.children[0], visible=False)
is_suite_like = node.parent and node.parent.type in STATEMENT
if is_suite_like:
- if self.is_pyi and is_stub_body(node):
+ if self.mode.is_pyi and is_stub_body(node):
yield from self.visit_default(node)
else:
yield from self.line(+1)
yield from self.line(-1)
else:
- if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
+ if (
+ not self.mode.is_pyi
+ or not node.parent
+ or not is_stub_suite(node.parent)
+ ):
yield from self.line()
yield from self.visit_default(node)
# We're ignoring docstrings with backslash newline escapes because changing
# indentation of those changes the AST representation of the code.
prefix = get_string_prefix(leaf.value)
- lead_len = len(prefix) + 3
- tail_len = -3
- indent = " " * 4 * self.current_line.depth
- docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
+ docstring = leaf.value[len(prefix) :] # Remove the prefix
+ quote_char = docstring[0]
+ # A natural way to remove the outer quotes is to do:
+ # docstring = docstring.strip(quote_char)
+ # but that breaks on """""x""" (which is '""x').
+ # So we actually need to remove the first character and the next two
+ # characters but only if they are the same as the first.
+ quote_len = 1 if docstring[1] != quote_char else 3
+ docstring = docstring[quote_len:-quote_len]
+
+ if is_multiline_string(leaf):
+ indent = " " * 4 * self.current_line.depth
+ docstring = fix_docstring(docstring, indent)
+ else:
+ docstring = docstring.strip()
+
if docstring:
- if leaf.value[lead_len - 1] == docstring[0]:
+ # Add some padding if the docstring starts / ends with a quote mark.
+ if docstring[0] == quote_char:
docstring = " " + docstring
- if leaf.value[tail_len + 1] == docstring[-1]:
+ if docstring[-1] == quote_char:
docstring = docstring + " "
- leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
+ else:
+ # Add some padding if the docstring is empty.
+ docstring = " "
+
+ # We could enforce triple quotes at this point.
+ quote = quote_char * quote_len
+ leaf.value = prefix + quote + docstring + quote
yield from self.visit_default(leaf)
def __post_init__(self) -> None:
"""You are in a twisty little maze of passages."""
+ self.current_line = Line(mode=self.mode)
+
v = self.visit_stmt
Ø: Set[str] = set()
self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
):
# Python 2 print chevron
return NO
+ elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:
+ # no space in decorators
+ return NO
elif prev.type in OPENING_BRACKETS:
return NO
FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
+FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
+FMT_PASS = {*FMT_OFF, *FMT_SKIP}
FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
consumed = 0
nlines = 0
ignored_lines = 0
- for index, line in enumerate(prefix.split("\n")):
+ for index, line in enumerate(re.split("\r?\n", prefix)):
consumed += len(line) + 1 # adding the length of the split '\n'
line = line.lstrip()
if not line:
if content[0] == "#":
content = content[1:]
+ NON_BREAKING_SPACE = " "
+ if (
+ content
+ and content[0] == NON_BREAKING_SPACE
+ and not content.lstrip().startswith("type:")
+ ):
+ content = " " + content[1:] # Replace NBSP by a simple space
if content and content[0] not in " !:#'%":
content = " " + content
return "#" + content
transformers: List[Transformer]
if (
not line.contains_uncollapsable_type_comments()
- and not line.should_explode
+ and not line.should_split_rhs
+ and not line.magic_trailing_comma
and (
is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
or line.contains_unsplittable_type_ignore()
Requirements:
The line contains a string which is surrounded by parentheses and:
- - The target string is NOT the only argument to a function call).
+ - The target string is NOT the only argument to a function call.
+ - The target string is NOT a "pointless" string.
- If the target string contains a PERCENT, the brackets are not
preceeded or followed by an operator with higher precedence than
PERCENT.
if leaf.type != token.STRING:
continue
+ # If this is a "pointless" string...
+ if (
+ leaf.parent
+ and leaf.parent.parent
+ and leaf.parent.parent.type == syms.simple_stmt
+ ):
+ continue
+
# Should be preceded by a non-empty LPAR...
if (
not is_valid_index(idx - 1)
MIN_SUBSTR_SIZE characters.
The string will ONLY be split on spaces (i.e. each new substring should
- start with a space).
+ start with a space). Note that the string will NOT be split on a space
+ which is escaped with a backslash.
If the string is an f-string, it will NOT be split in the middle of an
f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
MIN_SUBSTR_SIZE = 6
# Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
RE_FEXPR = r"""
- (?<!\{)\{
+ (?<!\{) (?:\{\{)* \{ (?!\{)
(?:
[^\{\}]
| \{\{
| \}\}
+ | (?R)
)+?
- (?<!\})(?:\}\})*\}(?!\})
+ (?<!\}) \} (?:\}\})* (?!\})
"""
def do_splitter_match(self, line: Line) -> TMatchResult:
section of this classes' docstring would be be met by returning @i.
"""
is_space = string[i] == " "
+
+ is_not_escaped = True
+ j = i - 1
+ while is_valid_index(j) and string[j] == "\\":
+ is_not_escaped = not is_not_escaped
+ j -= 1
+
is_big_enough = (
len(string[i:]) >= self.MIN_SUBSTR_SIZE
and len(string[:i]) >= self.MIN_SUBSTR_SIZE
)
- return is_space and is_big_enough and not breaks_fstring_expression(i)
+ return (
+ is_space
+ and is_not_escaped
+ and is_big_enough
+ and not breaks_fstring_expression(i)
+ )
# First, we check all indices BELOW @max_break_idx.
break_idx = max_break_idx
# `StringSplitter` will break it down further if necessary.
string_value = LL[string_idx].value
string_line = Line(
+ mode=line.mode,
depth=line.depth + 1,
inside_brackets=True,
- should_explode=line.should_explode,
+ should_split_rhs=line.should_split_rhs,
+ magic_trailing_comma=line.magic_trailing_comma,
)
string_leaf = Leaf(token.STRING, string_value)
insert_str_child(string_leaf)
If `is_body` is True, the result line is one-indented inside brackets and as such
has its first leaf's prefix normalized and a trailing comma added when expected.
"""
- result = Line(depth=original.depth)
+ result = Line(mode=original.mode, depth=original.depth)
if is_body:
result.inside_brackets = True
result.depth += 1
result.append(leaf, preformatted=True)
for comment_after in original.comments_after(leaf):
result.append(comment_after, preformatted=True)
- if is_body and should_split_body_explode(result, opening_bracket):
- result.should_explode = True
+ if is_body and should_split_line(result, opening_bracket):
+ result.should_split_rhs = True
return result
if bt.delimiter_count_with_priority(delimiter_priority) == 1:
raise CannotSplit("Splitting a single attribute from its owner looks wrong")
- current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+ current_line = Line(
+ mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+ )
lowest_depth = sys.maxsize
trailing_comma_safe = True
except ValueError:
yield current_line
- current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+ current_line = Line(
+ mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+ )
current_line.append(leaf)
for leaf in line.leaves:
if leaf_priority == delimiter_priority:
yield current_line
- current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+ current_line = Line(
+ mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+ )
if current_line:
if (
trailing_comma_safe
if not line.contains_standalone_comments(0):
raise CannotSplit("Line does not have any standalone comments")
- current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+ current_line = Line(
+ mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+ )
def append_to_line(leaf: Leaf) -> Iterator[Line]:
"""Append `leaf` to current line or to new line if appending impossible."""
except ValueError:
yield current_line
- current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+ current_line = Line(
+ line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+ )
current_line.append(leaf)
for leaf in line.leaves:
# Leave octal and binary literals alone.
pass
elif text.startswith("0x"):
- # Change hex literals to upper case.
- before, after = text[:2], text[2:]
- text = f"{before}{after.upper()}"
+ text = format_hex(text)
elif "e" in text:
- before, after = text.split("e")
- sign = ""
- if after.startswith("-"):
- after = after[1:]
- sign = "-"
- elif after.startswith("+"):
- after = after[1:]
- before = format_float_or_int_string(before)
- text = f"{before}e{sign}{after}"
+ text = format_scientific_notation(text)
elif text.endswith(("j", "l")):
- number = text[:-1]
- suffix = text[-1]
- # Capitalize in "2L" because "l" looks too similar to "1".
- if suffix == "l":
- suffix = "L"
- text = f"{format_float_or_int_string(number)}{suffix}"
+ text = format_long_or_complex_number(text)
else:
text = format_float_or_int_string(text)
leaf.value = text
+def format_hex(text: str) -> str:
+ """
+ Formats a hexadecimal string like "0x12B3"
+ """
+ before, after = text[:2], text[2:]
+ return f"{before}{after.upper()}"
+
+
+def format_scientific_notation(text: str) -> str:
+ """Formats a numeric string utilizing scentific notation"""
+ before, after = text.split("e")
+ sign = ""
+ if after.startswith("-"):
+ after = after[1:]
+ sign = "-"
+ elif after.startswith("+"):
+ after = after[1:]
+ before = format_float_or_int_string(before)
+ return f"{before}e{sign}{after}"
+
+
+def format_long_or_complex_number(text: str) -> str:
+ """Formats a long or complex string like `10L` or `10j`"""
+ number = text[:-1]
+ suffix = text[-1]
+ # Capitalize in "2L" because "l" looks too similar to "1".
+ if suffix == "l":
+ suffix = "L"
+ return f"{format_float_or_int_string(number)}{suffix}"
+
+
def format_float_or_int_string(text: str) -> str:
"""Formats a float string like "1.0"."""
if "." not in text:
check_lpar = True
if check_lpar:
- if is_walrus_assignment(child):
- pass
-
- elif child.type == syms.atom:
+ if child.type == syms.atom:
if maybe_make_parens_invisible_in_atom(child, parent=node):
wrap_in_parentheses(node, child, visible=False)
elif is_one_tuple(child):
for leaf in node.leaves():
previous_consumed = 0
for comment in list_comments(leaf.prefix, is_endmarker=False):
- if comment.value in FMT_OFF:
- # We only want standalone comments. If there's no previous leaf or
- # the previous leaf is indentation, it's a standalone comment in
- # disguise.
- if comment.type != STANDALONE_COMMENT:
- prev = preceding_leaf(leaf)
- if prev and prev.type not in WHITESPACE:
+ if comment.value not in FMT_PASS:
+ previous_consumed = comment.consumed
+ continue
+ # We only want standalone comments. If there's no previous leaf or
+ # the previous leaf is indentation, it's a standalone comment in
+ # disguise.
+ if comment.value in FMT_PASS and comment.type != STANDALONE_COMMENT:
+ prev = preceding_leaf(leaf)
+ if prev:
+ if comment.value in FMT_OFF and prev.type not in WHITESPACE:
+ continue
+ if comment.value in FMT_SKIP and prev.type in WHITESPACE:
continue
- ignored_nodes = list(generate_ignored_nodes(leaf))
- if not ignored_nodes:
- continue
-
- first = ignored_nodes[0] # Can be a container node with the `leaf`.
- parent = first.parent
- prefix = first.prefix
- first.prefix = prefix[comment.consumed :]
- hidden_value = (
- comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
- )
- if hidden_value.endswith("\n"):
- # That happens when one of the `ignored_nodes` ended with a NEWLINE
- # leaf (possibly followed by a DEDENT).
- hidden_value = hidden_value[:-1]
- first_idx: Optional[int] = None
- for ignored in ignored_nodes:
- index = ignored.remove()
- if first_idx is None:
- first_idx = index
- assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
- assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
- parent.insert_child(
- first_idx,
- Leaf(
- STANDALONE_COMMENT,
- hidden_value,
- prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
- ),
- )
- return True
+ ignored_nodes = list(generate_ignored_nodes(leaf, comment))
+ if not ignored_nodes:
+ continue
- previous_consumed = comment.consumed
+ first = ignored_nodes[0] # Can be a container node with the `leaf`.
+ parent = first.parent
+ prefix = first.prefix
+ first.prefix = prefix[comment.consumed :]
+ hidden_value = "".join(str(n) for n in ignored_nodes)
+ if comment.value in FMT_OFF:
+ hidden_value = comment.value + "\n" + hidden_value
+ if comment.value in FMT_SKIP:
+ hidden_value += " " + comment.value
+ if hidden_value.endswith("\n"):
+ # That happens when one of the `ignored_nodes` ended with a NEWLINE
+ # leaf (possibly followed by a DEDENT).
+ hidden_value = hidden_value[:-1]
+ first_idx: Optional[int] = None
+ for ignored in ignored_nodes:
+ index = ignored.remove()
+ if first_idx is None:
+ first_idx = index
+ assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
+ assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
+ parent.insert_child(
+ first_idx,
+ Leaf(
+ STANDALONE_COMMENT,
+ hidden_value,
+ prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
+ ),
+ )
+ return True
return False
-def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
+def generate_ignored_nodes(leaf: Leaf, comment: ProtoComment) -> Iterator[LN]:
"""Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
+ If comment is skip, returns leaf only.
Stops at the end of the block.
"""
container: Optional[LN] = container_of(leaf)
+ if comment.value in FMT_SKIP:
+ prev_sibling = leaf.prev_sibling
+ if comment.value in leaf.prefix and prev_sibling is not None:
+ leaf.prefix = leaf.prefix.replace(comment.value, "")
+ siblings = [prev_sibling]
+ while (
+ "\n" not in prev_sibling.prefix
+ and prev_sibling.prev_sibling is not None
+ ):
+ prev_sibling = prev_sibling.prev_sibling
+ siblings.insert(0, prev_sibling)
+ for sibling in siblings:
+ yield sibling
+ elif leaf.parent is not None:
+ yield leaf.parent
+ return
while container is not None and container.type != token.ENDMARKER:
if is_fmt_on(container):
return
Returns whether the node should itself be wrapped in invisible parentheses.
"""
+
if (
node.type != syms.atom
or is_empty_tuple(node)
):
return False
+ if is_walrus_assignment(node):
+ if parent.type in [
+ syms.annassign,
+ syms.expr_stmt,
+ syms.assert_stmt,
+ syms.return_stmt,
+ ]:
+ return False
+
first = node.children[0]
last = node.children[-1]
if first.type == token.LPAR and last.type == token.RPAR:
return wrapped
+def first_child_is_arith(node: Node) -> bool:
+ """Whether first child is an arithmetic or a binary arithmetic expression"""
+ expr_types = {
+ syms.arith_expr,
+ syms.shift_expr,
+ syms.xor_expr,
+ syms.and_expr,
+ }
+ return bool(node.children and node.children[0].type in expr_types)
+
+
def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
"""Wrap `child` in parentheses.
return inner is not None and inner.type == syms.namedexpr_test
+def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:
+ """Return True iff `node` is a trailer valid in a simple decorator"""
+ return node.type == syms.trailer and (
+ (
+ len(node.children) == 2
+ and node.children[0].type == token.DOT
+ and node.children[1].type == token.NAME
+ )
+ # last trailer can be arguments
+ or (
+ last
+ and len(node.children) == 3
+ and node.children[0].type == token.LPAR
+ # and node.children[1].type == syms.argument
+ and node.children[2].type == token.RPAR
+ )
+ )
+
+
+def is_simple_decorator_expression(node: LN) -> bool:
+ """Return True iff `node` could be a 'dotted name' decorator
+
+ This function takes the node of the 'namedexpr_test' of the new decorator
+ grammar and test if it would be valid under the old decorator grammar.
+
+ The old grammar was: decorator: @ dotted_name [arguments] NEWLINE
+ The new grammar is : decorator: @ namedexpr_test NEWLINE
+ """
+ if node.type == token.NAME:
+ return True
+ if node.type == syms.power:
+ if node.children:
+ return (
+ node.children[0].type == token.NAME
+ and all(map(is_simple_decorator_trailer, node.children[1:-1]))
+ and (
+ len(node.children) < 2
+ or is_simple_decorator_trailer(node.children[-1], last=True)
+ )
+ )
+ return False
+
+
def is_yield(node: LN) -> bool:
"""Return True if `node` holds a `yield` or `yield from` expression."""
if node.type == syms.yield_expr:
leaf.value = ")"
-def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool:
+def should_split_line(line: Line, opening_bracket: Leaf) -> bool:
"""Should `line` be immediately split with `delimiter_split()` after RHS?"""
if not (opening_bracket.parent and opening_bracket.value in "[{("):
return False
return max_priority == COMMA_PRIORITY and (
- trailing_comma
+ (line.mode.magic_trailing_comma and trailing_comma)
# always explode imports
or opening_bracket.parent.type in {syms.atom, syms.import_from}
)
- underscores in numeric literals;
- trailing commas after * or ** in function signatures and calls;
- positional only arguments in function signatures and lambdas;
+ - assignment expression;
+ - relaxed decorator syntax;
"""
features: Set[Feature] = set()
for n in node.pre_order():
elif n.type == token.COLONEQUAL:
features.add(Feature.ASSIGNMENT_EXPRESSIONS)
+ elif n.type == syms.decorator:
+ if len(n.children) > 1 and not is_simple_decorator_expression(
+ n.children[1]
+ ):
+ features.add(Feature.RELAXED_DECORATORS)
+
elif (
n.type in {syms.typedargslist, syms.arglist}
and n.children
"""
omit: Set[LeafID] = set()
- if not line.should_explode:
+ if not line.magic_trailing_comma:
yield omit
length = 4 * line.depth
elif leaf.type in CLOSING_BRACKETS:
prev = line.leaves[index - 1] if index > 0 else None
if (
- line.should_explode
- and prev
+ prev
and prev.type == token.COMMA
and not is_one_tuple_between(
leaf.opening_bracket, leaf, line.leaves
yield omit
if (
- line.should_explode
- and prev
+ prev
and prev.type == token.COMMA
and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
):
@lru_cache()
def get_gitignore(root: Path) -> PathSpec:
- """ Return a PathSpec matching gitignore content if present."""
+ """Return a PathSpec matching gitignore content if present."""
gitignore = root / ".gitignore"
lines: List[str] = []
if gitignore.is_file():
return normalized_path
+def path_is_excluded(
+ normalized_path: str,
+ pattern: Optional[Pattern[str]],
+) -> bool:
+ match = pattern.search(normalized_path) if pattern else None
+ return bool(match and match.group(0))
+
+
def gen_python_files(
paths: Iterable[Path],
root: Path,
include: Optional[Pattern[str]],
exclude: Pattern[str],
+ extend_exclude: Optional[Pattern[str]],
force_exclude: Optional[Pattern[str]],
report: "Report",
gitignore: PathSpec,
) -> Iterator[Path]:
"""Generate all files under `path` whose paths are not excluded by the
- `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
+ `exclude_regex`, `extend_exclude`, or `force_exclude` regexes,
+ but are included by the `include` regex.
Symbolic links pointing outside of the `root` directory are ignored.
report.path_ignored(child, "matches the .gitignore file content")
continue
- # Then ignore with `--exclude` and `--force-exclude` options.
+ # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options.
normalized_path = "/" + normalized_path
if child.is_dir():
normalized_path += "/"
- exclude_match = exclude.search(normalized_path) if exclude else None
- if exclude_match and exclude_match.group(0):
+ if path_is_excluded(normalized_path, exclude):
report.path_ignored(child, "matches the --exclude regular expression")
continue
- force_exclude_match = (
- force_exclude.search(normalized_path) if force_exclude else None
- )
- if force_exclude_match and force_exclude_match.group(0):
+ if path_is_excluded(normalized_path, extend_exclude):
+ report.path_ignored(
+ child, "matches the --extend-exclude regular expression"
+ )
+ continue
+
+ if path_is_excluded(normalized_path, force_exclude):
report.path_ignored(child, "matches the --force-exclude regular expression")
continue
root,
include,
exclude,
+ extend_exclude,
force_exclude,
report,
gitignore,
@lru_cache()
-def find_project_root(srcs: Iterable[str]) -> Path:
+def find_project_root(srcs: Tuple[str, ...]) -> Path:
"""Return a directory containing .git, .hg, or pyproject.toml.
That directory will be a common parent of all files and directories
return directory
+@lru_cache()
+def find_user_pyproject_toml() -> Path:
+ r"""Return the path to the top-level user configuration for black.
+
+ This looks for ~\.black on Windows and ~/.config/black on Linux and other
+ Unix systems.
+ """
+ if sys.platform == "win32":
+ # Windows
+ user_config_path = Path.home() / ".black"
+ else:
+ config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config")
+ user_config_path = Path(config_root).expanduser() / "black"
+ return user_config_path.resolve()
+
+
@dataclass
class Report:
"""Provides a reformatting counter. Can be rendered with `str(report)`."""
return ast3.parse(src, filename, feature_version=feature_version)
except SyntaxError:
continue
-
+ if ast27.__name__ == "ast":
+ raise SyntaxError(
+ "The requested source code has invalid Python 3 syntax.\n"
+ "If you are trying to format Python 2 files please reinstall Black"
+ " with the 'python2' extra: `python3 -m pip install black[python2]`."
+ )
return ast27.parse(src)
# Constant strings may be indented across newlines, if they are
# docstrings; fold spaces after newlines when comparing. Similarly,
# trailing and leading space may be removed.
+ # Note that when formatting Python 2 code, at least with Windows
+ # line-endings, docstrings can end up here as bytes instead of
+ # str so make sure that we handle both cases.
if (
isinstance(node, ast.Constant)
and field == "value"
- and isinstance(value, str)
+ and isinstance(value, (str, bytes))
):
- normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
+ lineend = "\n" if isinstance(value, str) else b"\n"
+ # To normalize, we strip any leading and trailing space from
+ # each line...
+ stripped = [line.strip() for line in value.splitlines()]
+ normalized = lineend.join(stripped) # type: ignore[attr-defined]
+ # ...and remove any blank lines at the beginning and end of
+ # the whole string
+ normalized = normalized.strip()
else:
normalized = value
yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}"
yield f"{' ' * depth}) # /{node.__class__.__name__}"
-def assert_equivalent(src: str, dst: str) -> None:
+def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
try:
src_ast = parse_ast(src)
except Exception as exc:
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
raise AssertionError(
- f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
- " on https://github.com/psf/black/issues. This invalid output might be"
- f" helpful: {log}"
+ f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
+ "Please report a bug on https://github.com/psf/black/issues. "
+ f"This invalid output might be helpful: {log}"
) from None
src_ast_str = "\n".join(_stringify_ast(src_ast))
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError(
"INTERNAL ERROR: Black produced code that is not equivalent to the"
- " source. Please report a bug on https://github.com/psf/black/issues. "
- f" This diff might be helpful: {log}"
+ f" source on pass {pass_num}. Please report a bug on "
+ f"https://github.com/psf/black/issues. This diff might be helpful: {log}"
) from None
@mypyc_attr(patchable=True)
-def dump_to_file(*output: str) -> str:
+def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str:
"""Dump `output` to a temporary file. Return path to the file."""
with tempfile.NamedTemporaryFile(
mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
) as f:
for lines in output:
f.write(lines)
- if lines and lines[-1] != "\n":
+ if ensure_final_newline and lines and lines[-1] != "\n":
f.write("\n")
return f.name
"""Return a unified diff string between strings `a` and `b`."""
import difflib
- a_lines = [line + "\n" for line in a.splitlines()]
- b_lines = [line + "\n" for line in b.splitlines()]
- return "".join(
- difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
- )
+ a_lines = [line for line in a.splitlines(keepends=True)]
+ b_lines = [line for line in b.splitlines(keepends=True)]
+ diff_lines = []
+ for line in difflib.unified_diff(
+ a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5
+ ):
+ # Work around https://bugs.python.org/issue2142
+ # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html
+ if line[-1] == "\n":
+ diff_lines.append(line)
+ else:
+ diff_lines.append(line + "\n")
+ diff_lines.append("\\ No newline at end of file\n")
+ return "".join(diff_lines)
def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
penultimate = line.leaves[-2]
last = line.leaves[-1]
- if line.should_explode:
+ if line.magic_trailing_comma:
try:
penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
except LookupError:
# unnecessary.
return True
- if line.should_explode and penultimate.type == token.COMMA:
+ if line.magic_trailing_comma and penultimate.type == token.COMMA:
# The rightmost non-omitted bracket pair is the one we want to explode on.
return True
"""
todo, done = set(), set()
for src in sources:
- src = src.resolve()
- if cache.get(src) != get_cache_info(src):
+ res_src = src.resolve()
+ if cache.get(str(res_src)) != get_cache_info(res_src):
todo.add(src)
else:
done.add(src)
cache_file = get_cache_file(mode)
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
- new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
+ new_cache = {
+ **cache,
+ **{str(src.resolve()): get_cache_info(src) for src in sources},
+ }
with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
pickle.dump(new_cache, f, protocol=4)
os.replace(f.name, cache_file)
def is_docstring(leaf: Leaf) -> bool:
- if not is_multiline_string(leaf):
- # For the purposes of docstring re-indentation, we don't need to do anything
- # with single-line docstrings.
- return False
-
if prev_siblings_are(
leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
):
return False
+def lines_with_leading_tabs_expanded(s: str) -> List[str]:
+ """
+ Splits string into lines and expands only leading tabs (following the normal
+ Python rules)
+ """
+ lines = []
+ for line in s.splitlines():
+ # Find the index of the first non-whitespace character after a string of
+ # whitespace that includes at least one tab
+ match = re.match(r"\s*\t+\s*(\S)", line)
+ if match:
+ first_non_whitespace_idx = match.start(1)
+
+ lines.append(
+ line[:first_non_whitespace_idx].expandtabs()
+ + line[first_non_whitespace_idx:]
+ )
+ else:
+ lines.append(line)
+ return lines
+
+
def fix_docstring(docstring: str, prefix: str) -> str:
# https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
if not docstring:
return ""
- # Convert tabs to spaces (following the normal Python rules)
- # and split into a list of lines:
- lines = docstring.expandtabs().splitlines()
+ lines = lines_with_leading_tabs_expanded(docstring)
# Determine minimum indentation (first line doesn't count):
indent = sys.maxsize
for line in lines[1:]: