import os
from pathlib import Path
from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
-import regex as re
+import re
import signal
import sys
import tokenize
MutableMapping,
Optional,
Pattern,
+ Sequence,
Set,
Sized,
Tuple,
Union,
)
-from dataclasses import replace
import click
+from click.core import ParameterSource
+from dataclasses import replace
+from mypy_extensions import mypyc_attr
from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
from black.const import STDIN_PLACEHOLDER
from black.nodes import STARS, syms, is_simple_decorator_expression
+from black.nodes import is_string_token
from black.lines import Line, EmptyLineTracker
from black.linegen import transform_line, LineGenerator, LN
from black.comments import normalize_fmt_off
-from black.mode import Mode, TargetVersion
+from black.mode import FUTURE_FLAG_TO_FEATURE, Mode, TargetVersion
from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
from black.concurrency import cancel, shutdown, maybe_install_uvloop
remove_trailing_semicolon,
put_trailing_semicolon_back,
TRANSFORMED_MAGICS,
+ PYTHON_CELL_MAGICS,
jupyter_dependencies_are_installed,
)
from _black_version import version as __version__
+COMPILED = Path(__file__).suffix in (".pyd", ".so")
+
# types
FileContent = str
Encoding = str
# Legacy name, left for integrations.
FileMode = Mode
+DEFAULT_WORKERS = os.cpu_count()
+
def read_pyproject_toml(
ctx: click.Context, param: click.Parameter, value: Optional[str]
except (OSError, ValueError) as e:
raise click.FileError(
filename=value, hint=f"Error reading configuration file: {e}"
- )
+ ) from None
if not config:
return None
ctx: click.Context,
param: click.Parameter,
value: Optional[str],
-) -> Optional[Pattern]:
+) -> Optional[Pattern[str]]:
try:
return re_compile_maybe_verbose(value) if value is not None else None
- except re.error:
- raise click.BadParameter("Not a valid regular expression")
+ except re.error as e:
+ raise click.BadParameter(f"Not a valid regular expression: {e}") from None
-@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
+@click.command(
+ context_settings={"help_option_names": ["-h", "--help"]},
+ # While Click does set this field automatically using the docstring, mypyc
+ # (annoyingly) strips 'em so we need to set it here too.
+ help="The uncompromising code formatter.",
+)
@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
@click.option(
"-l",
"(useful when piping source on standard input)."
),
)
+@click.option(
+ "--python-cell-magics",
+ multiple=True,
+ help=(
+ "When processing Jupyter Notebooks, add the given magic to the list"
+ f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
+ " Useful for formatting cells with custom python magics."
+ ),
+ default=[],
+)
@click.option(
"-S",
"--skip-string-normalization",
"--experimental-string-processing",
is_flag=True,
hidden=True,
+ help="(DEPRECATED and now included in --preview) Normalize string literals.",
+)
+@click.option(
+ "--preview",
+ is_flag=True,
help=(
- "Experimental option that performs more normalization on string literals."
- " Currently disabled because it leads to some crashes."
+ "Enable potentially disruptive style changes that will be added to Black's main"
+ " functionality in the next major release."
),
)
@click.option(
"editors that rely on using stdin."
),
)
+@click.option(
+ "-W",
+ "--workers",
+ type=click.IntRange(min=1),
+ default=DEFAULT_WORKERS,
+ show_default=True,
+ help="Number of parallel workers",
+)
@click.option(
"-q",
"--quiet",
" due to exclusion patterns."
),
)
-@click.version_option(version=__version__)
+@click.version_option(
+ version=__version__,
+ message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
+)
@click.argument(
"src",
nargs=-1,
fast: bool,
pyi: bool,
ipynb: bool,
+ python_cell_magics: Sequence[str],
skip_string_normalization: bool,
skip_magic_trailing_comma: bool,
experimental_string_processing: bool,
+ preview: bool,
quiet: bool,
verbose: bool,
- required_version: str,
- include: Pattern,
- exclude: Optional[Pattern],
- extend_exclude: Optional[Pattern],
- force_exclude: Optional[Pattern],
+ required_version: Optional[str],
+ include: Pattern[str],
+ exclude: Optional[Pattern[str]],
+ extend_exclude: Optional[Pattern[str]],
+ force_exclude: Optional[Pattern[str]],
stdin_filename: Optional[str],
+ workers: int,
src: Tuple[str, ...],
config: Optional[str],
) -> None:
"""The uncompromising code formatter."""
- if config and verbose:
- out(f"Using configuration from {config}.", bold=False, fg="blue")
+ ctx.ensure_object(dict)
+
+ if src and code is not None:
+ out(
+ main.get_usage(ctx)
+ + "\n\n'SRC' and 'code' cannot be passed simultaneously."
+ )
+ ctx.exit(1)
+ if not src and code is None:
+ out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
+ ctx.exit(1)
+
+ root, method = find_project_root(src) if code is None else (None, None)
+ ctx.obj["root"] = root
+
+ if verbose:
+ if root:
+ out(
+ f"Identified `{root}` as project root containing a {method}.",
+ fg="blue",
+ )
+
+ normalized = [
+ (normalize_path_maybe_ignore(Path(source), root), source)
+ for source in src
+ ]
+ srcs_string = ", ".join(
+ [
+ f'"{_norm}"'
+ if _norm
+ else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+ for _norm, source in normalized
+ ]
+ )
+ out(f"Sources to be formatted: {srcs_string}", fg="blue")
+
+ if config:
+ config_source = ctx.get_parameter_source("config")
+ if config_source in (ParameterSource.DEFAULT, ParameterSource.DEFAULT_MAP):
+ out("Using configuration from project root.", fg="blue")
+ else:
+ out(f"Using configuration in '{config}'.", fg="blue")
error_msg = "Oh no! 💥 💔 💥"
if required_version and required_version != __version__:
string_normalization=not skip_string_normalization,
magic_trailing_comma=not skip_magic_trailing_comma,
experimental_string_processing=experimental_string_processing,
+ preview=preview,
+ python_cell_magics=set(python_cell_magics),
)
if code is not None:
write_back=write_back,
mode=mode,
report=report,
+ workers=workers,
)
if verbose or not quiet:
+ if code is None and (verbose or report.change_count or report.failure_count):
+ out()
out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
if code is None:
click.echo(str(report), err=True)
stdin_filename: Optional[str],
) -> Set[Path]:
"""Compute the set of files to be formatted."""
-
- root = find_project_root(src)
sources: Set[Path] = set()
- path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
if exclude is None:
exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
- gitignore = get_gitignore(root)
+ gitignore = get_gitignore(ctx.obj["root"])
else:
gitignore = None
is_stdin = False
if is_stdin or p.is_file():
- normalized_path = normalize_path_maybe_ignore(p, root, report)
+ normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report)
if normalized_path is None:
continue
sources.update(
gen_python_files(
p.iterdir(),
- root,
+ ctx.obj["root"],
include,
exclude,
extend_exclude,
report.failed(src, str(exc))
+# diff-shades depends on being to monkeypatch this function to operate. I know it's
+# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
+@mypyc_attr(patchable=True)
def reformat_many(
- sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
+ sources: Set[Path],
+ fast: bool,
+ write_back: WriteBack,
+ mode: Mode,
+ report: "Report",
+ workers: Optional[int],
) -> None:
"""Reformat multiple files using a ProcessPoolExecutor."""
executor: Executor
loop = asyncio.get_event_loop()
- worker_count = os.cpu_count()
+ worker_count = workers if workers is not None else DEFAULT_WORKERS
if sys.platform == "win32":
# Work around https://bugs.python.org/issue26903
+ assert worker_count is not None
worker_count = min(worker_count, 60)
try:
executor = ProcessPoolExecutor(max_workers=worker_count)
- except (ImportError, OSError):
+ except (ImportError, NotImplementedError, OSError):
# we arrive here if the underlying system does not support multi-processing
# like in AWS Lambda or Termux, in which case we gracefully fallback to
# a ThreadPoolExecutor with just a single worker (more workers would not do us
sources_to_cache.append(src)
report.done(src, changed)
if cancelled:
- await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
+ if sys.version_info >= (3, 7):
+ await asyncio.gather(*cancelled, return_exceptions=True)
+ else:
+ await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
if sources_to_cache:
write_cache(cache, sources_to_cache, mode)
except NothingChanged:
return False
except JSONDecodeError:
- raise ValueError(f"File '{src}' cannot be parsed as valid Jupyter notebook.")
+ raise ValueError(
+ f"File '{src}' cannot be parsed as valid Jupyter notebook."
+ ) from None
if write_back == WriteBack.YES:
with open(src, "w", encoding=encoding, newline=newline) as f:
content differently.
"""
assert_equivalent(src_contents, dst_contents)
-
- # Forced second pass to work around optional trailing commas (becoming
- # forced trailing commas on pass 2) interacting differently with optional
- # parentheses. Admittedly ugly.
- dst_contents_pass2 = format_str(dst_contents, mode=mode)
- if dst_contents != dst_contents_pass2:
- dst_contents = dst_contents_pass2
- assert_equivalent(src_contents, dst_contents, pass_num=2)
- assert_stable(src_contents, dst_contents, mode=mode)
- # Note: no need to explicitly call `assert_stable` if `dst_contents` was
- # the same as `dst_contents_pass2`.
+ assert_stable(src_contents, dst_contents, mode=mode)
def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
return dst_contents
-def validate_cell(src: str) -> None:
- """Check that cell does not already contain TransformerManager transformations.
+def validate_cell(src: str, mode: Mode) -> None:
+ """Check that cell does not already contain TransformerManager transformations,
+ or non-Python cell magics, which might cause tokenizer_rt to break because of
+ indentations.
If a cell contains ``!ls``, then it'll be transformed to
``get_ipython().system('ls')``. However, if the cell originally contained
"""
if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
raise NothingChanged
+ if (
+ src[:2] == "%%"
+ and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
+ ):
+ raise NothingChanged
def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
could potentially be automagics or multi-line magics, which
are currently not supported.
"""
- validate_cell(src)
+ validate_cell(src, mode)
src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
src
)
try:
masked_src, replacements = mask_cell(src_without_trailing_semicolon)
except SyntaxError:
- raise NothingChanged
+ raise NothingChanged from None
masked_dst = format_str(masked_src, mode=mode)
if not fast:
check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
)
dst = dst.rstrip("\n")
if dst == src:
- raise NothingChanged
+ raise NothingChanged from None
return dst
"""
language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
if language is not None and language != "python":
- raise NothingChanged
+ raise NothingChanged from None
def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
"""Format Jupyter notebook.
Operate cell-by-cell, only on code cells, only for Python notebooks.
- If the ``.ipynb`` originally had a trailing newline, it'll be preseved.
+ If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
"""
trailing_newline = src_contents[-1] == "\n"
modified = False
raise NothingChanged
-def format_str(src_contents: str, *, mode: Mode) -> FileContent:
+def format_str(src_contents: str, *, mode: Mode) -> str:
"""Reformat a string and return new contents.
`mode` determines formatting options, such as how many characters per line are
hey
"""
+ dst_contents = _format_str_once(src_contents, mode=mode)
+ # Forced second pass to work around optional trailing commas (becoming
+ # forced trailing commas on pass 2) interacting differently with optional
+ # parentheses. Admittedly ugly.
+ if src_contents != dst_contents:
+ return _format_str_once(dst_contents, mode=mode)
+ return dst_contents
+
+
+def _format_str_once(src_contents: str, *, mode: Mode) -> str:
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
dst_contents = []
future_imports = get_future_imports(src_node)
if mode.target_versions:
versions = mode.target_versions
else:
- versions = detect_target_versions(src_node)
+ versions = detect_target_versions(src_node, future_imports=future_imports)
+
normalize_fmt_off(src_node)
- lines = LineGenerator(
- mode=mode,
- remove_u_prefix="unicode_literals" in future_imports
- or supports_feature(versions, Feature.UNICODE_LITERALS),
- )
+ lines = LineGenerator(mode=mode)
elt = EmptyLineTracker(is_pyi=mode.is_pyi)
empty_line = Line(mode=mode)
after = 0
return tiow.read(), encoding, newline
-def get_features_used(node: Node) -> Set[Feature]:
+def get_features_used( # noqa: C901
+ node: Node, *, future_imports: Optional[Set[str]] = None
+) -> Set[Feature]:
"""Return a set of (relatively) new Python features used in this file.
Currently looking for:
- positional only arguments in function signatures and lambdas;
- assignment expression;
- relaxed decorator syntax;
+ - usage of __future__ flags (annotations);
+ - print / exec statements;
"""
features: Set[Feature] = set()
+ if future_imports:
+ features |= {
+ FUTURE_FLAG_TO_FEATURE[future_import]
+ for future_import in future_imports
+ if future_import in FUTURE_FLAG_TO_FEATURE
+ }
+
for n in node.pre_order():
- if n.type == token.STRING:
- value_head = n.value[:2] # type: ignore
+ if is_string_token(n):
+ value_head = n.value[:2]
if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
features.add(Feature.F_STRINGS)
elif n.type == token.NUMBER:
- if "_" in n.value: # type: ignore
+ assert isinstance(n, Leaf)
+ if "_" in n.value:
features.add(Feature.NUMERIC_UNDERSCORES)
elif n.type == token.SLASH:
- if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
+ if n.parent and n.parent.type in {
+ syms.typedargslist,
+ syms.arglist,
+ syms.varargslist,
+ }:
features.add(Feature.POS_ONLY_ARGUMENTS)
elif n.type == token.COLONEQUAL:
if argch.type in STARS:
features.add(feature)
+ elif (
+ n.type in {syms.return_stmt, syms.yield_expr}
+ and len(n.children) >= 2
+ and n.children[1].type == syms.testlist_star_expr
+ and any(child.type == syms.star_expr for child in n.children[1].children)
+ ):
+ features.add(Feature.UNPACKING_ON_FLOW)
+
+ elif (
+ n.type == syms.annassign
+ and len(n.children) >= 4
+ and n.children[3].type == syms.testlist_star_expr
+ ):
+ features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
+
return features
-def detect_target_versions(node: Node) -> Set[TargetVersion]:
+def detect_target_versions(
+ node: Node, *, future_imports: Optional[Set[str]] = None
+) -> Set[TargetVersion]:
"""Detect the version to target based on the nodes used."""
- features = get_features_used(node)
+ features = get_features_used(node, future_imports=future_imports)
return {
version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
}
src_ast = parse_ast(src)
except Exception as exc:
raise AssertionError(
- "cannot use --safe with this file; failed to parse source file. AST"
- f" error message: {exc}"
- )
+ f"cannot use --safe with this file; failed to parse source file AST: "
+ f"{exc}\n"
+ f"This could be caused by running Black with an older Python version "
+ f"that does not support new syntax used in your source file."
+ ) from exc
try:
dst_ast = parse_ast(dst)
def assert_stable(src: str, dst: str, mode: Mode) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
- newdst = format_str(dst, mode=mode)
+ # We shouldn't call format_str() here, because that formats the string
+ # twice and may hide a bug where we bounce back and forth between two
+ # versions.
+ newdst = _format_str_once(dst, mode=mode)
if dst != newdst:
log = dump_to_file(
str(mode),
"""
try:
from click import core
- from click import _unicodefun # type: ignore
+ from click import _unicodefun
except ModuleNotFoundError:
return