import asyncio
+from json.decoder import JSONDecodeError
+import json
from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
from contextlib import contextmanager
from datetime import datetime
from multiprocessing import Manager, freeze_support
import os
from pathlib import Path
+from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
import regex as re
import signal
import sys
Generator,
Iterator,
List,
+ MutableMapping,
Optional,
Pattern,
Set,
from black.mode import Feature, supports_feature, VERSION_TO_FEATURES
from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache
from black.concurrency import cancel, shutdown, maybe_install_uvloop
-from black.output import dump_to_file, diff, color_diff, out, err
-from black.report import Report, Changed
+from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err
+from black.report import Report, Changed, NothingChanged
from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml
from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore
from black.files import wrap_stream_for_windows
from black.parsing import InvalidInput # noqa F401
from black.parsing import lib2to3_parse, parse_ast, stringify_ast
+from black.handle_ipynb_magics import (
+ mask_cell,
+ unmask_cell,
+ remove_trailing_semicolon,
+ put_trailing_semicolon_back,
+ TRANSFORMED_MAGICS,
+ jupyter_dependencies_are_installed,
+)
# lib2to3 fork
NewLine = str
-class NothingChanged(UserWarning):
- """Raised when reformatted code is the same as source."""
-
-
class WriteBack(Enum):
NO = 0
YES = 1
# Legacy name, left for integrations.
FileMode = Mode
+DEFAULT_WORKERS = os.cpu_count()
+
def read_pyproject_toml(
ctx: click.Context, param: click.Parameter, value: Optional[str]
except (OSError, ValueError) as e:
raise click.FileError(
filename=value, hint=f"Error reading configuration file: {e}"
- )
+ ) from None
if not config:
return None
try:
return re_compile_maybe_verbose(value) if value is not None else None
except re.error:
- raise click.BadParameter("Not a valid regular expression")
+ raise click.BadParameter("Not a valid regular expression") from None
@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
" when piping source on standard input)."
),
)
+@click.option(
+ "--ipynb",
+ is_flag=True,
+ help=(
+ "Format all input files like Jupyter Notebooks regardless of file extension "
+ "(useful when piping source on standard input)."
+ ),
+)
@click.option(
"-S",
"--skip-string-normalization",
"editors that rely on using stdin."
),
)
+@click.option(
+ "-W",
+ "--workers",
+ type=click.IntRange(min=1),
+ default=DEFAULT_WORKERS,
+ show_default=True,
+ help="Number of parallel workers",
+)
@click.option(
"-q",
"--quiet",
exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
),
is_eager=True,
+ metavar="SRC ...",
)
@click.option(
"--config",
color: bool,
fast: bool,
pyi: bool,
+ ipynb: bool,
skip_string_normalization: bool,
skip_magic_trailing_comma: bool,
experimental_string_processing: bool,
extend_exclude: Optional[Pattern],
force_exclude: Optional[Pattern],
stdin_filename: Optional[str],
+ workers: int,
src: Tuple[str, ...],
config: Optional[str],
) -> None:
f" the running version `{__version__}`!"
)
ctx.exit(1)
+ if ipynb and pyi:
+ err("Cannot pass both `pyi` and `ipynb` flags!")
+ ctx.exit(1)
write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
if target_version:
target_versions=versions,
line_length=line_length,
is_pyi=pyi,
+ is_ipynb=ipynb,
string_normalization=not skip_string_normalization,
magic_trailing_comma=not skip_magic_trailing_comma,
experimental_string_processing=experimental_string_processing,
content=code, fast=fast, write_back=write_back, mode=mode, report=report
)
else:
- sources = get_sources(
- ctx=ctx,
- src=src,
- quiet=quiet,
- verbose=verbose,
- include=include,
- exclude=exclude,
- extend_exclude=extend_exclude,
- force_exclude=force_exclude,
- report=report,
- stdin_filename=stdin_filename,
- )
+ try:
+ sources = get_sources(
+ ctx=ctx,
+ src=src,
+ quiet=quiet,
+ verbose=verbose,
+ include=include,
+ exclude=exclude,
+ extend_exclude=extend_exclude,
+ force_exclude=force_exclude,
+ report=report,
+ stdin_filename=stdin_filename,
+ )
+ except GitWildMatchPatternError:
+ ctx.exit(1)
path_empty(
sources,
write_back=write_back,
mode=mode,
report=report,
+ workers=workers,
)
if verbose or not quiet:
if is_stdin:
p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
+ if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed(
+ verbose=verbose, quiet=quiet
+ ):
+ continue
+
sources.add(p)
elif p.is_dir():
sources.update(
force_exclude,
report,
gitignore,
+ verbose=verbose,
+ quiet=quiet,
)
)
elif s == "-":
if is_stdin:
if src.suffix == ".pyi":
mode = replace(mode, is_pyi=True)
+ elif src.suffix == ".ipynb":
+ mode = replace(mode, is_ipynb=True)
if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
changed = Changed.YES
else:
def reformat_many(
- sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
+ sources: Set[Path],
+ fast: bool,
+ write_back: WriteBack,
+ mode: Mode,
+ report: "Report",
+ workers: Optional[int],
) -> None:
"""Reformat multiple files using a ProcessPoolExecutor."""
executor: Executor
loop = asyncio.get_event_loop()
- worker_count = os.cpu_count()
+ worker_count = workers if workers is not None else DEFAULT_WORKERS
if sys.platform == "win32":
# Work around https://bugs.python.org/issue26903
worker_count = min(worker_count, 60)
"""
if src.suffix == ".pyi":
mode = replace(mode, is_pyi=True)
+ elif src.suffix == ".ipynb":
+ mode = replace(mode, is_ipynb=True)
then = datetime.utcfromtimestamp(src.stat().st_mtime)
with open(src, "rb") as buf:
dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
except NothingChanged:
return False
+ except JSONDecodeError:
+ raise ValueError(
+ f"File '{src}' cannot be parsed as valid Jupyter notebook."
+ ) from None
if write_back == WriteBack.YES:
with open(src, "w", encoding=encoding, newline=newline) as f:
now = datetime.utcnow()
src_name = f"{src}\t{then} +0000"
dst_name = f"{src}\t{now} +0000"
- diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
+ if mode.is_ipynb:
+ diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
+ else:
+ diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
if write_back == WriteBack.COLOR_DIFF:
diff_contents = color_diff(diff_contents)
f.detach()
+def check_stability_and_equivalence(
+ src_contents: str, dst_contents: str, *, mode: Mode
+) -> None:
+ """Perform stability and equivalence checks.
+
+ Raise AssertionError if source and destination contents are not
+ equivalent, or if a second pass of the formatter would format the
+ content differently.
+ """
+ assert_equivalent(src_contents, dst_contents)
+
+ # Forced second pass to work around optional trailing commas (becoming
+ # forced trailing commas on pass 2) interacting differently with optional
+ # parentheses. Admittedly ugly.
+ dst_contents_pass2 = format_str(dst_contents, mode=mode)
+ if dst_contents != dst_contents_pass2:
+ dst_contents = dst_contents_pass2
+ assert_equivalent(src_contents, dst_contents, pass_num=2)
+ assert_stable(src_contents, dst_contents, mode=mode)
+ # Note: no need to explicitly call `assert_stable` if `dst_contents` was
+ # the same as `dst_contents_pass2`.
+
+
def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
"""Reformat contents of a file and return new contents.
if not src_contents.strip():
raise NothingChanged
- dst_contents = format_str(src_contents, mode=mode)
+ if mode.is_ipynb:
+ dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
+ else:
+ dst_contents = format_str(src_contents, mode=mode)
if src_contents == dst_contents:
raise NothingChanged
- if not fast:
- assert_equivalent(src_contents, dst_contents)
-
- # Forced second pass to work around optional trailing commas (becoming
- # forced trailing commas on pass 2) interacting differently with optional
- # parentheses. Admittedly ugly.
- dst_contents_pass2 = format_str(dst_contents, mode=mode)
- if dst_contents != dst_contents_pass2:
- dst_contents = dst_contents_pass2
- assert_equivalent(src_contents, dst_contents, pass_num=2)
- assert_stable(src_contents, dst_contents, mode=mode)
- # Note: no need to explicitly call `assert_stable` if `dst_contents` was
- # the same as `dst_contents_pass2`.
+ if not fast and not mode.is_ipynb:
+ # Jupyter notebooks will already have been checked above.
+ check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
return dst_contents
+def validate_cell(src: str) -> None:
+ """Check that cell does not already contain TransformerManager transformations.
+
+ If a cell contains ``!ls``, then it'll be transformed to
+ ``get_ipython().system('ls')``. However, if the cell originally contained
+ ``get_ipython().system('ls')``, then it would get transformed in the same way:
+
+ >>> TransformerManager().transform_cell("get_ipython().system('ls')")
+ "get_ipython().system('ls')\n"
+ >>> TransformerManager().transform_cell("!ls")
+ "get_ipython().system('ls')\n"
+
+ Due to the impossibility of safely roundtripping in such situations, cells
+ containing transformed magics will be ignored.
+ """
+ if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
+ raise NothingChanged
+
+
+def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
+ """Format code in given cell of Jupyter notebook.
+
+ General idea is:
+
+ - if cell has trailing semicolon, remove it;
+ - if cell has IPython magics, mask them;
+ - format cell;
+ - reinstate IPython magics;
+ - reinstate trailing semicolon (if originally present);
+ - strip trailing newlines.
+
+ Cells with syntax errors will not be processed, as they
+ could potentially be automagics or multi-line magics, which
+ are currently not supported.
+ """
+ validate_cell(src)
+ src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
+ src
+ )
+ try:
+ masked_src, replacements = mask_cell(src_without_trailing_semicolon)
+ except SyntaxError:
+ raise NothingChanged from None
+ masked_dst = format_str(masked_src, mode=mode)
+ if not fast:
+ check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
+ dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
+ dst = put_trailing_semicolon_back(
+ dst_without_trailing_semicolon, has_trailing_semicolon
+ )
+ dst = dst.rstrip("\n")
+ if dst == src:
+ raise NothingChanged from None
+ return dst
+
+
+def validate_metadata(nb: MutableMapping[str, Any]) -> None:
+ """If notebook is marked as non-Python, don't format it.
+
+ All notebook metadata fields are optional, see
+ https://nbformat.readthedocs.io/en/latest/format_description.html. So
+ if a notebook has empty metadata, we will try to parse it anyway.
+ """
+ language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
+ if language is not None and language != "python":
+ raise NothingChanged from None
+
+
+def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
+ """Format Jupyter notebook.
+
+ Operate cell-by-cell, only on code cells, only for Python notebooks.
+ If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
+ """
+ trailing_newline = src_contents[-1] == "\n"
+ modified = False
+ nb = json.loads(src_contents)
+ validate_metadata(nb)
+ for cell in nb["cells"]:
+ if cell.get("cell_type", None) == "code":
+ try:
+ src = "".join(cell["source"])
+ dst = format_cell(src, fast=fast, mode=mode)
+ except NothingChanged:
+ pass
+ else:
+ cell["source"] = dst.splitlines(keepends=True)
+ modified = True
+ if modified:
+ dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
+ if trailing_newline:
+ dst_contents = dst_contents + "\n"
+ return dst_contents
+ else:
+ raise NothingChanged
+
+
def format_str(src_contents: str, *, mode: Mode) -> FileContent:
"""Reformat a string and return new contents.
features.add(Feature.NUMERIC_UNDERSCORES)
elif n.type == token.SLASH:
- if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
+ if n.parent and n.parent.type in {
+ syms.typedargslist,
+ syms.arglist,
+ syms.varargslist,
+ }:
features.add(Feature.POS_ONLY_ARGUMENTS)
elif n.type == token.COLONEQUAL:
src_ast = parse_ast(src)
except Exception as exc:
raise AssertionError(
- "cannot use --safe with this file; failed to parse source file. AST"
- f" error message: {exc}"
- )
+ "cannot use --safe with this file; failed to parse source file."
+ ) from exc
try:
dst_ast = parse_ast(dst)