X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/7567cdf3b4f32d4fb12bd5ca0da838f7ff252cfc..de65741b8d49d78fa2675ef79b799cd35e92e7c1:/src/black/__init__.py diff --git a/src/black/__init__.py b/src/black/__init__.py index 1d0ad7d..4ebf288 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -1,66 +1,92 @@ -import asyncio -from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor -from contextlib import contextmanager -from datetime import datetime -from enum import Enum import io -from multiprocessing import Manager, freeze_support -import os -from pathlib import Path -import regex as re -import signal +import json +import platform +import re import sys import tokenize import traceback +from contextlib import contextmanager +from dataclasses import replace +from datetime import datetime +from enum import Enum +from json.decoder import JSONDecodeError +from pathlib import Path from typing import ( Any, Dict, Generator, Iterator, List, + MutableMapping, Optional, Pattern, + Sequence, Set, Sized, Tuple, Union, ) -from dataclasses import replace import click +from click.core import ParameterSource +from mypy_extensions import mypyc_attr +from pathspec import PathSpec +from pathspec.patterns.gitwildmatch import GitWildMatchPatternError -from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES -from black.const import STDIN_PLACEHOLDER -from black.nodes import STARS, syms, is_simple_decorator_expression -from black.lines import Line, EmptyLineTracker -from black.linegen import transform_line, LineGenerator, LN +from _black_version import version as __version__ +from black.cache import Cache, get_cache_info, read_cache, write_cache from black.comments import normalize_fmt_off -from black.mode import Mode, TargetVersion -from black.mode import Feature, supports_feature, VERSION_TO_FEATURES -from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache -from black.concurrency import cancel, shutdown -from black.output import dump_to_file, diff, color_diff, out, err -from black.report import Report, Changed -from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml -from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore -from black.files import wrap_stream_for_windows +from black.const import ( + DEFAULT_EXCLUDES, + DEFAULT_INCLUDES, + DEFAULT_LINE_LENGTH, + STDIN_PLACEHOLDER, +) +from black.files import ( + find_project_root, + find_pyproject_toml, + find_user_pyproject_toml, + gen_python_files, + get_gitignore, + normalize_path_maybe_ignore, + parse_pyproject_toml, + wrap_stream_for_windows, +) +from black.handle_ipynb_magics import ( + PYTHON_CELL_MAGICS, + TRANSFORMED_MAGICS, + jupyter_dependencies_are_installed, + mask_cell, + put_trailing_semicolon_back, + remove_trailing_semicolon, + unmask_cell, +) +from black.linegen import LN, LineGenerator, transform_line +from black.lines import EmptyLineTracker, LinesBlock +from black.mode import ( + FUTURE_FLAG_TO_FEATURE, + VERSION_TO_FEATURES, + Feature, + Mode, + TargetVersion, + supports_feature, +) +from black.nodes import ( + STARS, + is_number_token, + is_simple_decorator_expression, + is_string_token, + syms, +) +from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out from black.parsing import InvalidInput # noqa F401 from black.parsing import lib2to3_parse, parse_ast, stringify_ast - - -# lib2to3 fork -from blib2to3.pytree import Node, Leaf +from black.report import Changed, NothingChanged, Report +from black.trans import iter_fexpr_spans from blib2to3.pgen2 import token +from blib2to3.pytree import Leaf, Node -from _black_version import version as __version__ - -# If our environment has uvloop installed lets use it -try: - import uvloop - - uvloop.install() -except ImportError: - pass +COMPILED = Path(__file__).suffix in (".pyd", ".so") # types FileContent = str @@ -68,10 +94,6 @@ Encoding = str NewLine = str -class NothingChanged(UserWarning): - """Raised when reformatted code is the same as source.""" - - class WriteBack(Enum): NO = 0 YES = 1 @@ -114,7 +136,7 @@ def read_pyproject_toml( except (OSError, ValueError) as e: raise click.FileError( filename=value, hint=f"Error reading configuration file: {e}" - ) + ) from None if not config: return None @@ -168,14 +190,19 @@ def validate_regex( ctx: click.Context, param: click.Parameter, value: Optional[str], -) -> Optional[Pattern]: +) -> Optional[Pattern[str]]: try: return re_compile_maybe_verbose(value) if value is not None else None - except re.error: - raise click.BadParameter("Not a valid regular expression") + except re.error as e: + raise click.BadParameter(f"Not a valid regular expression: {e}") from None -@click.command(context_settings=dict(help_option_names=["-h", "--help"])) +@click.command( + context_settings={"help_option_names": ["-h", "--help"]}, + # While Click does set this field automatically using the docstring, mypyc + # (annoyingly) strips 'em so we need to set it here too. + help="The uncompromising code formatter.", +) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( "-l", @@ -192,8 +219,9 @@ def validate_regex( callback=target_version_option_callback, multiple=True, help=( - "Python versions that should be supported by Black's output. [default: per-file" - " auto-detection]" + "Python versions that should be supported by Black's output. By default, Black" + " will try to infer this from the project metadata in pyproject.toml. If this" + " does not yield conclusive results, Black will use per-file auto-detection." ), ) @click.option( @@ -204,6 +232,30 @@ def validate_regex( " when piping source on standard input)." ), ) +@click.option( + "--ipynb", + is_flag=True, + help=( + "Format all input files like Jupyter Notebooks regardless of file extension " + "(useful when piping source on standard input)." + ), +) +@click.option( + "--python-cell-magics", + multiple=True, + help=( + "When processing Jupyter Notebooks, add the given magic to the list" + f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})." + " Useful for formatting cells with custom python magics." + ), + default=[], +) +@click.option( + "-x", + "--skip-source-first-line", + is_flag=True, + help="Skip the first line of the source code.", +) @click.option( "-S", "--skip-string-normalization", @@ -220,9 +272,14 @@ def validate_regex( "--experimental-string-processing", is_flag=True, hidden=True, + help="(DEPRECATED and now included in --preview) Normalize string literals.", +) +@click.option( + "--preview", + is_flag=True, help=( - "Experimental option that performs more normalization on string literals." - " Currently disabled because it leads to some crashes." + "Enable potentially disruptive style changes that may be added to Black's main" + " functionality in the next major release." ), ) @click.option( @@ -249,6 +306,15 @@ def validate_regex( is_flag=True, help="If --fast given, skip temporary sanity checks. [default: --safe]", ) +@click.option( + "--required-version", + type=str, + help=( + "Require a specific version of Black to be running (useful for unifying results" + " across many environments e.g. with a pyproject.toml file). It can be" + " either a major version number or an exact version." + ), +) @click.option( "--include", type=str, @@ -302,6 +368,13 @@ def validate_regex( "editors that rely on using stdin." ), ) +@click.option( + "-W", + "--workers", + type=click.IntRange(min=1), + default=None, + help="Number of parallel workers [default: number of CPUs in the system]", +) @click.option( "-q", "--quiet", @@ -320,7 +393,13 @@ def validate_regex( " due to exclusion patterns." ), ) -@click.version_option(version=__version__) +@click.version_option( + version=__version__, + message=( + f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})\n" + f"Python ({platform.python_implementation()}) {platform.python_version()}" + ), +) @click.argument( "src", nargs=-1, @@ -328,6 +407,7 @@ def validate_regex( exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True ), is_eager=True, + metavar="SRC ...", ) @click.option( "--config", @@ -344,7 +424,7 @@ def validate_regex( help="Read configuration from FILE path.", ) @click.pass_context -def main( +def main( # noqa: C901 ctx: click.Context, code: Optional[str], line_length: int, @@ -354,20 +434,107 @@ def main( color: bool, fast: bool, pyi: bool, + ipynb: bool, + python_cell_magics: Sequence[str], + skip_source_first_line: bool, skip_string_normalization: bool, skip_magic_trailing_comma: bool, experimental_string_processing: bool, + preview: bool, quiet: bool, verbose: bool, - include: Pattern, - exclude: Optional[Pattern], - extend_exclude: Optional[Pattern], - force_exclude: Optional[Pattern], + required_version: Optional[str], + include: Pattern[str], + exclude: Optional[Pattern[str]], + extend_exclude: Optional[Pattern[str]], + force_exclude: Optional[Pattern[str]], stdin_filename: Optional[str], + workers: Optional[int], src: Tuple[str, ...], config: Optional[str], ) -> None: """The uncompromising code formatter.""" + ctx.ensure_object(dict) + + if src and code is not None: + out( + main.get_usage(ctx) + + "\n\n'SRC' and 'code' cannot be passed simultaneously." + ) + ctx.exit(1) + if not src and code is None: + out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.") + ctx.exit(1) + + root, method = ( + find_project_root(src, stdin_filename) if code is None else (None, None) + ) + ctx.obj["root"] = root + + if verbose: + if root: + out( + f"Identified `{root}` as project root containing a {method}.", + fg="blue", + ) + + normalized = [ + ( + (source, source) + if source == "-" + else (normalize_path_maybe_ignore(Path(source), root), source) + ) + for source in src + ] + srcs_string = ", ".join( + [ + ( + f'"{_norm}"' + if _norm + else f'\033[31m"{source} (skipping - invalid)"\033[34m' + ) + for _norm, source in normalized + ] + ) + out(f"Sources to be formatted: {srcs_string}", fg="blue") + + if config: + config_source = ctx.get_parameter_source("config") + user_level_config = str(find_user_pyproject_toml()) + if config == user_level_config: + out( + ( + "Using configuration from user-level config at " + f"'{user_level_config}'." + ), + fg="blue", + ) + elif config_source in ( + ParameterSource.DEFAULT, + ParameterSource.DEFAULT_MAP, + ): + out("Using configuration from project root.", fg="blue") + else: + out(f"Using configuration in '{config}'.", fg="blue") + if ctx.default_map: + for param, value in ctx.default_map.items(): + out(f"{param}: {value}") + + error_msg = "Oh no! 💥 💔 💥" + if ( + required_version + and required_version != __version__ + and required_version != __version__.split(".")[0] + ): + err( + f"{error_msg} The required version `{required_version}` does not match" + f" the running version `{__version__}`!" + ) + ctx.exit(1) + if ipynb and pyi: + err("Cannot pass both `pyi` and `ipynb` flags!") + ctx.exit(1) + write_back = WriteBack.from_configuration(check=check, diff=diff, color=color) if target_version: versions = set(target_version) @@ -378,12 +545,14 @@ def main( target_versions=versions, line_length=line_length, is_pyi=pyi, + is_ipynb=ipynb, + skip_source_first_line=skip_source_first_line, string_normalization=not skip_string_normalization, magic_trailing_comma=not skip_magic_trailing_comma, experimental_string_processing=experimental_string_processing, + preview=preview, + python_cell_magics=set(python_cell_magics), ) - if config and verbose: - out(f"Using configuration from {config}.", bold=False, fg="blue") if code is not None: # Run in quiet mode by default with -c; the extra output isn't useful. @@ -397,18 +566,21 @@ def main( content=code, fast=fast, write_back=write_back, mode=mode, report=report ) else: - sources = get_sources( - ctx=ctx, - src=src, - quiet=quiet, - verbose=verbose, - include=include, - exclude=exclude, - extend_exclude=extend_exclude, - force_exclude=force_exclude, - report=report, - stdin_filename=stdin_filename, - ) + try: + sources = get_sources( + ctx=ctx, + src=src, + quiet=quiet, + verbose=verbose, + include=include, + exclude=exclude, + extend_exclude=extend_exclude, + force_exclude=force_exclude, + report=report, + stdin_filename=stdin_filename, + ) + except GitWildMatchPatternError: + ctx.exit(1) path_empty( sources, @@ -427,18 +599,23 @@ def main( report=report, ) else: + from black.concurrency import reformat_many + reformat_many( sources=sources, fast=fast, write_back=write_back, mode=mode, report=report, + workers=workers, ) if verbose or not quiet: - out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨") + if code is None and (verbose or report.change_count or report.failure_count): + out() + out(error_msg if report.return_code else "All done! ✨ 🍰 ✨") if code is None: - click.secho(str(report), err=True) + click.echo(str(report), err=True) ctx.exit(report.return_code) @@ -456,16 +633,13 @@ def get_sources( stdin_filename: Optional[str], ) -> Set[Path]: """Compute the set of files to be formatted.""" - - root = find_project_root(src) sources: Set[Path] = set() - path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx) + root = ctx.obj["root"] - if exclude is None: - exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) - gitignore = get_gitignore(root) - else: - gitignore = None + using_default_exclude = exclude is None + exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude + gitignore: Optional[Dict[Path, PathSpec]] = None + root_gitignore = get_gitignore(root) for s in src: if s == "-" and stdin_filename: @@ -476,7 +650,7 @@ def get_sources( is_stdin = False if is_stdin or p.is_file(): - normalized_path = normalize_path_maybe_ignore(p, root, report) + normalized_path = normalize_path_maybe_ignore(p, ctx.obj["root"], report) if normalized_path is None: continue @@ -493,18 +667,31 @@ def get_sources( if is_stdin: p = Path(f"{STDIN_PLACEHOLDER}{str(p)}") + if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed( + verbose=verbose, quiet=quiet + ): + continue + sources.add(p) elif p.is_dir(): + p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report) + if using_default_exclude: + gitignore = { + root: root_gitignore, + p: get_gitignore(p), + } sources.update( gen_python_files( p.iterdir(), - root, + ctx.obj["root"], include, exclude, extend_exclude, force_exclude, report, gitignore, + verbose=verbose, + quiet=quiet, ) ) elif s == "-": @@ -550,6 +737,9 @@ def reformat_code( report.failed(path, str(exc)) +# diff-shades depends on being to monkeypatch this function to operate. I know it's +# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26 +@mypyc_attr(patchable=True) def reformat_one( src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: @@ -574,6 +764,8 @@ def reformat_one( if is_stdin: if src.suffix == ".pyi": mode = replace(mode, is_pyi=True) + elif src.suffix == ".ipynb": + mode = replace(mode, is_ipynb=True) if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode): changed = Changed.YES else: @@ -599,114 +791,6 @@ def reformat_one( report.failed(src, str(exc)) -def reformat_many( - sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report" -) -> None: - """Reformat multiple files using a ProcessPoolExecutor.""" - executor: Executor - loop = asyncio.get_event_loop() - worker_count = os.cpu_count() - if sys.platform == "win32": - # Work around https://bugs.python.org/issue26903 - worker_count = min(worker_count, 60) - try: - executor = ProcessPoolExecutor(max_workers=worker_count) - except (ImportError, OSError): - # we arrive here if the underlying system does not support multi-processing - # like in AWS Lambda or Termux, in which case we gracefully fallback to - # a ThreadPoolExecutor with just a single worker (more workers would not do us - # any good due to the Global Interpreter Lock) - executor = ThreadPoolExecutor(max_workers=1) - - try: - loop.run_until_complete( - schedule_formatting( - sources=sources, - fast=fast, - write_back=write_back, - mode=mode, - report=report, - loop=loop, - executor=executor, - ) - ) - finally: - shutdown(loop) - if executor is not None: - executor.shutdown() - - -async def schedule_formatting( - sources: Set[Path], - fast: bool, - write_back: WriteBack, - mode: Mode, - report: "Report", - loop: asyncio.AbstractEventLoop, - executor: Executor, -) -> None: - """Run formatting of `sources` in parallel using the provided `executor`. - - (Use ProcessPoolExecutors for actual parallelism.) - - `write_back`, `fast`, and `mode` options are passed to - :func:`format_file_in_place`. - """ - cache: Cache = {} - if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF): - cache = read_cache(mode) - sources, cached = filter_cached(cache, sources) - for src in sorted(cached): - report.done(src, Changed.CACHED) - if not sources: - return - - cancelled = [] - sources_to_cache = [] - lock = None - if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): - # For diff output, we need locks to ensure we don't interleave output - # from different processes. - manager = Manager() - lock = manager.Lock() - tasks = { - asyncio.ensure_future( - loop.run_in_executor( - executor, format_file_in_place, src, fast, mode, write_back, lock - ) - ): src - for src in sorted(sources) - } - pending = tasks.keys() - try: - loop.add_signal_handler(signal.SIGINT, cancel, pending) - loop.add_signal_handler(signal.SIGTERM, cancel, pending) - except NotImplementedError: - # There are no good alternatives for these on Windows. - pass - while pending: - done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - for task in done: - src = tasks.pop(task) - if task.cancelled(): - cancelled.append(task) - elif task.exception(): - report.failed(src, str(task.exception())) - else: - changed = Changed.YES if task.result() else Changed.NO - # If the file was written back or was successfully checked as - # well-formatted, store this information in the cache. - if write_back is WriteBack.YES or ( - write_back is WriteBack.CHECK and changed is Changed.NO - ): - sources_to_cache.append(src) - report.done(src, changed) - if cancelled: - await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) - if sources_to_cache: - write_cache(cache, sources_to_cache, mode) - - def format_file_in_place( src: Path, fast: bool, @@ -722,14 +806,25 @@ def format_file_in_place( """ if src.suffix == ".pyi": mode = replace(mode, is_pyi=True) + elif src.suffix == ".ipynb": + mode = replace(mode, is_ipynb=True) then = datetime.utcfromtimestamp(src.stat().st_mtime) + header = b"" with open(src, "rb") as buf: + if mode.skip_source_first_line: + header = buf.readline() src_contents, encoding, newline = decode_bytes(buf.read()) try: dst_contents = format_file_contents(src_contents, fast=fast, mode=mode) except NothingChanged: return False + except JSONDecodeError: + raise ValueError( + f"File '{src}' cannot be parsed as valid Jupyter notebook." + ) from None + src_contents = header.decode(encoding) + src_contents + dst_contents = header.decode(encoding) + dst_contents if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: @@ -738,7 +833,10 @@ def format_file_in_place( now = datetime.utcnow() src_name = f"{src}\t{then} +0000" dst_name = f"{src}\t{now} +0000" - diff_contents = diff(src_contents, dst_contents, src_name, dst_name) + if mode.is_ipynb: + diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name) + else: + diff_contents = diff(src_contents, dst_contents, src_name, dst_name) if write_back == WriteBack.COLOR_DIFF: diff_contents = color_diff(diff_contents) @@ -793,7 +891,8 @@ def format_stdin_to_stdout( ) if write_back == WriteBack.YES: # Make sure there's a newline after the content - dst += "" if dst[-1] == "\n" else "\n" + if dst and dst[-1] != "\n": + dst += "\n" f.write(dst) elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): now = datetime.utcnow() @@ -807,6 +906,19 @@ def format_stdin_to_stdout( f.detach() +def check_stability_and_equivalence( + src_contents: str, dst_contents: str, *, mode: Mode +) -> None: + """Perform stability and equivalence checks. + + Raise AssertionError if source and destination contents are not + equivalent, or if a second pass of the formatter would format the + content differently. + """ + assert_equivalent(src_contents, dst_contents) + assert_stable(src_contents, dst_contents, mode=mode) + + def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: """Reformat contents of a file and return new contents. @@ -814,30 +926,127 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. `mode` is passed to :func:`format_str`. """ - if not src_contents.strip(): + if mode.is_ipynb: + dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode) + else: + dst_contents = format_str(src_contents, mode=mode) + if src_contents == dst_contents: raise NothingChanged - dst_contents = format_str(src_contents, mode=mode) - if src_contents == dst_contents: + if not fast and not mode.is_ipynb: + # Jupyter notebooks will already have been checked above. + check_stability_and_equivalence(src_contents, dst_contents, mode=mode) + return dst_contents + + +def validate_cell(src: str, mode: Mode) -> None: + """Check that cell does not already contain TransformerManager transformations, + or non-Python cell magics, which might cause tokenizer_rt to break because of + indentations. + + If a cell contains ``!ls``, then it'll be transformed to + ``get_ipython().system('ls')``. However, if the cell originally contained + ``get_ipython().system('ls')``, then it would get transformed in the same way: + + >>> TransformerManager().transform_cell("get_ipython().system('ls')") + "get_ipython().system('ls')\n" + >>> TransformerManager().transform_cell("!ls") + "get_ipython().system('ls')\n" + + Due to the impossibility of safely roundtripping in such situations, cells + containing transformed magics will be ignored. + """ + if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): + raise NothingChanged + if ( + src[:2] == "%%" + and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics + ): raise NothingChanged + +def format_cell(src: str, *, fast: bool, mode: Mode) -> str: + """Format code in given cell of Jupyter notebook. + + General idea is: + + - if cell has trailing semicolon, remove it; + - if cell has IPython magics, mask them; + - format cell; + - reinstate IPython magics; + - reinstate trailing semicolon (if originally present); + - strip trailing newlines. + + Cells with syntax errors will not be processed, as they + could potentially be automagics or multi-line magics, which + are currently not supported. + """ + validate_cell(src, mode) + src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon( + src + ) + try: + masked_src, replacements = mask_cell(src_without_trailing_semicolon) + except SyntaxError: + raise NothingChanged from None + masked_dst = format_str(masked_src, mode=mode) if not fast: - assert_equivalent(src_contents, dst_contents) - - # Forced second pass to work around optional trailing commas (becoming - # forced trailing commas on pass 2) interacting differently with optional - # parentheses. Admittedly ugly. - dst_contents_pass2 = format_str(dst_contents, mode=mode) - if dst_contents != dst_contents_pass2: - dst_contents = dst_contents_pass2 - assert_equivalent(src_contents, dst_contents, pass_num=2) - assert_stable(src_contents, dst_contents, mode=mode) - # Note: no need to explicitly call `assert_stable` if `dst_contents` was - # the same as `dst_contents_pass2`. - return dst_contents + check_stability_and_equivalence(masked_src, masked_dst, mode=mode) + dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements) + dst = put_trailing_semicolon_back( + dst_without_trailing_semicolon, has_trailing_semicolon + ) + dst = dst.rstrip("\n") + if dst == src: + raise NothingChanged from None + return dst -def format_str(src_contents: str, *, mode: Mode) -> FileContent: +def validate_metadata(nb: MutableMapping[str, Any]) -> None: + """If notebook is marked as non-Python, don't format it. + + All notebook metadata fields are optional, see + https://nbformat.readthedocs.io/en/latest/format_description.html. So + if a notebook has empty metadata, we will try to parse it anyway. + """ + language = nb.get("metadata", {}).get("language_info", {}).get("name", None) + if language is not None and language != "python": + raise NothingChanged from None + + +def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: + """Format Jupyter notebook. + + Operate cell-by-cell, only on code cells, only for Python notebooks. + If the ``.ipynb`` originally had a trailing newline, it'll be preserved. + """ + if not src_contents: + raise NothingChanged + + trailing_newline = src_contents[-1] == "\n" + modified = False + nb = json.loads(src_contents) + validate_metadata(nb) + for cell in nb["cells"]: + if cell.get("cell_type", None) == "code": + try: + src = "".join(cell["source"]) + dst = format_cell(src, fast=fast, mode=mode) + except NothingChanged: + pass + else: + cell["source"] = dst.splitlines(keepends=True) + modified = True + if modified: + dst_contents = json.dumps(nb, indent=1, ensure_ascii=False) + if trailing_newline: + dst_contents = dst_contents + "\n" + return dst_contents + else: + raise NothingChanged + + +def format_str(src_contents: str, *, mode: Mode) -> str: """Reformat a string and return new contents. `mode` determines formatting options, such as how many characters per line are @@ -867,35 +1076,57 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: hey """ + dst_contents = _format_str_once(src_contents, mode=mode) + # Forced second pass to work around optional trailing commas (becoming + # forced trailing commas on pass 2) interacting differently with optional + # parentheses. Admittedly ugly. + if src_contents != dst_contents: + return _format_str_once(dst_contents, mode=mode) + return dst_contents + + +def _format_str_once(src_contents: str, *, mode: Mode) -> str: src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) - dst_contents = [] - future_imports = get_future_imports(src_node) + dst_blocks: List[LinesBlock] = [] if mode.target_versions: versions = mode.target_versions else: - versions = detect_target_versions(src_node) + future_imports = get_future_imports(src_node) + versions = detect_target_versions(src_node, future_imports=future_imports) + + context_manager_features = { + feature + for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS} + if supports_feature(versions, feature) + } normalize_fmt_off(src_node) - lines = LineGenerator( - mode=mode, - remove_u_prefix="unicode_literals" in future_imports - or supports_feature(versions, Feature.UNICODE_LITERALS), - ) - elt = EmptyLineTracker(is_pyi=mode.is_pyi) - empty_line = Line(mode=mode) - after = 0 + lines = LineGenerator(mode=mode, features=context_manager_features) + elt = EmptyLineTracker(mode=mode) split_line_features = { feature for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF} if supports_feature(versions, feature) } + block: Optional[LinesBlock] = None for current_line in lines.visit(src_node): - dst_contents.append(str(empty_line) * after) - before, after = elt.maybe_empty_lines(current_line) - dst_contents.append(str(empty_line) * before) + block = elt.maybe_empty_lines(current_line) + dst_blocks.append(block) for line in transform_line( current_line, mode=mode, features=split_line_features ): - dst_contents.append(str(line)) + block.content_lines.append(str(line)) + if dst_blocks: + dst_blocks[-1].after = 0 + dst_contents = [] + for block in dst_blocks: + dst_contents.extend(block.all_lines()) + if not dst_contents: + # Use decode_bytes to retrieve the correct source newline (CRLF or LF), + # and check if normalized_content has more than one line + normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8")) + if "\n" in normalized_content: + return newline + return "" return "".join(dst_contents) @@ -916,30 +1147,55 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: return tiow.read(), encoding, newline -def get_features_used(node: Node) -> Set[Feature]: +def get_features_used( # noqa: C901 + node: Node, *, future_imports: Optional[Set[str]] = None +) -> Set[Feature]: """Return a set of (relatively) new Python features used in this file. Currently looking for: - f-strings; + - self-documenting expressions in f-strings (f"{x=}"); - underscores in numeric literals; - trailing commas after * or ** in function signatures and calls; - positional only arguments in function signatures and lambdas; - assignment expression; - relaxed decorator syntax; + - usage of __future__ flags (annotations); + - print / exec statements; + - parenthesized context managers; + - match statements; + - except* clause; + - variadic generics; """ features: Set[Feature] = set() + if future_imports: + features |= { + FUTURE_FLAG_TO_FEATURE[future_import] + for future_import in future_imports + if future_import in FUTURE_FLAG_TO_FEATURE + } + for n in node.pre_order(): - if n.type == token.STRING: - value_head = n.value[:2] # type: ignore + if is_string_token(n): + value_head = n.value[:2] if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}: features.add(Feature.F_STRINGS) - - elif n.type == token.NUMBER: - if "_" in n.value: # type: ignore + if Feature.DEBUG_F_STRINGS not in features: + for span_beg, span_end in iter_fexpr_spans(n.value): + if n.value[span_beg : span_end - 1].rstrip().endswith("="): + features.add(Feature.DEBUG_F_STRINGS) + break + + elif is_number_token(n): + if "_" in n.value: features.add(Feature.NUMERIC_UNDERSCORES) elif n.type == token.SLASH: - if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}: + if n.parent and n.parent.type in { + syms.typedargslist, + syms.arglist, + syms.varargslist, + }: features.add(Feature.POS_ONLY_ARGUMENTS) elif n.type == token.COLONEQUAL: @@ -970,12 +1226,65 @@ def get_features_used(node: Node) -> Set[Feature]: if argch.type in STARS: features.add(feature) + elif ( + n.type in {syms.return_stmt, syms.yield_expr} + and len(n.children) >= 2 + and n.children[1].type == syms.testlist_star_expr + and any(child.type == syms.star_expr for child in n.children[1].children) + ): + features.add(Feature.UNPACKING_ON_FLOW) + + elif ( + n.type == syms.annassign + and len(n.children) >= 4 + and n.children[3].type == syms.testlist_star_expr + ): + features.add(Feature.ANN_ASSIGN_EXTENDED_RHS) + + elif ( + n.type == syms.with_stmt + and len(n.children) > 2 + and n.children[1].type == syms.atom + ): + atom_children = n.children[1].children + if ( + len(atom_children) == 3 + and atom_children[0].type == token.LPAR + and atom_children[1].type == syms.testlist_gexp + and atom_children[2].type == token.RPAR + ): + features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS) + + elif n.type == syms.match_stmt: + features.add(Feature.PATTERN_MATCHING) + + elif ( + n.type == syms.except_clause + and len(n.children) >= 2 + and n.children[1].type == token.STAR + ): + features.add(Feature.EXCEPT_STAR) + + elif n.type in {syms.subscriptlist, syms.trailer} and any( + child.type == syms.star_expr for child in n.children + ): + features.add(Feature.VARIADIC_GENERICS) + + elif ( + n.type == syms.tname_star + and len(n.children) == 3 + and n.children[2].type == syms.star_expr + ): + features.add(Feature.VARIADIC_GENERICS) + return features -def detect_target_versions(node: Node) -> Set[TargetVersion]: +def detect_target_versions( + node: Node, *, future_imports: Optional[Set[str]] = None +) -> Set[TargetVersion]: """Detect the version to target based on the nodes used.""" - features = get_features_used(node) + features = get_features_used(node, future_imports=future_imports) return { version for version in TargetVersion if features <= VERSION_TO_FEATURES[version] } @@ -1031,22 +1340,24 @@ def get_future_imports(node: Node) -> Set[str]: return imports -def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None: +def assert_equivalent(src: str, dst: str) -> None: """Raise AssertionError if `src` and `dst` aren't equivalent.""" try: src_ast = parse_ast(src) except Exception as exc: raise AssertionError( - "cannot use --safe with this file; failed to parse source file. AST" - f" error message: {exc}" - ) + "cannot use --safe with this file; failed to parse source file AST: " + f"{exc}\n" + "This could be caused by running Black with an older Python version " + "that does not support new syntax used in your source file." + ) from exc try: dst_ast = parse_ast(dst) except Exception as exc: log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) raise AssertionError( - f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. " + f"INTERNAL ERROR: Black produced invalid code: {exc}. " "Please report a bug on https://github.com/psf/black/issues. " f"This invalid output might be helpful: {log}" ) from None @@ -1057,14 +1368,17 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None: log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) raise AssertionError( "INTERNAL ERROR: Black produced code that is not equivalent to the" - f" source on pass {pass_num}. Please report a bug on " + " source. Please report a bug on " f"https://github.com/psf/black/issues. This diff might be helpful: {log}" ) from None def assert_stable(src: str, dst: str, mode: Mode) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" - newdst = format_str(dst, mode=mode) + # We shouldn't call format_str() here, because that formats the string + # twice and may hide a bug where we bounce back and forth between two + # versions. + newdst = _format_str_once(dst, mode=mode) if dst != newdst: log = dump_to_file( str(mode), @@ -1098,21 +1412,37 @@ def patch_click() -> None: file paths is minimal since it's Python source code. Moreover, this crash was spurious on Python 3.7 thanks to PEP 538 and PEP 540. """ + modules: List[Any] = [] try: from click import core + except ImportError: + pass + else: + modules.append(core) + try: + # Removed in Click 8.1.0 and newer; we keep this around for users who have + # older versions installed. from click import _unicodefun # type: ignore - except ModuleNotFoundError: - return + except ImportError: + pass + else: + modules.append(_unicodefun) - for module in (core, _unicodefun): + for module in modules: if hasattr(module, "_verify_python3_env"): - module._verify_python3_env = lambda: None # type: ignore + module._verify_python3_env = lambda: None if hasattr(module, "_verify_python_env"): - module._verify_python_env = lambda: None # type: ignore + module._verify_python_env = lambda: None def patched_main() -> None: - freeze_support() + # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows + # environments so just assume we always need to call it if frozen. + if getattr(sys, "frozen", False): + from multiprocessing import freeze_support + + freeze_support() + patch_click() main()