X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/80bd2b3134b4f01da4e279d040a224326b3577e5..1d45f6e6a11675ee1ee5a3b0c3664cd7feec532b:/black.py?ds=inline diff --git a/black.py b/black.py index dc03e0a..c2c118a 100644 --- a/black.py +++ b/black.py @@ -3,15 +3,18 @@ import asyncio from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor +from enum import Enum from functools import partial, wraps import keyword import logging +from multiprocessing import Manager import os from pathlib import Path import tokenize import signal import sys from typing import ( + Any, Callable, Dict, Generic, @@ -92,6 +95,12 @@ class FormatOff(FormatError): """Found a comment like `# fmt: off` in the file.""" +class WriteBack(Enum): + NO = 0 + YES = 1 + DIFF = 2 + + @click.command() @click.option( "-l", @@ -105,11 +114,16 @@ class FormatOff(FormatError): "--check", is_flag=True, help=( - "Don't write back the files, just return the status. Return code 0 " + "Don't write the files back, just return the status. Return code 0 " "means nothing would change. Return code 1 means some files would be " "reformatted. Return code 123 means there was an internal error." ), ) +@click.option( + "--diff", + is_flag=True, + help="Don't write the files back, just output a diff for each file on stdout.", +) @click.option( "--fast/--safe", is_flag=True, @@ -125,7 +139,12 @@ class FormatOff(FormatError): ) @click.pass_context def main( - ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str] + ctx: click.Context, + line_length: int, + check: bool, + diff: bool, + fast: bool, + src: List[str], ) -> None: """The uncompromising code formatter.""" sources: List[Path] = [] @@ -140,6 +159,17 @@ def main( sources.append(Path("-")) else: err(f"invalid path: {s}") + if check and diff: + exc = click.ClickException("Options --check and --diff are mutually exclusive") + exc.exit_code = 2 + raise exc + + if check: + write_back = WriteBack.NO + elif diff: + write_back = WriteBack.DIFF + else: + write_back = WriteBack.YES if len(sources) == 0: ctx.exit(0) elif len(sources) == 1: @@ -148,11 +178,11 @@ def main( try: if not p.is_file() and str(p) == "-": changed = format_stdin_to_stdout( - line_length=line_length, fast=fast, write_back=not check + line_length=line_length, fast=fast, write_back=write_back ) else: changed = format_file_in_place( - p, line_length=line_length, fast=fast, write_back=not check + p, line_length=line_length, fast=fast, write_back=write_back ) report.done(p, changed) except Exception as exc: @@ -165,7 +195,7 @@ def main( try: return_code = loop.run_until_complete( schedule_formatting( - sources, line_length, not check, fast, loop, executor + sources, line_length, write_back, fast, loop, executor ) ) finally: @@ -176,7 +206,7 @@ def main( async def schedule_formatting( sources: List[Path], line_length: int, - write_back: bool, + write_back: WriteBack, fast: bool, loop: BaseEventLoop, executor: Executor, @@ -188,9 +218,15 @@ async def schedule_formatting( `line_length`, `write_back`, and `fast` options are passed to :func:`format_file_in_place`. """ + lock = None + if write_back == WriteBack.DIFF: + # For diff output, we need locks to ensure we don't interleave output + # from different processes. + manager = Manager() + lock = manager.Lock() tasks = { src: loop.run_in_executor( - executor, format_file_in_place, src, line_length, fast, write_back + executor, format_file_in_place, src, line_length, fast, write_back, lock ) for src in sources } @@ -220,7 +256,11 @@ async def schedule_formatting( def format_file_in_place( - src: Path, line_length: int, fast: bool, write_back: bool = False + src: Path, + line_length: int, + fast: bool, + write_back: WriteBack = WriteBack.NO, + lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy ) -> bool: """Format file under `src` path. Return True if changed. @@ -230,37 +270,53 @@ def format_file_in_place( with tokenize.open(src) as src_buffer: src_contents = src_buffer.read() try: - contents = format_file_contents( + dst_contents = format_file_contents( src_contents, line_length=line_length, fast=fast ) except NothingChanged: return False - if write_back: + if write_back == write_back.YES: with open(src, "w", encoding=src_buffer.encoding) as f: - f.write(contents) + f.write(dst_contents) + elif write_back == write_back.DIFF: + src_name = f"{src.name} (original)" + dst_name = f"{src.name} (formatted)" + diff_contents = diff(src_contents, dst_contents, src_name, dst_name) + if lock: + lock.acquire() + try: + sys.stdout.write(diff_contents) + finally: + if lock: + lock.release() return True def format_stdin_to_stdout( - line_length: int, fast: bool, write_back: bool = False + line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO ) -> bool: """Format file on stdin. Return True if changed. If `write_back` is True, write reformatted code back to stdout. `line_length` and `fast` arguments are passed to :func:`format_file_contents`. """ - contents = sys.stdin.read() + src = sys.stdin.read() try: - contents = format_file_contents(contents, line_length=line_length, fast=fast) + dst = format_file_contents(src, line_length=line_length, fast=fast) return True except NothingChanged: + dst = src return False finally: - if write_back: - sys.stdout.write(contents) + if write_back == WriteBack.YES: + sys.stdout.write(dst) + elif write_back == WriteBack.DIFF: + src_name = " (original)" + dst_name = " (formatted)" + sys.stdout.write(diff(src, dst, src_name, dst_name)) def format_file_contents( @@ -451,6 +507,7 @@ MATH_OPERATORS = { token.DOUBLESTAR, token.DOUBLESLASH, } +VARARGS = {token.STAR, token.DOUBLESTAR} COMPREHENSION_PRIORITY = 20 COMMA_PRIORITY = 10 LOGIC_PRIORITY = 5 @@ -492,32 +549,12 @@ class BracketTracker: leaf.opening_bracket = opening_bracket leaf.bracket_depth = self.depth if self.depth == 0: - delim = is_delimiter(leaf) - if delim: - self.delimiters[id(leaf)] = delim - elif self.previous is not None: - if leaf.type == token.STRING and self.previous.type == token.STRING: - self.delimiters[id(self.previous)] = STRING_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value == "for" - and leaf.parent - and leaf.parent.type in {syms.comp_for, syms.old_comp_for} - ): - self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value == "if" - and leaf.parent - and leaf.parent.type in {syms.comp_if, syms.old_comp_if} - ): - self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value in LOGIC_OPERATORS - and leaf.parent - ): - self.delimiters[id(self.previous)] = LOGIC_PRIORITY + after_delim = is_split_after_delimiter(leaf, self.previous) + before_delim = is_split_before_delimiter(leaf, self.previous) + if after_delim > before_delim: + self.delimiters[id(leaf)] = after_delim + elif before_delim > after_delim and self.previous is not None: + self.delimiters[id(self.previous)] = before_delim if leaf.type in OPENING_BRACKETS: self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf self.depth += 1 @@ -1374,17 +1411,35 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: return None -def is_delimiter(leaf: Leaf) -> int: - """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. +def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter, given a line break after it. + + The delimiter priorities returned here are from those delimiters that would + cause a line break after themselves. Higher numbers are higher priority. """ if leaf.type == token.COMMA: return COMMA_PRIORITY - if leaf.type in COMPARATORS: - return COMPARATOR_PRIORITY + if ( + leaf.type in VARARGS + and leaf.parent + and leaf.parent.type in {syms.argument, syms.typedargslist} + ): + return MATH_PRIORITY + return 0 + + +def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter, given a line before after it. + + The delimiter priorities returned here are from those delimiters that would + cause a line break before themselves. + + Higher numbers are higher priority. + """ if ( leaf.type in MATH_OPERATORS and leaf.parent @@ -1392,9 +1447,49 @@ def is_delimiter(leaf: Leaf) -> int: ): return MATH_PRIORITY + if leaf.type in COMPARATORS: + return COMPARATOR_PRIORITY + + if ( + leaf.type == token.STRING + and previous is not None + and previous.type == token.STRING + ): + return STRING_PRIORITY + + if ( + leaf.type == token.NAME + and leaf.value == "for" + and leaf.parent + and leaf.parent.type in {syms.comp_for, syms.old_comp_for} + ): + return COMPREHENSION_PRIORITY + + if ( + leaf.type == token.NAME + and leaf.value == "if" + and leaf.parent + and leaf.parent.type in {syms.comp_if, syms.old_comp_if} + ): + return COMPREHENSION_PRIORITY + + if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent: + return LOGIC_PRIORITY + return 0 +def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. + + Higher numbers are higher priority. + """ + return max( + is_split_before_delimiter(leaf, previous), + is_split_after_delimiter(leaf, previous), + ) + + def generate_comments(leaf: Leaf) -> Iterator[Leaf]: """Clean the prefix of the `leaf` and generate comments from it, if any. @@ -1776,6 +1871,13 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: def normalize_string_quotes(leaf: Leaf) -> None: + """Prefer double quotes but only if it doesn't cause more escaping. + + Adds or removes backslashes as appropriate. Doesn't parse and fix + strings nested in f-strings (yet). + + Note: Mutates its argument. + """ value = leaf.value.lstrip("furbFURB") if value[:3] == '"""': return @@ -2018,7 +2120,8 @@ def dump_to_file(*output: str) -> str: ) as f: for lines in output: f.write(lines) - f.write("\n") + if lines and lines[-1] != "\n": + f.write("\n") return f.name