import asyncio
from json.decoder import JSONDecodeError
import json
-from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
from contextlib import contextmanager
from datetime import datetime
from enum import Enum
import os
from pathlib import Path
from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
+import platform
import re
import signal
import sys
import tokenize
import traceback
from typing import (
+ TYPE_CHECKING,
Any,
Dict,
Generator,
MutableMapping,
Optional,
Pattern,
+ Sequence,
Set,
Sized,
Tuple,
from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
from black.const import STDIN_PLACEHOLDER
from black.nodes import STARS, syms, is_simple_decorator_expression
-from black.nodes import is_string_token
+from black.nodes import is_string_token, is_number_token
from black.lines import Line, EmptyLineTracker
from black.linegen import transform_line, LineGenerator, LN
from black.comments import normalize_fmt_off
from black.concurrency import cancel, shutdown, maybe_install_uvloop
from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
from black.report import Report, Changed, NothingChanged
-from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
+from black.files import (
+ find_project_root,
+ find_pyproject_toml,
+ parse_pyproject_toml,
+ find_user_pyproject_toml,
+)
from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
from black.files import wrap_stream_for_windows
from black.parsing import InvalidInput # noqa F401
from _black_version import version as __version__
+if TYPE_CHECKING:
+ from concurrent.futures import Executor
+
COMPILED = Path(__file__).suffix in (".pyd", ".so")
# types
"(useful when piping source on standard input)."
),
)
+@click.option(
+ "--python-cell-magics",
+ multiple=True,
+ help=(
+ "When processing Jupyter Notebooks, add the given magic to the list"
+ f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
+ " Useful for formatting cells with custom python magics."
+ ),
+ default=[],
+)
@click.option(
"-S",
"--skip-string-normalization",
"--preview",
is_flag=True,
help=(
- "Enable potentially disruptive style changes that will be added to Black's main"
+ "Enable potentially disruptive style changes that may be added to Black's main"
" functionality in the next major release."
),
)
type=str,
help=(
"Require a specific version of Black to be running (useful for unifying results"
- " across many environments e.g. with a pyproject.toml file)."
+ " across many environments e.g. with a pyproject.toml file). It can be"
+ " either a major version number or an exact version."
),
)
@click.option(
)
@click.version_option(
version=__version__,
- message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
+ message=(
+ f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
+ f"Python ({platform.python_implementation()}) {platform.python_version()}"
+ ),
)
@click.argument(
"src",
help="Read configuration from FILE path.",
)
@click.pass_context
-def main(
+def main( # noqa: C901
ctx: click.Context,
code: Optional[str],
line_length: int,
fast: bool,
pyi: bool,
ipynb: bool,
+ python_cell_magics: Sequence[str],
skip_string_normalization: bool,
skip_magic_trailing_comma: bool,
experimental_string_processing: bool,
) -> None:
"""The uncompromising code formatter."""
ctx.ensure_object(dict)
+
+ if src and code is not None:
+ out(
+ main.get_usage(ctx)
+ + "\n\n'SRC' and 'code' cannot be passed simultaneously."
+ )
+ ctx.exit(1)
+ if not src and code is None:
+ out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
+ ctx.exit(1)
+
root, method = find_project_root(src) if code is None else (None, None)
ctx.obj["root"] = root
if config:
config_source = ctx.get_parameter_source("config")
- if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP):
+ user_level_config = str(find_user_pyproject_toml())
+ if config == user_level_config:
+ out(
+ "Using configuration from user-level config at "
+ f"'{user_level_config}'.",
+ fg="blue",
+ )
+ elif config_source in (
+ ParameterSource.DEFAULT,
+ ParameterSource.DEFAULT_MAP,
+ ):
out("Using configuration from project root.", fg="blue")
else:
out(f"Using configuration in '{config}'.", fg="blue")
error_msg = "Oh no! 💥 💔 💥"
- if required_version and required_version != __version__:
+ if (
+ required_version
+ and required_version != __version__
+ and required_version != __version__.split(".")[0]
+ ):
err(
f"{error_msg} The required version `{required_version}` does not match"
f" the running version `{__version__}`!"
magic_trailing_comma=not skip_magic_trailing_comma,
experimental_string_processing=experimental_string_processing,
preview=preview,
+ python_cell_magics=set(python_cell_magics),
)
if code is not None:
) -> Set[Path]:
"""Compute the set of files to be formatted."""
sources: Set[Path] = set()
- path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
if exclude is None:
exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
report.failed(path, str(exc))
+# diff-shades depends on being to monkeypatch this function to operate. I know it's
+# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
+@mypyc_attr(patchable=True)
def reformat_one(
src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
) -> None:
workers: Optional[int],
) -> None:
"""Reformat multiple files using a ProcessPoolExecutor."""
+ from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
+
executor: Executor
- loop = asyncio.get_event_loop()
worker_count = workers if workers is not None else DEFAULT_WORKERS
if sys.platform == "win32":
# Work around https://bugs.python.org/issue26903
# any good due to the Global Interpreter Lock)
executor = ThreadPoolExecutor(max_workers=1)
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
schedule_formatting(
)
)
finally:
- shutdown(loop)
+ try:
+ shutdown(loop)
+ finally:
+ asyncio.set_event_loop(None)
if executor is not None:
executor.shutdown()
mode: Mode,
report: "Report",
loop: asyncio.AbstractEventLoop,
- executor: Executor,
+ executor: "Executor",
) -> None:
"""Run formatting of `sources` in parallel using the provided `executor`.
content differently.
"""
assert_equivalent(src_contents, dst_contents)
-
- # Forced second pass to work around optional trailing commas (becoming
- # forced trailing commas on pass 2) interacting differently with optional
- # parentheses. Admittedly ugly.
- dst_contents_pass2 = format_str(dst_contents, mode=mode)
- if dst_contents != dst_contents_pass2:
- dst_contents = dst_contents_pass2
- assert_equivalent(src_contents, dst_contents, pass_num=2)
- assert_stable(src_contents, dst_contents, mode=mode)
- # Note: no need to explicitly call `assert_stable` if `dst_contents` was
- # the same as `dst_contents_pass2`.
+ assert_stable(src_contents, dst_contents, mode=mode)
def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
return dst_contents
-def validate_cell(src: str) -> None:
+def validate_cell(src: str, mode: Mode) -> None:
"""Check that cell does not already contain TransformerManager transformations,
or non-Python cell magics, which might cause tokenizer_rt to break because of
indentations.
"""
if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
raise NothingChanged
- if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS:
+ if (
+ src[:2] == "%%"
+ and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
+ ):
raise NothingChanged
could potentially be automagics or multi-line magics, which
are currently not supported.
"""
- validate_cell(src)
+ validate_cell(src, mode)
src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
src
)
raise NothingChanged
-def format_str(src_contents: str, *, mode: Mode) -> FileContent:
+def format_str(src_contents: str, *, mode: Mode) -> str:
"""Reformat a string and return new contents.
`mode` determines formatting options, such as how many characters per line are
hey
"""
+ dst_contents = _format_str_once(src_contents, mode=mode)
+ # Forced second pass to work around optional trailing commas (becoming
+ # forced trailing commas on pass 2) interacting differently with optional
+ # parentheses. Admittedly ugly.
+ if src_contents != dst_contents:
+ return _format_str_once(dst_contents, mode=mode)
+ return dst_contents
+
+
+def _format_str_once(src_contents: str, *, mode: Mode) -> str:
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
dst_contents = []
- future_imports = get_future_imports(src_node)
if mode.target_versions:
versions = mode.target_versions
else:
+ future_imports = get_future_imports(src_node)
versions = detect_target_versions(src_node, future_imports=future_imports)
- normalize_fmt_off(src_node)
+ normalize_fmt_off(src_node, preview=mode.preview)
lines = LineGenerator(mode=mode)
elt = EmptyLineTracker(is_pyi=mode.is_pyi)
empty_line = Line(mode=mode)
if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
features.add(Feature.F_STRINGS)
- elif n.type == token.NUMBER:
- assert isinstance(n, Leaf)
+ elif is_number_token(n):
if "_" in n.value:
features.add(Feature.NUMERIC_UNDERSCORES)
):
features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
+ elif (
+ n.type == syms.except_clause
+ and len(n.children) >= 2
+ and n.children[1].type == token.STAR
+ ):
+ features.add(Feature.EXCEPT_STAR)
+
+ elif n.type in {syms.subscriptlist, syms.trailer} and any(
+ child.type == syms.star_expr for child in n.children
+ ):
+ features.add(Feature.VARIADIC_GENERICS)
+
+ elif (
+ n.type == syms.tname_star
+ and len(n.children) == 3
+ and n.children[2].type == syms.star_expr
+ ):
+ features.add(Feature.VARIADIC_GENERICS)
+
return features
return imports
-def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
+def assert_equivalent(src: str, dst: str) -> None:
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
try:
src_ast = parse_ast(src)
except Exception as exc:
raise AssertionError(
- f"cannot use --safe with this file; failed to parse source file AST: "
+ "cannot use --safe with this file; failed to parse source file AST: "
f"{exc}\n"
- f"This could be caused by running Black with an older Python version "
- f"that does not support new syntax used in your source file."
+ "This could be caused by running Black with an older Python version "
+ "that does not support new syntax used in your source file."
) from exc
try:
except Exception as exc:
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
raise AssertionError(
- f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
+ f"INTERNAL ERROR: Black produced invalid code: {exc}. "
"Please report a bug on https://github.com/psf/black/issues. "
f"This invalid output might be helpful: {log}"
) from None
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError(
"INTERNAL ERROR: Black produced code that is not equivalent to the"
- f" source on pass {pass_num}. Please report a bug on "
+ " source. Please report a bug on "
f"https://github.com/psf/black/issues. This diff might be helpful: {log}"
) from None
def assert_stable(src: str, dst: str, mode: Mode) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
- newdst = format_str(dst, mode=mode)
+ # We shouldn't call format_str() here, because that formats the string
+ # twice and may hide a bug where we bounce back and forth between two
+ # versions.
+ newdst = _format_str_once(dst, mode=mode)
if dst != newdst:
log = dump_to_file(
str(mode),
file paths is minimal since it's Python source code. Moreover, this crash was
spurious on Python 3.7 thanks to PEP 538 and PEP 540.
"""
+ modules: List[Any] = []
try:
from click import core
- from click import _unicodefun
- except ModuleNotFoundError:
- return
+ except ImportError:
+ pass
+ else:
+ modules.append(core)
+ try:
+ # Removed in Click 8.1.0 and newer; we keep this around for users who have
+ # older versions installed.
+ from click import _unicodefun # type: ignore
+ except ImportError:
+ pass
+ else:
+ modules.append(_unicodefun)
- for module in (core, _unicodefun):
+ for module in modules:
if hasattr(module, "_verify_python3_env"):
module._verify_python3_env = lambda: None # type: ignore
if hasattr(module, "_verify_python_env"):