-import asyncio
-from json.decoder import JSONDecodeError
-import json
-from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
-from contextlib import contextmanager
-from datetime import datetime
-from enum import Enum
import io
-from multiprocessing import Manager, freeze_support
-import os
-from pathlib import Path
-from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
-import regex as re
-import signal
+import json
+import platform
+import re
import sys
import tokenize
import traceback
+from contextlib import contextmanager
+from dataclasses import replace
+from datetime import datetime, timezone
+from enum import Enum
+from json.decoder import JSONDecodeError
+from pathlib import Path
from typing import (
Any,
Dict,
MutableMapping,
Optional,
Pattern,
+ Sequence,
Set,
Sized,
Tuple,
Union,
)
-from dataclasses import replace
import click
+from click.core import ParameterSource
+from mypy_extensions import mypyc_attr
+from pathspec import PathSpec
+from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
-from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
-from black.const import STDIN_PLACEHOLDER
-from black.nodes import STARS, syms, is_simple_decorator_expression
-from black.lines import Line, EmptyLineTracker
-from black.linegen import transform_line, LineGenerator, LN
+from _black_version import version as __version__
+from black.cache import Cache, get_cache_info, read_cache, write_cache
from black.comments import normalize_fmt_off
-from black.mode import Mode, TargetVersion
-from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
-from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
-from black.concurrency import cancel, shutdown, maybe_install_uvloop
-from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
-from black.report import Report, Changed, NothingChanged
-from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
-from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
-from black.files import wrap_stream_for_windows
-from black.parsing import InvalidInput # noqa F401
-from black.parsing import lib2to3_parse, parse_ast, stringify_ast
+from black.const import (
+ DEFAULT_EXCLUDES,
+ DEFAULT_INCLUDES,
+ DEFAULT_LINE_LENGTH,
+ STDIN_PLACEHOLDER,
+)
+from black.files import (
+ find_project_root,
+ find_pyproject_toml,
+ find_user_pyproject_toml,
+ gen_python_files,
+ get_gitignore,
+ normalize_path_maybe_ignore,
+ parse_pyproject_toml,
+ wrap_stream_for_windows,
+)
from black.handle_ipynb_magics import (
- mask_cell,
- unmask_cell,
- remove_trailing_semicolon,
- put_trailing_semicolon_back,
+ PYTHON_CELL_MAGICS,
TRANSFORMED_MAGICS,
jupyter_dependencies_are_installed,
+ mask_cell,
+ put_trailing_semicolon_back,
+ remove_trailing_semicolon,
+ unmask_cell,
)
-
-
-# lib2to3 fork
-from blib2to3.pytree import Node, Leaf
+from black.linegen import LN, LineGenerator, transform_line
+from black.lines import EmptyLineTracker, LinesBlock
+from black.mode import (
+ FUTURE_FLAG_TO_FEATURE,
+ VERSION_TO_FEATURES,
+ Feature,
+ Mode,
+ TargetVersion,
+ supports_feature,
+)
+from black.nodes import (
+ STARS,
+ is_number_token,
+ is_simple_decorator_expression,
+ is_string_token,
+ syms,
+)
+from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
+from black.parsing import InvalidInput # noqa F401
+from black.parsing import lib2to3_parse, parse_ast, stringify_ast
+from black.report import Changed, NothingChanged, Report
+from black.trans import iter_fexpr_spans
from blib2to3.pgen2 import token
+from blib2to3.pytree import Leaf, Node
-from _black_version import version as __version__
+COMPILED = Path(__file__).suffix in (".pyd", ".so")
# types
FileContent = str
# Legacy name, left for integrations.
FileMode = Mode
-DEFAULT_WORKERS = os.cpu_count()
-
def read_pyproject_toml(
ctx: click.Context, param: click.Parameter, value: Optional[str]
otherwise.
"""
if not value:
- value = find_pyproject_toml(ctx.params.get("src", ()))
+ value = find_pyproject_toml(
+ ctx.params.get("src", ()), ctx.params.get("stdin_filename", None)
+ )
if value is None:
return None
) -> Optional[Pattern[str]]:
try:
return re_compile_maybe_verbose(value) if value is not None else None
- except re.error:
- raise click.BadParameter("Not a valid regular expression") from None
+ except re.error as e:
+ raise click.BadParameter(f"Not a valid regular expression: {e}") from None
-@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
+@click.command(
+ context_settings={"help_option_names": ["-h", "--help"]},
+ # While Click does set this field automatically using the docstring, mypyc
+ # (annoyingly) strips 'em so we need to set it here too.
+ help="The uncompromising code formatter.",
+)
@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
@click.option(
"-l",
callback=target_version_option_callback,
multiple=True,
help=(
- "Python versions that should be supported by Black's output. [default: per-file"
- " auto-detection]"
+ "Python versions that should be supported by Black's output. By default, Black"
+ " will try to infer this from the project metadata in pyproject.toml. If this"
+ " does not yield conclusive results, Black will use per-file auto-detection."
),
)
@click.option(
"(useful when piping source on standard input)."
),
)
+@click.option(
+ "--python-cell-magics",
+ multiple=True,
+ help=(
+ "When processing Jupyter Notebooks, add the given magic to the list"
+ f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
+ " Useful for formatting cells with custom python magics."
+ ),
+ default=[],
+)
+@click.option(
+ "-x",
+ "--skip-source-first-line",
+ is_flag=True,
+ help="Skip the first line of the source code.",
+)
@click.option(
"-S",
"--skip-string-normalization",
"--experimental-string-processing",
is_flag=True,
hidden=True,
+ help="(DEPRECATED and now included in --preview) Normalize string literals.",
+)
+@click.option(
+ "--preview",
+ is_flag=True,
help=(
- "Experimental option that performs more normalization on string literals."
- " Currently disabled because it leads to some crashes."
+ "Enable potentially disruptive style changes that may be added to Black's main"
+ " functionality in the next major release."
),
)
@click.option(
type=str,
help=(
"Require a specific version of Black to be running (useful for unifying results"
- " across many environments e.g. with a pyproject.toml file)."
+ " across many environments e.g. with a pyproject.toml file). It can be"
+ " either a major version number or an exact version."
),
)
@click.option(
@click.option(
"--stdin-filename",
type=str,
+ is_eager=True,
help=(
"The name of the file when passing it through stdin. Useful to make "
"sure Black will respect --force-exclude option on some "
"-W",
"--workers",
type=click.IntRange(min=1),
- default=DEFAULT_WORKERS,
- show_default=True,
- help="Number of parallel workers",
+ default=None,
+ help=(
+ "Number of parallel workers [default: BLACK_NUM_WORKERS environment variable "
+ "or number of CPUs in the system]"
+ ),
)
@click.option(
"-q",
" due to exclusion patterns."
),
)
-@click.version_option(version=__version__)
+@click.version_option(
+ version=__version__,
+ message=(
+ f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n"
+ f"Python ({platform.python_implementation()}) {platform.python_version()}"
+ ),
+)
@click.argument(
"src",
nargs=-1,
help="Read configuration from FILE path.",
)
@click.pass_context
-def main(
+def main( # noqa: C901
ctx: click.Context,
code: Optional[str],
line_length: int,
fast: bool,
pyi: bool,
ipynb: bool,
+ python_cell_magics: Sequence[str],
+ skip_source_first_line: bool,
skip_string_normalization: bool,
skip_magic_trailing_comma: bool,
experimental_string_processing: bool,
+ preview: bool,
quiet: bool,
verbose: bool,
- required_version: str,
+ required_version: Optional[str],
include: Pattern[str],
exclude: Optional[Pattern[str]],
extend_exclude: Optional[Pattern[str]],
force_exclude: Optional[Pattern[str]],
stdin_filename: Optional[str],
- workers: int,
+ workers: Optional[int],
src: Tuple[str, ...],
config: Optional[str],
) -> None:
"""The uncompromising code formatter."""
- if config and verbose:
- out(f"Using configuration from {config}.", bold=False, fg="blue")
+ ctx.ensure_object(dict)
+
+ if src and code is not None:
+ out(
+ main.get_usage(ctx)
+ + "\n\n'SRC' and 'code' cannot be passed simultaneously."
+ )
+ ctx.exit(1)
+ if not src and code is None:
+ out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
+ ctx.exit(1)
+
+ root, method = (
+ find_project_root(src, stdin_filename) if code is None else (None, None)
+ )
+ ctx.obj["root"] = root
+
+ if verbose:
+ if root:
+ out(
+ f"Identified `{root}` as project root containing a {method}.",
+ fg="blue",
+ )
+
+ if config:
+ config_source = ctx.get_parameter_source("config")
+ user_level_config = str(find_user_pyproject_toml())
+ if config == user_level_config:
+ out(
+ "Using configuration from user-level config at "
+ f"'{user_level_config}'.",
+ fg="blue",
+ )
+ elif config_source in (
+ ParameterSource.DEFAULT,
+ ParameterSource.DEFAULT_MAP,
+ ):
+ out("Using configuration from project root.", fg="blue")
+ else:
+ out(f"Using configuration in '{config}'.", fg="blue")
+ if ctx.default_map:
+ for param, value in ctx.default_map.items():
+ out(f"{param}: {value}")
error_msg = "Oh no! 💥 💔 💥"
- if required_version and required_version != __version__:
+ if (
+ required_version
+ and required_version != __version__
+ and required_version != __version__.split(".")[0]
+ ):
err(
f"{error_msg} The required version `{required_version}` does not match"
f" the running version `{__version__}`!"
line_length=line_length,
is_pyi=pyi,
is_ipynb=ipynb,
+ skip_source_first_line=skip_source_first_line,
string_normalization=not skip_string_normalization,
magic_trailing_comma=not skip_magic_trailing_comma,
experimental_string_processing=experimental_string_processing,
+ preview=preview,
+ python_cell_magics=set(python_cell_magics),
)
if code is not None:
report=report,
)
else:
+ from black.concurrency import reformat_many
+
reformat_many(
sources=sources,
fast=fast,
)
if verbose or not quiet:
+ if code is None and (verbose or report.change_count or report.failure_count):
+ out()
out(error_msg if report.return_code else "All done! ✨ 🍰 ✨")
if code is None:
click.echo(str(report), err=True)
stdin_filename: Optional[str],
) -> Set[Path]:
"""Compute the set of files to be formatted."""
-
- root = find_project_root(src)
sources: Set[Path] = set()
- path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
+ root = ctx.obj["root"]
- if exclude is None:
- exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
- gitignore = get_gitignore(root)
- else:
- gitignore = None
+ using_default_exclude = exclude is None
+ exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
+ gitignore: Optional[Dict[Path, PathSpec]] = None
+ root_gitignore = get_gitignore(root)
for s in src:
if s == "-" and stdin_filename:
is_stdin = False
if is_stdin or p.is_file():
- normalized_path = normalize_path_maybe_ignore(p, root, report)
+ normalized_path: Optional[str] = normalize_path_maybe_ignore(
+ p, ctx.obj["root"], report
+ )
if normalized_path is None:
+ if verbose:
+ out(f'Skipping invalid source: "{normalized_path}"', fg="red")
continue
+ if verbose:
+ out(f'Found input source: "{normalized_path}"', fg="blue")
normalized_path = "/" + normalized_path
# Hard-exclude any files that matches the `--force-exclude` regex.
sources.add(p)
elif p.is_dir():
+ p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
+ if verbose:
+ out(f'Found input source directory: "{p}"', fg="blue")
+
+ if using_default_exclude:
+ gitignore = {
+ root: root_gitignore,
+ p: get_gitignore(p),
+ }
sources.update(
gen_python_files(
p.iterdir(),
- root,
+ ctx.obj["root"],
include,
exclude,
extend_exclude,
)
)
elif s == "-":
+ if verbose:
+ out("Found input source stdin", fg="blue")
sources.add(p)
else:
err(f"invalid path: {s}")
+
return sources
report.failed(path, str(exc))
+# diff-shades depends on being to monkeypatch this function to operate. I know it's
+# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
+@mypyc_attr(patchable=True)
def reformat_one(
src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
) -> None:
report.failed(src, str(exc))
-def reformat_many(
- sources: Set[Path],
- fast: bool,
- write_back: WriteBack,
- mode: Mode,
- report: "Report",
- workers: Optional[int],
-) -> None:
- """Reformat multiple files using a ProcessPoolExecutor."""
- executor: Executor
- loop = asyncio.get_event_loop()
- worker_count = workers if workers is not None else DEFAULT_WORKERS
- if sys.platform == "win32":
- # Work around https://bugs.python.org/issue26903
- worker_count = min(worker_count, 60)
- try:
- executor = ProcessPoolExecutor(max_workers=worker_count)
- except (ImportError, OSError):
- # we arrive here if the underlying system does not support multi-processing
- # like in AWS Lambda or Termux, in which case we gracefully fallback to
- # a ThreadPoolExecutor with just a single worker (more workers would not do us
- # any good due to the Global Interpreter Lock)
- executor = ThreadPoolExecutor(max_workers=1)
-
- try:
- loop.run_until_complete(
- schedule_formatting(
- sources=sources,
- fast=fast,
- write_back=write_back,
- mode=mode,
- report=report,
- loop=loop,
- executor=executor,
- )
- )
- finally:
- shutdown(loop)
- if executor is not None:
- executor.shutdown()
-
-
-async def schedule_formatting(
- sources: Set[Path],
- fast: bool,
- write_back: WriteBack,
- mode: Mode,
- report: "Report",
- loop: asyncio.AbstractEventLoop,
- executor: Executor,
-) -> None:
- """Run formatting of `sources` in parallel using the provided `executor`.
-
- (Use ProcessPoolExecutors for actual parallelism.)
-
- `write_back`, `fast`, and `mode` options are passed to
- :func:`format_file_in_place`.
- """
- cache: Cache = {}
- if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
- cache = read_cache(mode)
- sources, cached = filter_cached(cache, sources)
- for src in sorted(cached):
- report.done(src, Changed.CACHED)
- if not sources:
- return
-
- cancelled = []
- sources_to_cache = []
- lock = None
- if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
- # For diff output, we need locks to ensure we don't interleave output
- # from different processes.
- manager = Manager()
- lock = manager.Lock()
- tasks = {
- asyncio.ensure_future(
- loop.run_in_executor(
- executor, format_file_in_place, src, fast, mode, write_back, lock
- )
- ): src
- for src in sorted(sources)
- }
- pending = tasks.keys()
- try:
- loop.add_signal_handler(signal.SIGINT, cancel, pending)
- loop.add_signal_handler(signal.SIGTERM, cancel, pending)
- except NotImplementedError:
- # There are no good alternatives for these on Windows.
- pass
- while pending:
- done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
- for task in done:
- src = tasks.pop(task)
- if task.cancelled():
- cancelled.append(task)
- elif task.exception():
- report.failed(src, str(task.exception()))
- else:
- changed = Changed.YES if task.result() else Changed.NO
- # If the file was written back or was successfully checked as
- # well-formatted, store this information in the cache.
- if write_back is WriteBack.YES or (
- write_back is WriteBack.CHECK and changed is Changed.NO
- ):
- sources_to_cache.append(src)
- report.done(src, changed)
- if cancelled:
- if sys.version_info >= (3, 7):
- await asyncio.gather(*cancelled, return_exceptions=True)
- else:
- await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
- if sources_to_cache:
- write_cache(cache, sources_to_cache, mode)
-
-
def format_file_in_place(
src: Path,
fast: bool,
elif src.suffix == ".ipynb":
mode = replace(mode, is_ipynb=True)
- then = datetime.utcfromtimestamp(src.stat().st_mtime)
+ then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)
+ header = b""
with open(src, "rb") as buf:
+ if mode.skip_source_first_line:
+ header = buf.readline()
src_contents, encoding, newline = decode_bytes(buf.read())
try:
dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
raise ValueError(
f"File '{src}' cannot be parsed as valid Jupyter notebook."
) from None
+ src_contents = header.decode(encoding) + src_contents
+ dst_contents = header.decode(encoding) + dst_contents
if write_back == WriteBack.YES:
with open(src, "w", encoding=encoding, newline=newline) as f:
f.write(dst_contents)
elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
- now = datetime.utcnow()
- src_name = f"{src}\t{then} +0000"
- dst_name = f"{src}\t{now} +0000"
+ now = datetime.now(timezone.utc)
+ src_name = f"{src}\t{then}"
+ dst_name = f"{src}\t{now}"
if mode.is_ipynb:
diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
else:
write a diff to stdout. The `mode` argument is passed to
:func:`format_file_contents`.
"""
- then = datetime.utcnow()
+ then = datetime.now(timezone.utc)
if content is None:
src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
dst += "\n"
f.write(dst)
elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
- now = datetime.utcnow()
- src_name = f"STDIN\t{then} +0000"
- dst_name = f"STDOUT\t{now} +0000"
+ now = datetime.now(timezone.utc)
+ src_name = f"STDIN\t{then}"
+ dst_name = f"STDOUT\t{now}"
d = diff(src, dst, src_name, dst_name)
if write_back == WriteBack.COLOR_DIFF:
d = color_diff(d)
content differently.
"""
assert_equivalent(src_contents, dst_contents)
-
- # Forced second pass to work around optional trailing commas (becoming
- # forced trailing commas on pass 2) interacting differently with optional
- # parentheses. Admittedly ugly.
- dst_contents_pass2 = format_str(dst_contents, mode=mode)
- if dst_contents != dst_contents_pass2:
- dst_contents = dst_contents_pass2
- assert_equivalent(src_contents, dst_contents, pass_num=2)
- assert_stable(src_contents, dst_contents, mode=mode)
- # Note: no need to explicitly call `assert_stable` if `dst_contents` was
- # the same as `dst_contents_pass2`.
+ assert_stable(src_contents, dst_contents, mode=mode)
def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
`mode` is passed to :func:`format_str`.
"""
- if not src_contents.strip():
- raise NothingChanged
-
if mode.is_ipynb:
dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
else:
return dst_contents
-def validate_cell(src: str) -> None:
- """Check that cell does not already contain TransformerManager transformations.
+def validate_cell(src: str, mode: Mode) -> None:
+ """Check that cell does not already contain TransformerManager transformations,
+ or non-Python cell magics, which might cause tokenizer_rt to break because of
+ indentations.
If a cell contains ``!ls``, then it'll be transformed to
``get_ipython().system('ls')``. However, if the cell originally contained
"""
if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
raise NothingChanged
+ if (
+ src[:2] == "%%"
+ and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
+ ):
+ raise NothingChanged
def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
could potentially be automagics or multi-line magics, which
are currently not supported.
"""
- validate_cell(src)
+ validate_cell(src, mode)
src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
src
)
Operate cell-by-cell, only on code cells, only for Python notebooks.
If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
"""
+ if not src_contents:
+ raise NothingChanged
+
trailing_newline = src_contents[-1] == "\n"
modified = False
nb = json.loads(src_contents)
raise NothingChanged
-def format_str(src_contents: str, *, mode: Mode) -> FileContent:
+def format_str(src_contents: str, *, mode: Mode) -> str:
"""Reformat a string and return new contents.
`mode` determines formatting options, such as how many characters per line are
hey
"""
+ dst_contents = _format_str_once(src_contents, mode=mode)
+ # Forced second pass to work around optional trailing commas (becoming
+ # forced trailing commas on pass 2) interacting differently with optional
+ # parentheses. Admittedly ugly.
+ if src_contents != dst_contents:
+ return _format_str_once(dst_contents, mode=mode)
+ return dst_contents
+
+
+def _format_str_once(src_contents: str, *, mode: Mode) -> str:
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
- dst_contents = []
- future_imports = get_future_imports(src_node)
+ dst_blocks: List[LinesBlock] = []
if mode.target_versions:
versions = mode.target_versions
else:
- versions = detect_target_versions(src_node)
+ future_imports = get_future_imports(src_node)
+ versions = detect_target_versions(src_node, future_imports=future_imports)
+
+ context_manager_features = {
+ feature
+ for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
+ if supports_feature(versions, feature)
+ }
normalize_fmt_off(src_node)
- lines = LineGenerator(
- mode=mode,
- remove_u_prefix="unicode_literals" in future_imports
- or supports_feature(versions, Feature.UNICODE_LITERALS),
- )
- elt = EmptyLineTracker(is_pyi=mode.is_pyi)
- empty_line = Line(mode=mode)
- after = 0
+ lines = LineGenerator(mode=mode, features=context_manager_features)
+ elt = EmptyLineTracker(mode=mode)
split_line_features = {
feature
for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
if supports_feature(versions, feature)
}
+ block: Optional[LinesBlock] = None
for current_line in lines.visit(src_node):
- dst_contents.append(str(empty_line) * after)
- before, after = elt.maybe_empty_lines(current_line)
- dst_contents.append(str(empty_line) * before)
+ block = elt.maybe_empty_lines(current_line)
+ dst_blocks.append(block)
for line in transform_line(
current_line, mode=mode, features=split_line_features
):
- dst_contents.append(str(line))
+ block.content_lines.append(str(line))
+ if dst_blocks:
+ dst_blocks[-1].after = 0
+ dst_contents = []
+ for block in dst_blocks:
+ dst_contents.extend(block.all_lines())
+ if not dst_contents:
+ # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
+ # and check if normalized_content has more than one line
+ normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
+ if "\n" in normalized_content:
+ return newline
+ return ""
return "".join(dst_contents)
return tiow.read(), encoding, newline
-def get_features_used(node: Node) -> Set[Feature]:
+def get_features_used( # noqa: C901
+ node: Node, *, future_imports: Optional[Set[str]] = None
+) -> Set[Feature]:
"""Return a set of (relatively) new Python features used in this file.
Currently looking for:
- f-strings;
+ - self-documenting expressions in f-strings (f"{x=}");
- underscores in numeric literals;
- trailing commas after * or ** in function signatures and calls;
- positional only arguments in function signatures and lambdas;
- assignment expression;
- relaxed decorator syntax;
+ - usage of __future__ flags (annotations);
+ - print / exec statements;
+ - parenthesized context managers;
+ - match statements;
+ - except* clause;
+ - variadic generics;
"""
features: Set[Feature] = set()
+ if future_imports:
+ features |= {
+ FUTURE_FLAG_TO_FEATURE[future_import]
+ for future_import in future_imports
+ if future_import in FUTURE_FLAG_TO_FEATURE
+ }
+
for n in node.pre_order():
- if n.type == token.STRING:
- value_head = n.value[:2] # type: ignore
+ if is_string_token(n):
+ value_head = n.value[:2]
if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
features.add(Feature.F_STRINGS)
-
- elif n.type == token.NUMBER:
- if "_" in n.value: # type: ignore
+ if Feature.DEBUG_F_STRINGS not in features:
+ for span_beg, span_end in iter_fexpr_spans(n.value):
+ if n.value[span_beg : span_end - 1].rstrip().endswith("="):
+ features.add(Feature.DEBUG_F_STRINGS)
+ break
+
+ elif is_number_token(n):
+ if "_" in n.value:
features.add(Feature.NUMERIC_UNDERSCORES)
elif n.type == token.SLASH:
if argch.type in STARS:
features.add(feature)
+ elif (
+ n.type in {syms.return_stmt, syms.yield_expr}
+ and len(n.children) >= 2
+ and n.children[1].type == syms.testlist_star_expr
+ and any(child.type == syms.star_expr for child in n.children[1].children)
+ ):
+ features.add(Feature.UNPACKING_ON_FLOW)
+
+ elif (
+ n.type == syms.annassign
+ and len(n.children) >= 4
+ and n.children[3].type == syms.testlist_star_expr
+ ):
+ features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
+
+ elif (
+ n.type == syms.with_stmt
+ and len(n.children) > 2
+ and n.children[1].type == syms.atom
+ ):
+ atom_children = n.children[1].children
+ if (
+ len(atom_children) == 3
+ and atom_children[0].type == token.LPAR
+ and atom_children[1].type == syms.testlist_gexp
+ and atom_children[2].type == token.RPAR
+ ):
+ features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
+
+ elif n.type == syms.match_stmt:
+ features.add(Feature.PATTERN_MATCHING)
+
+ elif (
+ n.type == syms.except_clause
+ and len(n.children) >= 2
+ and n.children[1].type == token.STAR
+ ):
+ features.add(Feature.EXCEPT_STAR)
+
+ elif n.type in {syms.subscriptlist, syms.trailer} and any(
+ child.type == syms.star_expr for child in n.children
+ ):
+ features.add(Feature.VARIADIC_GENERICS)
+
+ elif (
+ n.type == syms.tname_star
+ and len(n.children) == 3
+ and n.children[2].type == syms.star_expr
+ ):
+ features.add(Feature.VARIADIC_GENERICS)
+
+ elif n.type in (syms.type_stmt, syms.typeparams):
+ features.add(Feature.TYPE_PARAMS)
+
return features
-def detect_target_versions(node: Node) -> Set[TargetVersion]:
+def detect_target_versions(
+ node: Node, *, future_imports: Optional[Set[str]] = None
+) -> Set[TargetVersion]:
"""Detect the version to target based on the nodes used."""
- features = get_features_used(node)
+ features = get_features_used(node, future_imports=future_imports)
return {
version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
}
return imports
-def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None:
+def assert_equivalent(src: str, dst: str) -> None:
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
try:
src_ast = parse_ast(src)
except Exception as exc:
raise AssertionError(
- "cannot use --safe with this file; failed to parse source file."
+ "cannot use --safe with this file; failed to parse source file AST: "
+ f"{exc}\n"
+ "This could be caused by running Black with an older Python version "
+ "that does not support new syntax used in your source file."
) from exc
try:
except Exception as exc:
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
raise AssertionError(
- f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. "
+ f"INTERNAL ERROR: Black produced invalid code: {exc}. "
"Please report a bug on https://github.com/psf/black/issues. "
f"This invalid output might be helpful: {log}"
) from None
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError(
"INTERNAL ERROR: Black produced code that is not equivalent to the"
- f" source on pass {pass_num}. Please report a bug on "
+ " source. Please report a bug on "
f"https://github.com/psf/black/issues. This diff might be helpful: {log}"
) from None
def assert_stable(src: str, dst: str, mode: Mode) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
- newdst = format_str(dst, mode=mode)
+ # We shouldn't call format_str() here, because that formats the string
+ # twice and may hide a bug where we bounce back and forth between two
+ # versions.
+ newdst = _format_str_once(dst, mode=mode)
if dst != newdst:
log = dump_to_file(
str(mode),
file paths is minimal since it's Python source code. Moreover, this crash was
spurious on Python 3.7 thanks to PEP 538 and PEP 540.
"""
+ modules: List[Any] = []
try:
from click import core
- from click import _unicodefun
- except ModuleNotFoundError:
- return
+ except ImportError:
+ pass
+ else:
+ modules.append(core)
+ try:
+ # Removed in Click 8.1.0 and newer; we keep this around for users who have
+ # older versions installed.
+ from click import _unicodefun # type: ignore
+ except ImportError:
+ pass
+ else:
+ modules.append(_unicodefun)
- for module in (core, _unicodefun):
+ for module in modules:
if hasattr(module, "_verify_python3_env"):
- module._verify_python3_env = lambda: None # type: ignore
+ module._verify_python3_env = lambda: None
if hasattr(module, "_verify_python_env"):
- module._verify_python_env = lambda: None # type: ignore
+ module._verify_python_env = lambda: None
def patched_main() -> None:
- maybe_install_uvloop()
- freeze_support()
+ # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
+ # environments so just assume we always need to call it if frozen.
+ if getattr(sys, "frozen", False):
+ from multiprocessing import freeze_support
+
+ freeze_support()
+
patch_click()
main()