X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/544ea9c217cad9459d7b60665db787e94b52f93d..6695fd892f8f1a9e0d3879423bb0361e07eedd24:/black.py diff --git a/black.py b/black.py index ff4f098..2668747 100644 --- a/black.py +++ b/black.py @@ -2,7 +2,7 @@ import ast import asyncio from abc import ABC, abstractmethod from collections import defaultdict -from concurrent.futures import Executor, ProcessPoolExecutor +from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor from contextlib import contextmanager from datetime import datetime from enum import Enum @@ -39,6 +39,7 @@ from typing import ( TypeVar, Union, cast, + TYPE_CHECKING, ) from typing_extensions import Final from mypy_extensions import mypyc_attr @@ -59,6 +60,9 @@ from blib2to3.pgen2.parse import ParseError from _black_version import version as __version__ +if TYPE_CHECKING: + import colorama # noqa: F401 + DEFAULT_LINE_LENGTH = 88 DEFAULT_EXCLUDES = r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 DEFAULT_INCLUDES = r"\.pyi?$" @@ -140,12 +144,18 @@ class WriteBack(Enum): YES = 1 DIFF = 2 CHECK = 3 + COLOR_DIFF = 4 @classmethod - def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack": + def from_configuration( + cls, *, check: bool, diff: bool, color: bool = False + ) -> "WriteBack": if check and not diff: return cls.CHECK + if diff and color: + return cls.COLOR_DIFF + return cls.DIFF if diff else cls.YES @@ -296,6 +306,12 @@ def read_pyproject_toml( if not config: return None + target_version = config.get("target_version") + if target_version is not None and not isinstance(target_version, list): + raise click.BadOptionUsage( + "target-version", f"Config key target-version must be a list" + ) + default_map: Dict[str, Any] = {} if ctx.default_map: default_map.update(ctx.default_map) @@ -374,6 +390,11 @@ def target_version_option_callback( is_flag=True, help="Don't write the files back, just output a diff for each file on stdout.", ) +@click.option( + "--color/--no-color", + is_flag=True, + help="Show colored diff. Only applies when `--diff` is given.", +) @click.option( "--fast/--safe", is_flag=True, @@ -452,6 +473,7 @@ def main( target_version: List[TargetVersion], check: bool, diff: bool, + color: bool, fast: bool, pyi: bool, py36: bool, @@ -464,7 +486,7 @@ def main( config: Optional[str], ) -> None: """The uncompromising code formatter.""" - write_back = WriteBack.from_configuration(check=check, diff=diff) + write_back = WriteBack.from_configuration(check=check, diff=diff, color=color) if target_version: if py36: err("Cannot use both --target-version and --py36") @@ -591,12 +613,21 @@ def reformat_many( sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: """Reformat multiple files using a ProcessPoolExecutor.""" + executor: Executor loop = asyncio.get_event_loop() worker_count = os.cpu_count() if sys.platform == "win32": # Work around https://bugs.python.org/issue26903 worker_count = min(worker_count, 61) - executor = ProcessPoolExecutor(max_workers=worker_count) + try: + executor = ProcessPoolExecutor(max_workers=worker_count) + except OSError: + # we arrive here if the underlying system does not support multi-processing + # like in AWS Lambda, in which case we gracefully fallback to + # a ThreadPollExecutor with just a single worker (more workers would not do us + # any good due to the Global Interpreter Lock) + executor = ThreadPoolExecutor(max_workers=1) + try: loop.run_until_complete( schedule_formatting( @@ -611,7 +642,8 @@ def reformat_many( ) finally: shutdown(loop) - executor.shutdown() + if executor is not None: + executor.shutdown() async def schedule_formatting( @@ -712,12 +744,15 @@ def format_file_in_place( if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: f.write(dst_contents) - elif write_back == WriteBack.DIFF: + elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): now = datetime.utcnow() src_name = f"{src}\t{then} +0000" dst_name = f"{src}\t{now} +0000" diff_contents = diff(src_contents, dst_contents, src_name, dst_name) + if write_back == write_back.COLOR_DIFF: + diff_contents = color_diff(diff_contents) + with lock or nullcontext(): f = io.TextIOWrapper( sys.stdout.buffer, @@ -725,12 +760,57 @@ def format_file_in_place( newline=newline, write_through=True, ) + f = wrap_stream_for_windows(f) f.write(diff_contents) f.detach() return True +def color_diff(contents: str) -> str: + """Inject the ANSI color codes to the diff.""" + lines = contents.split("\n") + for i, line in enumerate(lines): + if line.startswith("+++") or line.startswith("---"): + line = "\033[1;37m" + line + "\033[0m" # bold white, reset + if line.startswith("@@"): + line = "\033[36m" + line + "\033[0m" # cyan, reset + if line.startswith("+"): + line = "\033[32m" + line + "\033[0m" # green, reset + elif line.startswith("-"): + line = "\033[31m" + line + "\033[0m" # red, reset + lines[i] = line + return "\n".join(lines) + + +def wrap_stream_for_windows( + f: io.TextIOWrapper, +) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]: + """ + Wrap the stream in colorama's wrap_stream so colors are shown on Windows. + + If `colorama` is not found, then no change is made. If `colorama` does + exist, then it handles the logic to determine whether or not to change + things. + """ + try: + from colorama import initialise + + # We set `strip=False` so that we can don't have to modify + # test_express_diff_with_color. + f = initialise.wrap_stream( + f, convert=None, strip=False, autoreset=False, wrap=True + ) + + # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object + # which does not have a `detach()` method. So we fake one. + f.detach = lambda *args, **kwargs: None # type: ignore + except ImportError: + pass + + return f + + def format_stdin_to_stdout( fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode ) -> bool: @@ -756,11 +836,15 @@ def format_stdin_to_stdout( ) if write_back == WriteBack.YES: f.write(dst) - elif write_back == WriteBack.DIFF: + elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): now = datetime.utcnow() src_name = f"STDIN\t{then} +0000" dst_name = f"STDOUT\t{now} +0000" - f.write(diff(src, dst, src_name, dst_name)) + d = diff(src, dst, src_name, dst_name) + if write_back == WriteBack.COLOR_DIFF: + d = color_diff(d) + f = wrap_stream_for_windows(f) + f.write(d) f.detach() @@ -1949,6 +2033,18 @@ class LineGenerator(Visitor[Line]): node.insert_child(index, Node(syms.atom, [lpar, operand, rpar])) yield from self.visit_default(node) + def visit_STRING(self, leaf: Leaf) -> Iterator[Line]: + # Check if it's a docstring + if prev_siblings_are( + leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt] + ) and is_multiline_string(leaf): + prefix = " " * self.current_line.depth + docstring = fix_docstring(leaf.value[3:-3], prefix) + leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:] + normalize_string_quotes(leaf) + + yield from self.visit_default(leaf) + def __post_init__(self) -> None: """You are in a twisty little maze of passages.""" v = self.visit_stmt @@ -2230,6 +2326,22 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: return None +def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool: + """Return if the `node` and its previous siblings match types against the provided + list of tokens; the provided `node`has its type matched against the last element in + the list. `None` can be used as the first element to declare that the start of the + list is anchored at the start of its parent's children.""" + if not tokens: + return True + if tokens[-1] is None: + return node is None + if not node: + return False + if node.type != tokens[-1]: + return False + return prev_siblings_are(node.prev_sibling, tokens[:-1]) + + def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]: """Return the child of `ancestor` that contains `descendant`.""" node: Optional[LN] = descendant @@ -2510,6 +2622,7 @@ def transform_line( is_line_short_enough(line, line_length=line_length, line_str=line_str) or line.contains_unsplittable_type_ignore() ) + and not (line.contains_standalone_comments() and line.inside_brackets) ): # Only apply basic string preprocessing, since lines shouldn't be split here. transformers = [string_merge, string_paren_strip] @@ -5121,7 +5234,7 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]: """ container: Optional[LN] = container_of(leaf) while container is not None and container.type != token.ENDMARKER: - if fmt_on(container): + if is_fmt_on(container): return # fix for fmt: on in children @@ -5135,17 +5248,21 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]: container = container.next_sibling -def fmt_on(container: LN) -> bool: - is_fmt_on = False +def is_fmt_on(container: LN) -> bool: + """Determine whether formatting is switched on within a container. + Determined by whether the last `# fmt:` comment is `on` or `off`. + """ + fmt_on = False for comment in list_comments(container.prefix, is_endmarker=False): if comment.value in FMT_ON: - is_fmt_on = True + fmt_on = True elif comment.value in FMT_OFF: - is_fmt_on = False - return is_fmt_on + fmt_on = False + return fmt_on def contains_fmt_on_at_column(container: LN, column: int) -> bool: + """Determine if children at a given column have formatting switched on.""" for child in container.children: if ( isinstance(child, Node) @@ -5153,13 +5270,14 @@ def contains_fmt_on_at_column(container: LN, column: int) -> bool: or isinstance(child, Leaf) and child.column == column ): - if fmt_on(child): + if is_fmt_on(child): return True return False def first_leaf_column(node: Node) -> Optional[int]: + """Returns the column of the first leaf child of a node.""" for child in node.children: if isinstance(child, Leaf): return child.column @@ -5798,54 +5916,66 @@ def _fixup_ast_constants( return node -def assert_equivalent(src: str, dst: str) -> None: - """Raise AssertionError if `src` and `dst` aren't equivalent.""" +def _stringify_ast( + node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0 +) -> Iterator[str]: + """Simple visitor generating strings to compare ASTs by content.""" - def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]: - """Simple visitor generating strings to compare ASTs by content.""" + node = _fixup_ast_constants(node) - node = _fixup_ast_constants(node) + yield f"{' ' * depth}{node.__class__.__name__}(" - yield f"{' ' * depth}{node.__class__.__name__}(" - - for field in sorted(node._fields): # noqa: F402 - # TypeIgnore has only one field 'lineno' which breaks this comparison - type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore) - if sys.version_info >= (3, 8): - type_ignore_classes += (ast.TypeIgnore,) - if isinstance(node, type_ignore_classes): - break + for field in sorted(node._fields): # noqa: F402 + # TypeIgnore has only one field 'lineno' which breaks this comparison + type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore) + if sys.version_info >= (3, 8): + type_ignore_classes += (ast.TypeIgnore,) + if isinstance(node, type_ignore_classes): + break - try: - value = getattr(node, field) - except AttributeError: - continue + try: + value = getattr(node, field) + except AttributeError: + continue - yield f"{' ' * (depth+1)}{field}=" + yield f"{' ' * (depth+1)}{field}=" - if isinstance(value, list): - for item in value: - # Ignore nested tuples within del statements, because we may insert - # parentheses and they change the AST. - if ( - field == "targets" - and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete)) - and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple)) - ): - for item in item.elts: - yield from _v(item, depth + 2) + if isinstance(value, list): + for item in value: + # Ignore nested tuples within del statements, because we may insert + # parentheses and they change the AST. + if ( + field == "targets" + and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete)) + and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple)) + ): + for item in item.elts: + yield from _stringify_ast(item, depth + 2) - elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)): - yield from _v(item, depth + 2) + elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)): + yield from _stringify_ast(item, depth + 2) - elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)): - yield from _v(value, depth + 2) + elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)): + yield from _stringify_ast(value, depth + 2) + else: + # Constant strings may be indented across newlines, if they are + # docstrings; fold spaces after newlines when comparing + if ( + isinstance(node, ast.Constant) + and field == "value" + and isinstance(value, str) + ): + normalized = re.sub(r"\n[ \t]+", "\n ", value) else: - yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}" + normalized = value + yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}" - yield f"{' ' * depth}) # /{node.__class__.__name__}" + yield f"{' ' * depth}) # /{node.__class__.__name__}" + +def assert_equivalent(src: str, dst: str) -> None: + """Raise AssertionError if `src` and `dst` aren't equivalent.""" try: src_ast = parse_ast(src) except Exception as exc: @@ -5864,8 +5994,8 @@ def assert_equivalent(src: str, dst: str) -> None: f" helpful: {log}" ) from None - src_ast_str = "\n".join(_v(src_ast)) - dst_ast_str = "\n".join(_v(dst_ast)) + src_ast_str = "\n".join(_stringify_ast(src_ast)) + dst_ast_str = "\n".join(_stringify_ast(dst_ast)) if src_ast_str != dst_ast_str: log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) raise AssertionError( @@ -6230,5 +6360,32 @@ def patched_main() -> None: main() +def fix_docstring(docstring: str, prefix: str) -> str: + # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation + if not docstring: + return "" + # Convert tabs to spaces (following the normal Python rules) + # and split into a list of lines: + lines = docstring.expandtabs().splitlines() + # Determine minimum indentation (first line doesn't count): + indent = sys.maxsize + for line in lines[1:]: + stripped = line.lstrip() + if stripped: + indent = min(indent, len(line) - len(stripped)) + # Remove indentation (first line is special): + trimmed = [lines[0].strip()] + if indent < sys.maxsize: + last_line_idx = len(lines) - 2 + for i, line in enumerate(lines[1:]): + stripped_line = line[indent:].rstrip() + if stripped_line or i == last_line_idx: + trimmed.append(prefix + stripped_line) + else: + trimmed.append("") + # Return a single string: + return "\n".join(trimmed) + + if __name__ == "__main__": patched_main()