From 788268bc39a87d37a24d203fa5ee7b3953af3446 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=C5=81ukasz=20Langa?= Date: Tue, 29 Oct 2019 00:50:42 +0100 Subject: [PATCH] Re-implement magic trailing comma handling: - when a trailing comma is specified in any bracket pair, that signals to Black that this bracket pair needs to be always exploded, e.g. presented as "one item per line"; - this causes some changes to previously formatted code that erroneously left trailing commas embedded into single-line expressions; - internally, Black needs to be able to identify trailing commas that it put itself compared to pre-existing trailing commas. We do this by using/abusing lib2to3's `was_checked` attribute. It's True for internally generated trailing commas and False for pre-existing ones (in fact, for all pre-existing leaves and nodes). Fixes #1288 --- CHANGES.md | 3 + gallery/gallery.py | 5 +- src/black/__init__.py | 126 +++++++++++----------- src/blib2to3/pgen2/driver.py | 2 +- tests/data/collections.py | 35 ++++-- tests/data/comments2.py | 8 +- tests/data/comments7.py | 9 +- tests/data/expression.diff | 27 ++++- tests/data/expression.py | 22 +++- tests/data/fmtonoff4.py | 7 +- tests/data/function.py | 5 +- tests/data/function2.py | 5 +- tests/data/function_trailing_comma.py | 58 ++++++++-- tests/data/function_trailing_comma_wip.py | 5 + tests/data/import_spacing.py | 8 +- tests/data/long_strings.py | 32 +++++- tests/data/long_strings__regression.py | 12 ++- tests/data/long_strings_flag_disabled.py | 18 +++- tests/test_black.py | 52 ++++++++- 19 files changed, 336 insertions(+), 103 deletions(-) create mode 100644 tests/data/function_trailing_comma_wip.py diff --git a/CHANGES.md b/CHANGES.md index 6d418b9..eb6d1c2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,6 +4,9 @@ #### _Black_ +- re-implemented support for explicit trailing commas: now it works consistently within + any bracket pair, including nested structures (#1288 and duplicates) + - reindent docstrings when reindenting code around it (#1053) - show colored diffs (#1266) diff --git a/gallery/gallery.py b/gallery/gallery.py index 2a56b4e..6b42ec3 100755 --- a/gallery/gallery.py +++ b/gallery/gallery.py @@ -127,7 +127,10 @@ DEFAULT_SLICE = slice(None) # for flake8 def download_and_extract_top_packages( - directory: Path, days: Days = 365, workers: int = 8, limit: slice = DEFAULT_SLICE, + directory: Path, + days: Days = 365, + workers: int = 8, + limit: slice = DEFAULT_SLICE, ) -> Generator[Path, None, None]: with ThreadPoolExecutor(max_workers=workers) as executor: bound_downloader = partial(get_package, version=None, directory=directory) diff --git a/src/black/__init__.py b/src/black/__init__.py index 2250943..8d0c70f 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -1442,7 +1442,8 @@ class Line: ) if self.inside_brackets or not preformatted: self.bracket_tracker.mark(leaf) - self.maybe_remove_trailing_comma(leaf) + if self.maybe_should_explode(leaf): + self.should_explode = True if not self.append_comment(leaf): self.leaves.append(leaf) @@ -1618,59 +1619,26 @@ class Line: def contains_multiline_strings(self) -> bool: return any(is_multiline_string(leaf) for leaf in self.leaves) - def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: - """Remove trailing comma if there is one and it's safe.""" + def maybe_should_explode(self, closing: Leaf) -> bool: + """Return True if this line should explode (always be split), that is when: + - there's a pre-existing trailing comma here; and + - it's not a one-tuple. + """ if not ( - self.leaves + closing.type in CLOSING_BRACKETS + and self.leaves and self.leaves[-1].type == token.COMMA - and closing.type in CLOSING_BRACKETS + and not self.leaves[-1].was_checked # pre-existing ): return False - if closing.type == token.RBRACE: - self.remove_trailing_comma() + if closing.type in {token.RBRACE, token.RSQB}: return True - if closing.type == token.RSQB: - comma = self.leaves[-1] - if comma.parent and comma.parent.type == syms.listmaker: - self.remove_trailing_comma() - return True - - # For parens let's check if it's safe to remove the comma. - # Imports are always safe. if self.is_import: - self.remove_trailing_comma() return True - # Otherwise, if the trailing one is the only one, we might mistakenly - # change a tuple into a different type by removing the comma. - depth = closing.bracket_depth + 1 - commas = 0 - opening = closing.opening_bracket - for _opening_index, leaf in enumerate(self.leaves): - if leaf is opening: - break - - else: - return False - - for leaf in self.leaves[_opening_index + 1 :]: - if leaf is closing: - break - - bracket_depth = leaf.bracket_depth - if bracket_depth == depth and leaf.type == token.COMMA: - commas += 1 - if leaf.parent and leaf.parent.type in { - syms.arglist, - syms.typedargslist, - }: - commas += 1 - break - - if commas > 1: - self.remove_trailing_comma() + if not is_one_tuple_between(closing.opening_bracket, closing, self.leaves): return True return False @@ -2647,7 +2615,7 @@ def transform_line( is_line_short_enough(line, line_length=mode.line_length, line_str=line_str) or line.contains_unsplittable_type_ignore() ) - and not (line.contains_standalone_comments() and line.inside_brackets) + and not (line.inside_brackets and line.contains_standalone_comments()) ): # Only apply basic string preprocessing, since lines shouldn't be split here. if mode.experimental_string_processing: @@ -4772,10 +4740,8 @@ def right_hand_split( tail = bracket_split_build_line(tail_leaves, line, opening_bracket) bracket_split_succeeded_or_raise(head, body, tail) if ( - # the body shouldn't be exploded - not body.should_explode # the opening bracket is an optional paren - and opening_bracket.type == token.LPAR + opening_bracket.type == token.LPAR and not opening_bracket.value # the closing bracket is an optional paren and closing_bracket.type == token.RPAR @@ -4872,7 +4838,9 @@ def bracket_split_build_line( continue if leaves[i].type != token.COMMA: - leaves.insert(i + 1, Leaf(token.COMMA, ",")) + new_comma = Leaf(token.COMMA, ",") + new_comma.was_checked = True + leaves.insert(i + 1, new_comma) break # Populate the line @@ -4880,8 +4848,8 @@ def bracket_split_build_line( result.append(leaf, preformatted=True) for comment_after in original.comments_after(leaf): result.append(comment_after, preformatted=True) - if is_body: - result.should_explode = should_explode(result, opening_bracket) + if is_body and should_split_body_explode(result, opening_bracket): + result.should_explode = True return result @@ -4966,7 +4934,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ and current_line.leaves[-1].type != token.COMMA and current_line.leaves[-1].type != STANDALONE_COMMENT ): - current_line.append(Leaf(token.COMMA, ",")) + new_comma = Leaf(token.COMMA, ",") + new_comma.was_checked = True + current_line.append(new_comma) yield current_line @@ -5588,24 +5558,60 @@ def ensure_visible(leaf: Leaf) -> None: leaf.value = ")" -def should_explode(line: Line, opening_bracket: Leaf) -> bool: +def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool: """Should `line` immediately be split with `delimiter_split()` after RHS?""" - if not ( - opening_bracket.parent - and opening_bracket.parent.type in {syms.atom, syms.import_from} - and opening_bracket.value in "[{(" - ): + if not (opening_bracket.parent and opening_bracket.value in "[{("): return False + # We're essentially checking if the body is delimited by commas and there's more + # than one of them (we're excluding the trailing comma and if the delimiter priority + # is still commas, that means there's more). + exclude = set() + pre_existing_trailing_comma = False try: last_leaf = line.leaves[-1] - exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set() + if last_leaf.type == token.COMMA: + pre_existing_trailing_comma = not last_leaf.was_checked + exclude.add(id(last_leaf)) max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude) except (IndexError, ValueError): return False - return max_priority == COMMA_PRIORITY + return max_priority == COMMA_PRIORITY and ( + # always explode imports + opening_bracket.parent.type in {syms.atom, syms.import_from} + or pre_existing_trailing_comma + ) + + +def is_one_tuple_between(opening: Leaf, closing: Leaf, leaves: List[Leaf]) -> bool: + """Return True if content between `opening` and `closing` looks like a one-tuple.""" + depth = closing.bracket_depth + 1 + for _opening_index, leaf in enumerate(leaves): + if leaf is opening: + break + + else: + raise LookupError("Opening paren not found in `leaves`") + + commas = 0 + _opening_index += 1 + for leaf in leaves[_opening_index:]: + if leaf is closing: + break + + bracket_depth = leaf.bracket_depth + if bracket_depth == depth and leaf.type == token.COMMA: + commas += 1 + if leaf.parent and leaf.parent.type in { + syms.arglist, + syms.typedargslist, + }: + commas += 1 + break + + return commas < 2 def get_features_used(node: Node) -> Set[Feature]: diff --git a/src/blib2to3/pgen2/driver.py b/src/blib2to3/pgen2/driver.py index 052c948..81940f7 100644 --- a/src/blib2to3/pgen2/driver.py +++ b/src/blib2to3/pgen2/driver.py @@ -128,7 +128,7 @@ class Driver(object): return self.parse_stream_raw(stream, debug) def parse_file( - self, filename: Path, encoding: Optional[Text] = None, debug: bool = False, + self, filename: Path, encoding: Optional[Text] = None, debug: bool = False ) -> NL: """Parse a file and return the syntax tree.""" with io.open(filename, "r", encoding=encoding) as stream: diff --git a/tests/data/collections.py b/tests/data/collections.py index ebe8d3c..6843166 100644 --- a/tests/data/collections.py +++ b/tests/data/collections.py @@ -2,18 +2,18 @@ import core, time, a from . import A, B, C -# unwraps +# keeps existing trailing comma from foo import ( bar, ) -# stays wrapped +# also keeps existing structure from foo import ( baz, qux, ) -# as doesn't get confusing when unwrapped +# `as` works as well from foo import ( xyzzy as magic, ) @@ -77,17 +77,21 @@ import core, time, a from . import A, B, C -# unwraps -from foo import bar +# keeps existing trailing comma +from foo import ( + bar, +) -# stays wrapped +# also keeps existing structure from foo import ( baz, qux, ) -# as doesn't get confusing when unwrapped -from foo import xyzzy as magic +# `as` works as well +from foo import ( + xyzzy as magic, +) a = { 1, @@ -151,11 +155,20 @@ if True: if True: ec2client.get_waiter("instance_stopped").wait( - InstanceIds=[instance.id], WaiterConfig={"Delay": 5,} + InstanceIds=[instance.id], + WaiterConfig={ + "Delay": 5, + }, ) ec2client.get_waiter("instance_stopped").wait( - InstanceIds=[instance.id], WaiterConfig={"Delay": 5,}, + InstanceIds=[instance.id], + WaiterConfig={ + "Delay": 5, + }, ) ec2client.get_waiter("instance_stopped").wait( - InstanceIds=[instance.id], WaiterConfig={"Delay": 5,}, + InstanceIds=[instance.id], + WaiterConfig={ + "Delay": 5, + }, ) diff --git a/tests/data/comments2.py b/tests/data/comments2.py index 89c2910..221cb3f 100644 --- a/tests/data/comments2.py +++ b/tests/data/comments2.py @@ -316,7 +316,13 @@ short ) -CONFIG_FILES = [CONFIG_FILE,] + SHARED_CONFIG_FILES + USER_CONFIG_FILES # type: Final +CONFIG_FILES = ( + [ + CONFIG_FILE, + ] + + SHARED_CONFIG_FILES + + USER_CONFIG_FILES +) # type: Final class Test: diff --git a/tests/data/comments7.py b/tests/data/comments7.py index 436df1a..a7bd281 100644 --- a/tests/data/comments7.py +++ b/tests/data/comments7.py @@ -97,7 +97,14 @@ result = ( # aaa def func(): - c = call(0.0123, 0.0456, 0.0789, 0.0123, 0.0789, a[-1],) # type: ignore + c = call( + 0.0123, + 0.0456, + 0.0789, + 0.0123, + 0.0789, + a[-1], # type: ignore + ) # The type: ignore exception only applies to line length, not # other types of formatting. diff --git a/tests/data/expression.diff b/tests/data/expression.diff index f47ee1c..684f92c 100644 --- a/tests/data/expression.diff +++ b/tests/data/expression.diff @@ -130,15 +130,21 @@ call(**self.screen_kwargs) call(b, **self.screen_kwargs) lukasz.langa.pl -@@ -94,23 +127,25 @@ +@@ -94,26 +127,29 @@ 1.0 .real ....__class__ list[str] dict[str, int] tuple[str, ...] ++tuple[str, int, float, dict[str, int]] + tuple[ +- str, int, float, dict[str, int] +-] -tuple[str, int, float, dict[str, int],] -+tuple[ -+ str, int, float, dict[str, int], ++ str, ++ int, ++ float, ++ dict[str, int], +] very_long_variable_name_filters: t.List[ t.Tuple[str, t.Union[str, t.List[t.Optional[str]]]], @@ -160,7 +166,7 @@ slice[0:1:2] slice[:] slice[:-1] -@@ -134,112 +169,170 @@ +@@ -137,113 +173,180 @@ numpy[-(c + 1) :, d] numpy[:, l[-2]] numpy[:, ::-1] @@ -200,6 +206,7 @@ g = 1, *"ten" -what_is_up_with_those_new_coord_names = (coord_names + set(vars_to_create)) + set(vars_to_remove) -what_is_up_with_those_new_coord_names = (coord_names | set(vars_to_create)) - set(vars_to_remove) +-result = session.query(models.Customer.id).filter(models.Customer.account_id == account_id, models.Customer.email == email_address).order_by(models.Customer.id.asc()).all() -result = session.query(models.Customer.id).filter(models.Customer.account_id == account_id, models.Customer.email == email_address).order_by(models.Customer.id.asc(),).all() +what_is_up_with_those_new_coord_names = (coord_names + set(vars_to_create)) + set( + vars_to_remove @@ -212,7 +219,17 @@ + .filter( + models.Customer.account_id == account_id, models.Customer.email == email_address + ) -+ .order_by(models.Customer.id.asc(),) ++ .order_by(models.Customer.id.asc()) ++ .all() ++) ++result = ( ++ session.query(models.Customer.id) ++ .filter( ++ models.Customer.account_id == account_id, models.Customer.email == email_address ++ ) ++ .order_by( ++ models.Customer.id.asc(), ++ ) + .all() +) Ø = set() diff --git a/tests/data/expression.py b/tests/data/expression.py index 6a04db8..8e63bdc 100644 --- a/tests/data/expression.py +++ b/tests/data/expression.py @@ -96,6 +96,9 @@ call.me(maybe) list[str] dict[str, int] tuple[str, ...] +tuple[ + str, int, float, dict[str, int] +] tuple[str, int, float, dict[str, int],] very_long_variable_name_filters: t.List[ t.Tuple[str, t.Union[str, t.List[t.Optional[str]]]], @@ -157,6 +160,7 @@ f = 1, *range(10) g = 1, *"ten" what_is_up_with_those_new_coord_names = (coord_names + set(vars_to_create)) + set(vars_to_remove) what_is_up_with_those_new_coord_names = (coord_names | set(vars_to_create)) - set(vars_to_remove) +result = session.query(models.Customer.id).filter(models.Customer.account_id == account_id, models.Customer.email == email_address).order_by(models.Customer.id.asc()).all() result = session.query(models.Customer.id).filter(models.Customer.account_id == account_id, models.Customer.email == email_address).order_by(models.Customer.id.asc(),).all() Ø = set() authors.łukasz.say_thanks() @@ -379,8 +383,12 @@ call.me(maybe) list[str] dict[str, int] tuple[str, ...] +tuple[str, int, float, dict[str, int]] tuple[ - str, int, float, dict[str, int], + str, + int, + float, + dict[str, int], ] very_long_variable_name_filters: t.List[ t.Tuple[str, t.Union[str, t.List[t.Optional[str]]]], @@ -459,7 +467,17 @@ result = ( .filter( models.Customer.account_id == account_id, models.Customer.email == email_address ) - .order_by(models.Customer.id.asc(),) + .order_by(models.Customer.id.asc()) + .all() +) +result = ( + session.query(models.Customer.id) + .filter( + models.Customer.account_id == account_id, models.Customer.email == email_address + ) + .order_by( + models.Customer.id.asc(), + ) .all() ) Ø = set() diff --git a/tests/data/fmtonoff4.py b/tests/data/fmtonoff4.py index 54673c0..4ca7079 100644 --- a/tests/data/fmtonoff4.py +++ b/tests/data/fmtonoff4.py @@ -25,7 +25,12 @@ def f(): @test( - [1, 2, 3, 4,] + [ + 1, + 2, + 3, + 4, + ] ) def f(): pass diff --git a/tests/data/function.py b/tests/data/function.py index 51234a1..2d642c8 100644 --- a/tests/data/function.py +++ b/tests/data/function.py @@ -230,7 +230,10 @@ def trailing_comma(): } -def f(a, **kwargs,) -> A: +def f( + a, + **kwargs, +) -> A: return ( yield from A( very_long_argument_name1=very_long_value_for_the_argument, diff --git a/tests/data/function2.py b/tests/data/function2.py index a6773d4..cfc259e 100644 --- a/tests/data/function2.py +++ b/tests/data/function2.py @@ -25,7 +25,10 @@ def h(): # output -def f(a, **kwargs,) -> A: +def f( + a, + **kwargs, +) -> A: with cache_dir(): if something: result = CliRunner().invoke( diff --git a/tests/data/function_trailing_comma.py b/tests/data/function_trailing_comma.py index fcd81ad..314a56c 100644 --- a/tests/data/function_trailing_comma.py +++ b/tests/data/function_trailing_comma.py @@ -1,25 +1,67 @@ def f(a,): - ... + d = {'key': 'value',} + tup = (1,) + +def f2(a,b,): + d = {'key': 'value', 'key2': 'value2',} + tup = (1,2,) def f(a:int=1,): - ... + call(arg={'explode': 'this',}) + call2(arg=[1,2,3],) def xxxxxxxxxxxxxxxxxxxxxxxxxxxx() -> Set[ "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" ]: - pass + json = {"k": {"k2": {"k3": [1,]}}} # output -def f(a,): - ... +def f( + a, +): + d = { + "key": "value", + } + tup = (1,) + + +def f2( + a, + b, +): + d = { + "key": "value", + "key2": "value2", + } + tup = ( + 1, + 2, + ) -def f(a: int = 1,): - ... +def f( + a: int = 1, +): + call( + arg={ + "explode": "this", + } + ) + call2( + arg=[1, 2, 3], + ) def xxxxxxxxxxxxxxxxxxxxxxxxxxxx() -> Set[ "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" ]: - pass + json = { + "k": { + "k2": { + "k3": [ + 1, + ] + } + } + } \ No newline at end of file diff --git a/tests/data/function_trailing_comma_wip.py b/tests/data/function_trailing_comma_wip.py new file mode 100644 index 0000000..c41fc70 --- /dev/null +++ b/tests/data/function_trailing_comma_wip.py @@ -0,0 +1,5 @@ +CONFIG_FILES = [CONFIG_FILE] + SHARED_CONFIG_FILES + USER_CONFIG_FILES # type: Final + +# output + +CONFIG_FILES = [CONFIG_FILE] + SHARED_CONFIG_FILES + USER_CONFIG_FILES # type: Final \ No newline at end of file diff --git a/tests/data/import_spacing.py b/tests/data/import_spacing.py index 51cfda2..8e6e23c 100644 --- a/tests/data/import_spacing.py +++ b/tests/data/import_spacing.py @@ -2,6 +2,9 @@ # flake8: noqa +from logging import ( + WARNING +) from logging import ( ERROR, ) @@ -53,7 +56,10 @@ __all__ = ( # flake8: noqa -from logging import ERROR +from logging import WARNING +from logging import ( + ERROR, +) import sys # This relies on each of the submodules having an __all__ variable. diff --git a/tests/data/long_strings.py b/tests/data/long_strings.py index 5da460b..e1ed90f 100644 --- a/tests/data/long_strings.py +++ b/tests/data/long_strings.py @@ -137,6 +137,20 @@ func_with_bad_comma( ), # comment after comma ) +func_with_bad_parens_that_wont_fit_in_one_line( + ("short string that should have parens stripped"), + x, + y, + z +) + +func_with_bad_parens_that_wont_fit_in_one_line( + x, + y, + ("short string that should have parens stripped"), + z +) + func_with_bad_parens( ("short string that should have parens stripped"), x, @@ -487,12 +501,26 @@ func_with_bad_comma( " which should NOT be there.", # comment after comma ) +func_with_bad_parens_that_wont_fit_in_one_line( + "short string that should have parens stripped", x, y, z +) + +func_with_bad_parens_that_wont_fit_in_one_line( + x, y, "short string that should have parens stripped", z +) + func_with_bad_parens( - "short string that should have parens stripped", x, y, z, + "short string that should have parens stripped", + x, + y, + z, ) func_with_bad_parens( - x, y, "short string that should have parens stripped", z, + x, + y, + "short string that should have parens stripped", + z, ) annotated_variable: Final = ( diff --git a/tests/data/long_strings__regression.py b/tests/data/long_strings__regression.py index 8dbc58a..044bb4a 100644 --- a/tests/data/long_strings__regression.py +++ b/tests/data/long_strings__regression.py @@ -528,17 +528,23 @@ class A: xxxxxxxx = [ xxxxxxxxxxxxxxxx( "xxxx", - xxxxxxxxxxx={"xxxx": 1.0,}, + xxxxxxxxxxx={ + "xxxx": 1.0, + }, xxxxxx={"xxxxxx 1": xxxxxx(xxxx="xxxxxx 1", xxxxxx=600.0)}, xxxxxxxx_xxxxxxx=0.0, ), xxxxxxxxxxxxxxxx( "xxxxxxx", - xxxxxxxxxxx={"xxxx": 1.0,}, + xxxxxxxxxxx={ + "xxxx": 1.0, + }, xxxxxx={"xxxxxx 1": xxxxxx(xxxx="xxxxxx 1", xxxxxx=200.0)}, xxxxxxxx_xxxxxxx=0.0, ), - xxxxxxxxxxxxxxxx("xxxx",), + xxxxxxxxxxxxxxxx( + "xxxx", + ), ] diff --git a/tests/data/long_strings_flag_disabled.py b/tests/data/long_strings_flag_disabled.py index 1ea864d..db3954e 100644 --- a/tests/data/long_strings_flag_disabled.py +++ b/tests/data/long_strings_flag_disabled.py @@ -225,12 +225,26 @@ func_with_bad_comma( ), # comment after comma ) +func_with_bad_parens_that_wont_fit_in_one_line( + ("short string that should have parens stripped"), x, y, z +) + +func_with_bad_parens_that_wont_fit_in_one_line( + x, y, ("short string that should have parens stripped"), z +) + func_with_bad_parens( - ("short string that should have parens stripped"), x, y, z, + ("short string that should have parens stripped"), + x, + y, + z, ) func_with_bad_parens( - x, y, ("short string that should have parens stripped"), z, + x, + y, + ("short string that should have parens stripped"), + z, ) annotated_variable: Final = ( diff --git a/tests/test_black.py b/tests/test_black.py index 686232a..7793b0e 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -5,13 +5,25 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from dataclasses import replace from functools import partial +import inspect from io import BytesIO, TextIOWrapper import os from pathlib import Path import regex as re import sys from tempfile import TemporaryDirectory -from typing import Any, BinaryIO, Dict, Generator, List, Tuple, Iterator, TypeVar +import types +from typing import ( + Any, + BinaryIO, + Callable, + Dict, + Generator, + List, + Tuple, + Iterator, + TypeVar, +) import unittest from unittest.mock import patch, MagicMock @@ -153,6 +165,7 @@ class BlackRunner(CliRunner): class BlackTestCase(unittest.TestCase): maxDiff = None + _diffThreshold = 2 ** 20 def assertFormatEqual(self, expected: str, actual: str) -> None: if actual != expected and not os.environ.get("SKIP_AST_PRINT"): @@ -171,7 +184,7 @@ class BlackTestCase(unittest.TestCase): list(bdv.visit(exp_node)) except Exception as ve: black.err(str(ve)) - self.assertEqual(expected, actual) + self.assertMultiLineEqual(expected, actual) def invokeBlack( self, args: List[str], exit_code: int = 0, ignore_config: bool = True @@ -332,6 +345,16 @@ class BlackTestCase(unittest.TestCase): black.assert_equivalent(source, actual) black.assert_stable(source, actual, DEFAULT_MODE) + @patch("black.dump_to_file", dump_to_stderr) + def test_function_trailing_comma_wip(self) -> None: + source, expected = read_data("function_trailing_comma_wip") + # sys.settrace(tracefunc) + actual = fs(source) + # sys.settrace(None) + self.assertFormatEqual(expected, actual) + black.assert_equivalent(source, actual) + black.assert_stable(source, actual, black.FileMode()) + @patch("black.dump_to_file", dump_to_stderr) def test_function_trailing_comma(self) -> None: source, expected = read_data("function_trailing_comma") @@ -2039,5 +2062,30 @@ class BlackDTestCase(AioHTTPTestCase): self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER)) +with open(black.__file__, "r") as _bf: + black_source_lines = _bf.readlines() + + +def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable: + """Show function calls `from black/__init__.py` as they happen. + + Register this with `sys.settrace()` in a test you're debugging. + """ + if event != "call": + return tracefunc + + stack = len(inspect.stack()) - 19 + filename = frame.f_code.co_filename + lineno = frame.f_lineno + func_sig_lineno = lineno - 1 + funcname = black_source_lines[func_sig_lineno].strip() + while funcname.startswith("@"): + func_sig_lineno += 1 + funcname = black_source_lines[func_sig_lineno].strip() + if "black/__init__.py" in filename: + print(f"{' ' * stack}{lineno}:{funcname}") + return tracefunc + + if __name__ == "__main__": unittest.main(module="test_black") -- 2.39.5