X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/80bd2b3134b4f01da4e279d040a224326b3577e5..00cadd43eeeaa24d7f01badacacd5dd5f8c21d5f:/black.py?ds=inline diff --git a/black.py b/black.py index dc03e0a..9b144ed 100644 --- a/black.py +++ b/black.py @@ -3,15 +3,19 @@ import asyncio from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor +from enum import Enum from functools import partial, wraps import keyword import logging +from multiprocessing import Manager import os from pathlib import Path +import re import tokenize import signal import sys from typing import ( + Any, Callable, Dict, Generic, @@ -35,7 +39,7 @@ from blib2to3 import pygram, pytree from blib2to3.pgen2 import driver, token from blib2to3.pgen2.parse import ParseError -__version__ = "18.3a4" +__version__ = "18.4a0" DEFAULT_LINE_LENGTH = 88 # types syms = pygram.python_symbols @@ -92,6 +96,12 @@ class FormatOff(FormatError): """Found a comment like `# fmt: off` in the file.""" +class WriteBack(Enum): + NO = 0 + YES = 1 + DIFF = 2 + + @click.command() @click.option( "-l", @@ -105,16 +115,30 @@ class FormatOff(FormatError): "--check", is_flag=True, help=( - "Don't write back the files, just return the status. Return code 0 " + "Don't write the files back, just return the status. Return code 0 " "means nothing would change. Return code 1 means some files would be " "reformatted. Return code 123 means there was an internal error." ), ) +@click.option( + "--diff", + is_flag=True, + help="Don't write the files back, just output a diff for each file on stdout.", +) @click.option( "--fast/--safe", is_flag=True, help="If --fast given, skip temporary sanity checks. [default: --safe]", ) +@click.option( + "-q", + "--quiet", + is_flag=True, + help=( + "Don't emit non-error messages to stderr. Errors are still emitted, " + "silence those with 2>/dev/null." + ), +) @click.version_option(version=__version__) @click.argument( "src", @@ -125,7 +149,13 @@ class FormatOff(FormatError): ) @click.pass_context def main( - ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str] + ctx: click.Context, + line_length: int, + check: bool, + diff: bool, + fast: bool, + quiet: bool, + src: List[str], ) -> None: """The uncompromising code formatter.""" sources: List[Path] = [] @@ -140,19 +170,30 @@ def main( sources.append(Path("-")) else: err(f"invalid path: {s}") + if check and diff: + exc = click.ClickException("Options --check and --diff are mutually exclusive") + exc.exit_code = 2 + raise exc + + if check: + write_back = WriteBack.NO + elif diff: + write_back = WriteBack.DIFF + else: + write_back = WriteBack.YES if len(sources) == 0: ctx.exit(0) elif len(sources) == 1: p = sources[0] - report = Report(check=check) + report = Report(check=check, quiet=quiet) try: if not p.is_file() and str(p) == "-": changed = format_stdin_to_stdout( - line_length=line_length, fast=fast, write_back=not check + line_length=line_length, fast=fast, write_back=write_back ) else: changed = format_file_in_place( - p, line_length=line_length, fast=fast, write_back=not check + p, line_length=line_length, fast=fast, write_back=write_back ) report.done(p, changed) except Exception as exc: @@ -165,7 +206,7 @@ def main( try: return_code = loop.run_until_complete( schedule_formatting( - sources, line_length, not check, fast, loop, executor + sources, line_length, write_back, fast, quiet, loop, executor ) ) finally: @@ -176,8 +217,9 @@ def main( async def schedule_formatting( sources: List[Path], line_length: int, - write_back: bool, + write_back: WriteBack, fast: bool, + quiet: bool, loop: BaseEventLoop, executor: Executor, ) -> int: @@ -188,9 +230,15 @@ async def schedule_formatting( `line_length`, `write_back`, and `fast` options are passed to :func:`format_file_in_place`. """ + lock = None + if write_back == WriteBack.DIFF: + # For diff output, we need locks to ensure we don't interleave output + # from different processes. + manager = Manager() + lock = manager.Lock() tasks = { src: loop.run_in_executor( - executor, format_file_in_place, src, line_length, fast, write_back + executor, format_file_in_place, src, line_length, fast, write_back, lock ) for src in sources } @@ -199,7 +247,7 @@ async def schedule_formatting( loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) await asyncio.wait(tasks.values()) cancelled = [] - report = Report(check=not write_back) + report = Report(check=write_back is WriteBack.NO, quiet=quiet) for src, task in tasks.items(): if not task.done(): report.failed(src, "timed out, cancelling") @@ -213,14 +261,19 @@ async def schedule_formatting( report.done(src, task.result()) if cancelled: await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) - else: + elif not quiet: out("All done! ✨ 🍰 ✨") - click.echo(str(report)) + if not quiet: + click.echo(str(report)) return report.return_code def format_file_in_place( - src: Path, line_length: int, fast: bool, write_back: bool = False + src: Path, + line_length: int, + fast: bool, + write_back: WriteBack = WriteBack.NO, + lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy ) -> bool: """Format file under `src` path. Return True if changed. @@ -230,37 +283,53 @@ def format_file_in_place( with tokenize.open(src) as src_buffer: src_contents = src_buffer.read() try: - contents = format_file_contents( + dst_contents = format_file_contents( src_contents, line_length=line_length, fast=fast ) except NothingChanged: return False - if write_back: + if write_back == write_back.YES: with open(src, "w", encoding=src_buffer.encoding) as f: - f.write(contents) + f.write(dst_contents) + elif write_back == write_back.DIFF: + src_name = f"{src.name} (original)" + dst_name = f"{src.name} (formatted)" + diff_contents = diff(src_contents, dst_contents, src_name, dst_name) + if lock: + lock.acquire() + try: + sys.stdout.write(diff_contents) + finally: + if lock: + lock.release() return True def format_stdin_to_stdout( - line_length: int, fast: bool, write_back: bool = False + line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO ) -> bool: """Format file on stdin. Return True if changed. If `write_back` is True, write reformatted code back to stdout. `line_length` and `fast` arguments are passed to :func:`format_file_contents`. """ - contents = sys.stdin.read() + src = sys.stdin.read() try: - contents = format_file_contents(contents, line_length=line_length, fast=fast) + dst = format_file_contents(src, line_length=line_length, fast=fast) return True except NothingChanged: + dst = src return False finally: - if write_back: - sys.stdout.write(contents) + if write_back == WriteBack.YES: + sys.stdout.write(dst) + elif write_back == WriteBack.DIFF: + src_name = " (original)" + dst_name = " (formatted)" + sys.stdout.write(diff(src, dst, src_name, dst_name)) def format_file_contents( @@ -451,6 +520,7 @@ MATH_OPERATORS = { token.DOUBLESTAR, token.DOUBLESLASH, } +VARARGS = {token.STAR, token.DOUBLESTAR} COMPREHENSION_PRIORITY = 20 COMMA_PRIORITY = 10 LOGIC_PRIORITY = 5 @@ -492,32 +562,13 @@ class BracketTracker: leaf.opening_bracket = opening_bracket leaf.bracket_depth = self.depth if self.depth == 0: - delim = is_delimiter(leaf) - if delim: - self.delimiters[id(leaf)] = delim - elif self.previous is not None: - if leaf.type == token.STRING and self.previous.type == token.STRING: - self.delimiters[id(self.previous)] = STRING_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value == "for" - and leaf.parent - and leaf.parent.type in {syms.comp_for, syms.old_comp_for} - ): - self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value == "if" - and leaf.parent - and leaf.parent.type in {syms.comp_if, syms.old_comp_if} - ): - self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value in LOGIC_OPERATORS - and leaf.parent - ): - self.delimiters[id(self.previous)] = LOGIC_PRIORITY + delim = is_split_before_delimiter(leaf, self.previous) + if delim and self.previous is not None: + self.delimiters[id(self.previous)] = delim + else: + delim = is_split_after_delimiter(leaf, self.previous) + if delim: + self.delimiters[id(leaf)] = delim if leaf.type in OPENING_BRACKETS: self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf self.depth += 1 @@ -1097,6 +1148,10 @@ class LineGenerator(Visitor[Line]): yield from self.line() yield from self.visit(node) + if node.type == token.ENDMARKER: + # somebody decided not to put a final `# fmt: on` + yield from self.line() + def __attrs_post_init__(self) -> None: """You are in a twisty little maze of passages.""" v = self.visit_stmt @@ -1374,16 +1429,36 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: return None -def is_delimiter(leaf: Leaf) -> int: - """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. +def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter, given a line break after it. + + The delimiter priorities returned here are from those delimiters that would + cause a line break after themselves. Higher numbers are higher priority. """ if leaf.type == token.COMMA: return COMMA_PRIORITY - if leaf.type in COMPARATORS: - return COMPARATOR_PRIORITY + return 0 + + +def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter, given a line before after it. + + The delimiter priorities returned here are from those delimiters that would + cause a line break before themselves. + + Higher numbers are higher priority. + """ + if ( + leaf.type in VARARGS + and leaf.parent + and leaf.parent.type in {syms.argument, syms.typedargslist} + ): + # * and ** might also be MATH_OPERATORS but in this case they are not. + # Don't treat them as a delimiter. + return 0 if ( leaf.type in MATH_OPERATORS @@ -1392,9 +1467,49 @@ def is_delimiter(leaf: Leaf) -> int: ): return MATH_PRIORITY + if leaf.type in COMPARATORS: + return COMPARATOR_PRIORITY + + if ( + leaf.type == token.STRING + and previous is not None + and previous.type == token.STRING + ): + return STRING_PRIORITY + + if ( + leaf.type == token.NAME + and leaf.value == "for" + and leaf.parent + and leaf.parent.type in {syms.comp_for, syms.old_comp_for} + ): + return COMPREHENSION_PRIORITY + + if ( + leaf.type == token.NAME + and leaf.value == "if" + and leaf.parent + and leaf.parent.type in {syms.comp_if, syms.old_comp_if} + ): + return COMPREHENSION_PRIORITY + + if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent: + return LOGIC_PRIORITY + return 0 +def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. + + Higher numbers are higher priority. + """ + return max( + is_split_before_delimiter(leaf, previous), + is_split_after_delimiter(leaf, previous), + ) + + def generate_comments(leaf: Leaf) -> Iterator[Leaf]: """Clean the prefix of the `leaf` and generate comments from it, if any. @@ -1442,7 +1557,12 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: raise FormatOn(consumed) if comment in {"# fmt: off", "# yapf: disable"}: - raise FormatOff(consumed) + if comment_type == STANDALONE_COMMENT: + raise FormatOff(consumed) + + prev = preceding_leaf(leaf) + if not prev or prev.type in WHITESPACE: # standalone comment in disguise + raise FormatOff(consumed) nlines = 0 @@ -1701,9 +1821,10 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) if current_line: if ( - delimiter_priority == COMMA_PRIORITY + trailing_comma_safe + and delimiter_priority == COMMA_PRIORITY and current_line.leaves[-1].type != token.COMMA - and trailing_comma_safe + and current_line.leaves[-1].type != STANDALONE_COMMENT ): current_line.append(Leaf(token.COMMA, ",")) yield current_line @@ -1776,6 +1897,13 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: def normalize_string_quotes(leaf: Leaf) -> None: + """Prefer double quotes but only if it doesn't cause more escaping. + + Adds or removes backslashes as appropriate. Doesn't parse and fix + strings nested in f-strings (yet). + + Note: Mutates its argument. + """ value = leaf.value.lstrip("furbFURB") if value[:3] == '"""': return @@ -1793,10 +1921,21 @@ def normalize_string_quotes(leaf: Leaf) -> None: if first_quote_pos == -1: return # There's an internal error + prefix = leaf.value[:first_quote_pos] body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)] - new_body = body.replace(f"\\{orig_quote}", orig_quote).replace( - new_quote, f"\\{new_quote}" - ) + unescaped_new_quote = re.compile(r"(([^\\]|^)(\\\\)*)" + new_quote) + escaped_orig_quote = re.compile(r"\\(\\\\)*" + orig_quote) + if "r" in prefix.casefold(): + if unescaped_new_quote.search(body): + # There's at least one unescaped new_quote in this raw string + # so converting is impossible + return + + # Do not introduce or remove backslashes in raw strings + new_body = body + else: + new_body = escaped_orig_quote.sub(f"\\1{orig_quote}", body) + new_body = unescaped_new_quote.sub(f"\\1\\\\{new_quote}", new_body) if new_quote == '"""' and new_body[-1] == '"': # edge case: new_body = new_body[:-1] + '\\"' @@ -1808,7 +1947,6 @@ def normalize_string_quotes(leaf: Leaf) -> None: if new_escape_count == orig_escape_count and orig_quote == '"': return # Prefer double quotes - prefix = leaf.value[:first_quote_pos] leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}" @@ -1862,6 +2000,7 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]: class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" check: bool = False + quiet: bool = False change_count: int = 0 same_count: int = 0 failure_count: int = 0 @@ -1870,10 +2009,12 @@ class Report: """Increment the counter for successful reformatting. Write out a message.""" if changed: reformatted = "would reformat" if self.check else "reformatted" - out(f"{reformatted} {src}") + if not self.quiet: + out(f"{reformatted} {src}") self.change_count += 1 else: - out(f"{src} already well formatted, good job.", bold=False) + if not self.quiet: + out(f"{src} already well formatted, good job.", bold=False) self.same_count += 1 def failed(self, src: Path, message: str) -> None: @@ -2018,7 +2159,8 @@ def dump_to_file(*output: str) -> str: ) as f: for lines in output: f.write(lines) - f.write("\n") + if lines and lines[-1] != "\n": + f.write("\n") return f.name