import colorama # noqa: F401
DEFAULT_LINE_LENGTH = 88
-DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950
+DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950
DEFAULT_INCLUDES = r"\.pyi?$"
CACHE_DIR = Path(user_cache_dir("black", version=__version__))
STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__"
return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
-def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]:
+def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
"""Find the absolute filepath to a pyproject.toml if it exists"""
path_project_root = find_project_root(path_search_start)
path_pyproject_toml = path_project_root / "pyproject.toml"
- return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+ if path_pyproject_toml.is_file():
+ return str(path_pyproject_toml)
+
+ path_user_pyproject_toml = find_user_pyproject_toml()
+ return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None
def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
"--check",
is_flag=True,
help=(
- "Don't write the files back, just return the status. Return code 0 means"
- " nothing would change. Return code 1 means some files would be reformatted."
+ "Don't write the files back, just return the status. Return code 0 means"
+ " nothing would change. Return code 1 means some files would be reformatted."
" Return code 123 means there was an internal error."
),
)
callback=validate_regex,
help=(
"A regular expression that matches files and directories that should be"
- " included on recursive searches. An empty value means all files are included"
- " regardless of the name. Use forward slashes for directories on all platforms"
- " (Windows, too). Exclusions are calculated first, inclusions later."
+ " included on recursive searches. An empty value means all files are included"
+ " regardless of the name. Use forward slashes for directories on all platforms"
+ " (Windows, too). Exclusions are calculated first, inclusions later."
),
show_default=True,
)
callback=validate_regex,
help=(
"A regular expression that matches files and directories that should be"
- " excluded on recursive searches. An empty value means no paths are excluded."
- " Use forward slashes for directories on all platforms (Windows, too). "
+ " excluded on recursive searches. An empty value means no paths are excluded."
+ " Use forward slashes for directories on all platforms (Windows, too)."
" Exclusions are calculated first, inclusions later."
),
show_default=True,
except (ImportError, OSError):
# we arrive here if the underlying system does not support multi-processing
# like in AWS Lambda or Termux, in which case we gracefully fallback to
- # a ThreadPollExecutor with just a single worker (more workers would not do us
+ # a ThreadPoolExecutor with just a single worker (more workers would not do us
# any good due to the Global Interpreter Lock)
executor = ThreadPoolExecutor(max_workers=1)
): src
for src in sorted(sources)
}
- pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
+ pending = tasks.keys()
try:
loop.add_signal_handler(signal.SIGINT, cancel, pending)
loop.add_signal_handler(signal.SIGTERM, cancel, pending)
if not fast:
assert_equivalent(src_contents, dst_contents)
- assert_stable(src_contents, dst_contents, mode=mode)
+
+ # Forced second pass to work around optional trailing commas (becoming
+ # forced trailing commas on pass 2) interacting differently with optional
+ # parentheses. Admittedly ugly.
+ dst_contents_pass2 = format_str(dst_contents, mode=mode)
+ if dst_contents != dst_contents_pass2:
+ dst_contents = dst_contents_pass2
+ assert_equivalent(src_contents, dst_contents, pass_num=2)
+ assert_stable(src_contents, dst_contents, mode=mode)
+ # Note: no need to explicitly call `assert_stable` if `dst_contents` was
+ # the same as `dst_contents_pass2`.
return dst_contents
# We're ignoring docstrings with backslash newline escapes because changing
# indentation of those changes the AST representation of the code.
prefix = get_string_prefix(leaf.value)
- lead_len = len(prefix) + 3
- tail_len = -3
- indent = " " * 4 * self.current_line.depth
- docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
+ docstring = leaf.value[len(prefix) :] # Remove the prefix
+ quote_char = docstring[0]
+ # A natural way to remove the outer quotes is to do:
+ # docstring = docstring.strip(quote_char)
+ # but that breaks on """""x""" (which is '""x').
+ # So we actually need to remove the first character and the next two
+ # characters but only if they are the same as the first.
+ quote_len = 1 if docstring[1] != quote_char else 3
+ docstring = docstring[quote_len:-quote_len]
+
+ if is_multiline_string(leaf):
+ indent = " " * 4 * self.current_line.depth
+ docstring = fix_docstring(docstring, indent)
+ else:
+ docstring = docstring.strip()
+
if docstring:
- if leaf.value[lead_len - 1] == docstring[0]:
+ # Add some padding if the docstring starts / ends with a quote mark.
+ if docstring[0] == quote_char:
docstring = " " + docstring
- if leaf.value[tail_len + 1] == docstring[-1]:
+ if docstring[-1] == quote_char:
docstring = docstring + " "
- leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
+ else:
+ # Add some padding if the docstring is empty.
+ docstring = " "
+
+ # We could enforce triple quotes at this point.
+ quote = quote_char * quote_len
+ leaf.value = prefix + quote + docstring + quote
yield from self.visit_default(leaf)
if content[0] == "#":
content = content[1:]
+ NON_BREAKING_SPACE = " "
+ if (
+ content
+ and content[0] == NON_BREAKING_SPACE
+ and not content.lstrip().startswith("type:")
+ ):
+ content = " " + content[1:] # Replace NBSP by a simple space
if content and content[0] not in " !:#'%":
content = " " + content
return "#" + content
def format_hex(text: str) -> str:
"""
- Formats a hexadecimal string like "0x12b3"
-
- Uses lowercase because of similarity between "B" and "8", which
- can cause security issues.
- see: https://github.com/psf/black/issues/1692
+ Formats a hexadecimal string like "0x12B3"
"""
-
before, after = text[:2], text[2:]
- return f"{before}{after.lower()}"
+ return f"{before}{after.upper()}"
def format_scientific_notation(text: str) -> str:
@lru_cache()
def get_gitignore(root: Path) -> PathSpec:
- """ Return a PathSpec matching gitignore content if present."""
+ """Return a PathSpec matching gitignore content if present."""
gitignore = root / ".gitignore"
lines: List[str] = []
if gitignore.is_file():
@lru_cache()
-def find_project_root(srcs: Iterable[str]) -> Path:
+def find_project_root(srcs: Tuple[str, ...]) -> Path:
"""Return a directory containing .git, .hg, or pyproject.toml.
That directory will be a common parent of all files and directories
return directory
+@lru_cache()
+def find_user_pyproject_toml() -> Path:
+ r"""Return the path to the top-level user configuration for black.
+
+ This looks for ~\.black on Windows and ~/.config/black on Linux and other
+ Unix systems.
+ """
+ if sys.platform == "win32":
+ # Windows
+ user_config_path = Path.home() / ".black"
+ else:
+ config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config")
+ user_config_path = Path(config_root).expanduser() / "black"
+ return user_config_path.resolve()
+
+
@dataclass
class Report:
"""Provides a reformatting counter. Can be rendered with `str(report)`."""
# Constant strings may be indented across newlines, if they are
# docstrings; fold spaces after newlines when comparing. Similarly,
# trailing and leading space may be removed.
+ # Note that when formatting Python 2 code, at least with Windows
+ # line-endings, docstrings can end up here as bytes instead of
+ # str so make sure that we handle both cases.
if (
isinstance(node, ast.Constant)
and field == "value"
- and isinstance(value, str)
+ and isinstance(value, (str, bytes))
):
- normalized = re.sub(r" *\n[ \t]*", "\n", value).strip()
+ lineend = "\n" if isinstance(value, str) else b"\n"
+ # To normalize, we strip any leading and trailing space from
+ # each line...
+ stripped = [line.strip() for line in value.splitlines()]
+ normalized = lineend.join(stripped) # type: ignore[attr-defined]
+ # ...and remove any blank lines at the beginning and end of
+ # the whole string
+ normalized = normalized.strip()
else:
normalized = value
yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}"
yield f"{' ' * depth}) # /{node.__class__.__name__}"
-def assert_equivalent(src: str, dst: str) -> None:
+def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
try:
src_ast = parse_ast(src)
except Exception as exc:
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
raise AssertionError(
- f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug"
- " on https://github.com/psf/black/issues. This invalid output might be"
- f" helpful: {log}"
+ f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
+ "Please report a bug on https://github.com/psf/black/issues. "
+ f"This invalid output might be helpful: {log}"
) from None
src_ast_str = "\n".join(_stringify_ast(src_ast))
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError(
"INTERNAL ERROR: Black produced code that is not equivalent to the"
- " source. Please report a bug on https://github.com/psf/black/issues. "
- f" This diff might be helpful: {log}"
+ f" source on pass {pass_num}. Please report a bug on "
+ f"https://github.com/psf/black/issues. This diff might be helpful: {log}"
) from None
def is_docstring(leaf: Leaf) -> bool:
- if not is_multiline_string(leaf):
- # For the purposes of docstring re-indentation, we don't need to do anything
- # with single-line docstrings.
- return False
-
if prev_siblings_are(
leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
):