From a20a3eeb0f738d3434efe3be8932db11722757a4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=C5=81ukasz=20Langa?= Date: Sat, 31 Mar 2018 02:24:01 -0700 Subject: [PATCH 1/1] Support --diff for both files and stdin Fixes #87 --- MANIFEST.in | 2 +- README.md | 8 +- black.py | 91 +++++++++++++--- tests/expression.diff | 238 ++++++++++++++++++++++++++++++++++++++++++ tests/test_black.py | 54 +++++++++- 5 files changed, 371 insertions(+), 22 deletions(-) create mode 100644 tests/expression.diff diff --git a/MANIFEST.in b/MANIFEST.in index e097e40..3d95555 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ include *.rst *.md LICENSE recursive-include blib2to3 *.txt *.py -recursive-include tests *.txt *.out *.py +recursive-include tests *.txt *.out *.diff *.py diff --git a/README.md b/README.md index 205163f..a898efa 100644 --- a/README.md +++ b/README.md @@ -53,11 +53,13 @@ black [OPTIONS] [SRC]... Options: -l, --line-length INTEGER Where to wrap around. [default: 88] - --check Don't write back the files, just return the + --check Don't write the files back, just return the status. Return code 0 means nothing would change. Return code 1 means some files would be reformatted. Return code 123 means there was an internal error. + --diff Don't write the files back, just output a diff + for each file on stdout. --fast / --safe If --fast given, skip temporary sanity checks. [default: --safe] --version Show the version and exit. @@ -394,8 +396,10 @@ More details can be found in [CONTRIBUTING](CONTRIBUTING.md). ### 18.3a5 (unreleased) +* added `--diff` (#87) + * add line breaks before all delimiters, except in cases like commas, to better - comply with PEP8 (#73) + comply with PEP 8 (#73) * fixed handling of standalone comments within nested bracketed expressions; Black will no longer produce super long lines or put all diff --git a/black.py b/black.py index 3bb83da..c2c118a 100644 --- a/black.py +++ b/black.py @@ -3,15 +3,18 @@ import asyncio from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor +from enum import Enum from functools import partial, wraps import keyword import logging +from multiprocessing import Manager import os from pathlib import Path import tokenize import signal import sys from typing import ( + Any, Callable, Dict, Generic, @@ -92,6 +95,12 @@ class FormatOff(FormatError): """Found a comment like `# fmt: off` in the file.""" +class WriteBack(Enum): + NO = 0 + YES = 1 + DIFF = 2 + + @click.command() @click.option( "-l", @@ -105,11 +114,16 @@ class FormatOff(FormatError): "--check", is_flag=True, help=( - "Don't write back the files, just return the status. Return code 0 " + "Don't write the files back, just return the status. Return code 0 " "means nothing would change. Return code 1 means some files would be " "reformatted. Return code 123 means there was an internal error." ), ) +@click.option( + "--diff", + is_flag=True, + help="Don't write the files back, just output a diff for each file on stdout.", +) @click.option( "--fast/--safe", is_flag=True, @@ -125,7 +139,12 @@ class FormatOff(FormatError): ) @click.pass_context def main( - ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str] + ctx: click.Context, + line_length: int, + check: bool, + diff: bool, + fast: bool, + src: List[str], ) -> None: """The uncompromising code formatter.""" sources: List[Path] = [] @@ -140,6 +159,17 @@ def main( sources.append(Path("-")) else: err(f"invalid path: {s}") + if check and diff: + exc = click.ClickException("Options --check and --diff are mutually exclusive") + exc.exit_code = 2 + raise exc + + if check: + write_back = WriteBack.NO + elif diff: + write_back = WriteBack.DIFF + else: + write_back = WriteBack.YES if len(sources) == 0: ctx.exit(0) elif len(sources) == 1: @@ -148,11 +178,11 @@ def main( try: if not p.is_file() and str(p) == "-": changed = format_stdin_to_stdout( - line_length=line_length, fast=fast, write_back=not check + line_length=line_length, fast=fast, write_back=write_back ) else: changed = format_file_in_place( - p, line_length=line_length, fast=fast, write_back=not check + p, line_length=line_length, fast=fast, write_back=write_back ) report.done(p, changed) except Exception as exc: @@ -165,7 +195,7 @@ def main( try: return_code = loop.run_until_complete( schedule_formatting( - sources, line_length, not check, fast, loop, executor + sources, line_length, write_back, fast, loop, executor ) ) finally: @@ -176,7 +206,7 @@ def main( async def schedule_formatting( sources: List[Path], line_length: int, - write_back: bool, + write_back: WriteBack, fast: bool, loop: BaseEventLoop, executor: Executor, @@ -188,9 +218,15 @@ async def schedule_formatting( `line_length`, `write_back`, and `fast` options are passed to :func:`format_file_in_place`. """ + lock = None + if write_back == WriteBack.DIFF: + # For diff output, we need locks to ensure we don't interleave output + # from different processes. + manager = Manager() + lock = manager.Lock() tasks = { src: loop.run_in_executor( - executor, format_file_in_place, src, line_length, fast, write_back + executor, format_file_in_place, src, line_length, fast, write_back, lock ) for src in sources } @@ -220,7 +256,11 @@ async def schedule_formatting( def format_file_in_place( - src: Path, line_length: int, fast: bool, write_back: bool = False + src: Path, + line_length: int, + fast: bool, + write_back: WriteBack = WriteBack.NO, + lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy ) -> bool: """Format file under `src` path. Return True if changed. @@ -230,37 +270,53 @@ def format_file_in_place( with tokenize.open(src) as src_buffer: src_contents = src_buffer.read() try: - contents = format_file_contents( + dst_contents = format_file_contents( src_contents, line_length=line_length, fast=fast ) except NothingChanged: return False - if write_back: + if write_back == write_back.YES: with open(src, "w", encoding=src_buffer.encoding) as f: - f.write(contents) + f.write(dst_contents) + elif write_back == write_back.DIFF: + src_name = f"{src.name} (original)" + dst_name = f"{src.name} (formatted)" + diff_contents = diff(src_contents, dst_contents, src_name, dst_name) + if lock: + lock.acquire() + try: + sys.stdout.write(diff_contents) + finally: + if lock: + lock.release() return True def format_stdin_to_stdout( - line_length: int, fast: bool, write_back: bool = False + line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO ) -> bool: """Format file on stdin. Return True if changed. If `write_back` is True, write reformatted code back to stdout. `line_length` and `fast` arguments are passed to :func:`format_file_contents`. """ - contents = sys.stdin.read() + src = sys.stdin.read() try: - contents = format_file_contents(contents, line_length=line_length, fast=fast) + dst = format_file_contents(src, line_length=line_length, fast=fast) return True except NothingChanged: + dst = src return False finally: - if write_back: - sys.stdout.write(contents) + if write_back == WriteBack.YES: + sys.stdout.write(dst) + elif write_back == WriteBack.DIFF: + src_name = " (original)" + dst_name = " (formatted)" + sys.stdout.write(diff(src, dst, src_name, dst_name)) def format_file_contents( @@ -2064,7 +2120,8 @@ def dump_to_file(*output: str) -> str: ) as f: for lines in output: f.write(lines) - f.write("\n") + if lines and lines[-1] != "\n": + f.write("\n") return f.name diff --git a/tests/expression.diff b/tests/expression.diff new file mode 100644 index 0000000..4cdf803 --- /dev/null +++ b/tests/expression.diff @@ -0,0 +1,238 @@ +--- (original) ++++ (formatted) +@@ -1,8 +1,8 @@ + ... +-'some_string' +-b'\\xa3' ++"some_string" ++b"\\xa3" + Name + None + True + False + 1 +@@ -29,65 +29,74 @@ + ~great + +value + -1 + ~int and not v1 ^ 123 + v2 | True + (~int) and (not ((v1 ^ (123 + v2)) | True)) +-flags & ~ select.EPOLLIN and waiters.write_task is not None ++flags & ~select.EPOLLIN and waiters.write_task is not None + lambda arg: None + lambda a=True: a + lambda a, b, c=True: a +-lambda a, b, c=True, *, d=(1 << v2), e='str': a +-lambda a, b, c=True, *vararg, d=(v1 << 2), e='str', **kwargs: a + b ++lambda a, b, c=True, *, d=(1 << v2), e="str": a ++lambda a, b, c=True, *vararg, d=(v1 << 2), e="str", **kwargs: a + b + 1 if True else 2 + str or None if True else str or bytes or None + (str or None) if True else (str or bytes or None) + str or None if (1 if True else 2) else str or bytes or None + (str or None) if (1 if True else 2) else (str or bytes or None) +-{'2.7': dead, '3.7': (long_live or die_hard)} +-{'2.7': dead, '3.7': (long_live or die_hard), **{'3.6': verygood}} ++{"2.7": dead, "3.7": (long_live or die_hard)} ++{"2.7": dead, "3.7": (long_live or die_hard), **{"3.6": verygood}} + {**a, **b, **c} +-{'2.7', '3.6', '3.7', '3.8', '3.9', ('4.0' if gilectomy else '3.10')} +-({'a': 'b'}, (True or False), (+value), 'string', b'bytes') or None ++{"2.7", "3.6", "3.7", "3.8", "3.9", ("4.0" if gilectomy else "3.10")} ++({"a": "b"}, (True or False), (+value), "string", b"bytes") or None + () + (1,) + (1, 2) + (1, 2, 3) + [] + [1, 2, 3, 4, 5, 6, 7, 8, 9, (10 or A), (11 or B), (12 or C)] +-[1, 2, 3,] ++[1, 2, 3] + {i for i in (1, 2, 3)} + {(i ** 2) for i in (1, 2, 3)} +-{(i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))} ++{(i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c"))} + {((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)} + [i for i in (1, 2, 3)] + [(i ** 2) for i in (1, 2, 3)] +-[(i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))] ++[(i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c"))] + [((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)] + {i: 0 for i in (1, 2, 3)} +-{i: j for i, j in ((1, 'a'), (2, 'b'), (3, 'c'))} ++{i: j for i, j in ((1, "a"), (2, "b"), (3, "c"))} + {a: b * 2 for a, b in dictionary.items()} + {a: b * -2 for a, b in dictionary.items()} +-{k: v for k, v in this_is_a_very_long_variable_which_will_cause_a_trailing_comma_which_breaks_the_comprehension} ++{ ++ k: v ++ for k, v in this_is_a_very_long_variable_which_will_cause_a_trailing_comma_which_breaks_the_comprehension ++} + Python3 > Python2 > COBOL + Life is Life + call() + call(arg) +-call(kwarg='hey') +-call(arg, kwarg='hey') +-call(arg, another, kwarg='hey', **kwargs) +-call(this_is_a_very_long_variable_which_will_force_a_delimiter_split, arg, another, kwarg='hey', **kwargs) # note: no trailing comma pre-3.6 ++call(kwarg="hey") ++call(arg, kwarg="hey") ++call(arg, another, kwarg="hey", **kwargs) ++call( ++ this_is_a_very_long_variable_which_will_force_a_delimiter_split, ++ arg, ++ another, ++ kwarg="hey", ++ **kwargs ++) # note: no trailing comma pre-3.6 + call(*gidgets[:2]) + call(**self.screen_kwargs) + lukasz.langa.pl + call.me(maybe) + 1 .real + 1.0 .real + ....__class__ + list[str] + dict[str, int] + tuple[str, ...] +-tuple[str, int, float, dict[str, int],] ++tuple[str, int, float, dict[str, int]] + very_long_variable_name_filters: t.List[ + t.Tuple[str, t.Union[str, t.List[t.Optional[str]]]], + ] + slice[0] + slice[0:1] +@@ -114,71 +123,90 @@ + numpy[-(c + 1):, d] + numpy[:, l[-2]] + numpy[:, ::-1] + numpy[np.newaxis, :] + (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None) +-{'2.7': dead, '3.7': long_live or die_hard} +-{'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'} ++{"2.7": dead, "3.7": long_live or die_hard} ++{"2.7", "3.6", "3.7", "3.8", "3.9", "4.0" if gilectomy else "3.10"} + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C] + (SomeName) + SomeName + (Good, Bad, Ugly) + (i for i in (1, 2, 3)) + ((i ** 2) for i in (1, 2, 3)) +-((i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))) ++((i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c"))) + (((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)) + (*starred) + a = (1,) + b = 1, + c = 1 + d = (1,) + a + (2,) + e = (1,).count(1) +-what_is_up_with_those_new_coord_names = (coord_names + set(vars_to_create)) + set(vars_to_remove) +-what_is_up_with_those_new_coord_names = (coord_names | set(vars_to_create)) - set(vars_to_remove) +-result = session.query(models.Customer.id).filter(models.Customer.account_id == account_id, models.Customer.email == email_address).order_by(models.Customer.id.asc(),).all() ++what_is_up_with_those_new_coord_names = (coord_names + set(vars_to_create)) + set( ++ vars_to_remove ++) ++what_is_up_with_those_new_coord_names = (coord_names | set(vars_to_create)) - set( ++ vars_to_remove ++) ++result = session.query(models.Customer.id).filter( ++ models.Customer.account_id == account_id, models.Customer.email == email_address ++).order_by( ++ models.Customer.id.asc() ++).all() ++ + + def gen(): + yield from outside_of_generator ++ + a = (yield) ++ + + async def f(): + await some.complicated[0].call(with_args=(True or (1 is not 1))) + +-if ( +- threading.current_thread() != threading.main_thread() and +- threading.current_thread() != threading.main_thread() or +- signal.getsignal(signal.SIGINT) != signal.default_int_handler +-): +- return True +-if ( +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa | +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +-): +- return True +-if ( +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa & +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +-): +- return True +-if ( +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +-): +- return True +-if ( +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa - +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +-): +- return True +-if ( +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa * +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +-): +- return True +-if ( +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa / +- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +-): +- return True ++ ++if ( ++ threading.current_thread() != threading.main_thread() ++ and threading.current_thread() != threading.main_thread() ++ or signal.getsignal(signal.SIGINT) != signal.default_int_handler ++): ++ return True ++ ++if ( ++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++ | aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++): ++ return True ++ ++if ( ++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++ & aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++): ++ return True ++ ++if ( ++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++ + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++): ++ return True ++ ++if ( ++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++ - aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++): ++ return True ++ ++if ( ++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++ * aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++): ++ return True ++ ++if ( ++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++ / aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ++): ++ return True ++ + last_call() + # standalone comment at ENDMARKER + diff --git a/tests/test_black.py b/tests/test_black.py index 30ecaf6..226a119 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -26,7 +26,7 @@ def dump_to_stderr(*output: str) -> str: def read_data(name: str) -> Tuple[str, str]: """read_data('test_name') -> 'input', 'output'""" - if not name.endswith((".py", ".out")): + if not name.endswith((".py", ".out", ".diff")): name += ".py" _input: List[str] = [] _output: List[str] = [] @@ -92,7 +92,9 @@ class BlackTestCase(unittest.TestCase): try: sys.stdin, sys.stdout = StringIO(source), StringIO() sys.stdin.name = "" - black.format_stdin_to_stdout(line_length=ll, fast=True, write_back=True) + black.format_stdin_to_stdout( + line_length=ll, fast=True, write_back=black.WriteBack.YES + ) sys.stdout.seek(0) actual = sys.stdout.read() finally: @@ -101,6 +103,23 @@ class BlackTestCase(unittest.TestCase): black.assert_equivalent(source, actual) black.assert_stable(source, actual, line_length=ll) + def test_piping_diff(self) -> None: + source, _ = read_data("expression.py") + expected, _ = read_data("expression.diff") + hold_stdin, hold_stdout = sys.stdin, sys.stdout + try: + sys.stdin, sys.stdout = StringIO(source), StringIO() + sys.stdin.name = "" + black.format_stdin_to_stdout( + line_length=ll, fast=True, write_back=black.WriteBack.DIFF + ) + sys.stdout.seek(0) + actual = sys.stdout.read() + finally: + sys.stdin, sys.stdout = hold_stdin, hold_stdout + actual = actual.rstrip() + "\n" # the diff output has a trailing space + self.assertEqual(expected, actual) + @patch("black.dump_to_file", dump_to_stderr) def test_setup(self) -> None: source, expected = read_data("../setup") @@ -126,6 +145,37 @@ class BlackTestCase(unittest.TestCase): black.assert_equivalent(source, actual) black.assert_stable(source, actual, line_length=ll) + def test_expression_ff(self) -> None: + source, expected = read_data("expression") + tmp_file = Path(black.dump_to_file(source)) + try: + self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES)) + with open(tmp_file) as f: + actual = f.read() + finally: + os.unlink(tmp_file) + self.assertFormatEqual(expected, actual) + with patch("black.dump_to_file", dump_to_stderr): + black.assert_equivalent(source, actual) + black.assert_stable(source, actual, line_length=ll) + + def test_expression_diff(self) -> None: + source, _ = read_data("expression.py") + expected, _ = read_data("expression.diff") + tmp_file = Path(black.dump_to_file(source)) + hold_stdout = sys.stdout + try: + sys.stdout = StringIO() + self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF)) + sys.stdout.seek(0) + actual = sys.stdout.read() + actual = actual.replace(tmp_file.name, "") + finally: + sys.stdout = hold_stdout + os.unlink(tmp_file) + actual = actual.rstrip() + "\n" # the diff output has a trailing space + self.assertEqual(expected, actual) + @patch("black.dump_to_file", dump_to_stderr) def test_fstring(self) -> None: source, expected = read_data("fstring") -- 2.39.5