]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Support --diff for both files and stdin
authorŁukasz Langa <lukasz@langa.pl>
Sat, 31 Mar 2018 09:24:01 +0000 (02:24 -0700)
committerŁukasz Langa <lukasz@langa.pl>
Sat, 31 Mar 2018 22:46:09 +0000 (15:46 -0700)
Fixes #87

MANIFEST.in
README.md
black.py
tests/expression.diff [new file with mode: 0644]
tests/test_black.py

index e097e409745de58219b7aee7430c45570fe4b9ff..3d9555541fa94a8ee5c99d26e1774fcf9f8dbc72 100644 (file)
@@ -1,3 +1,3 @@
 include *.rst *.md LICENSE
 recursive-include blib2to3 *.txt *.py
 include *.rst *.md LICENSE
 recursive-include blib2to3 *.txt *.py
-recursive-include tests *.txt *.out *.py
+recursive-include tests *.txt *.out *.diff *.py
index 205163f3e158d92cdd6be653fda388980a18c0d2..a898efa88ad70b26b83545695c2de1ccaafc03fe 100644 (file)
--- a/README.md
+++ b/README.md
@@ -53,11 +53,13 @@ black [OPTIONS] [SRC]...
 
 Options:
   -l, --line-length INTEGER   Where to wrap around.  [default: 88]
 
 Options:
   -l, --line-length INTEGER   Where to wrap around.  [default: 88]
-  --check                     Don't write back the files, just return the
+  --check                     Don't write the files back, just return the
                               status.  Return code 0 means nothing would
                               change.  Return code 1 means some files would be
                               reformatted.  Return code 123 means there was an
                               internal error.
                               status.  Return code 0 means nothing would
                               change.  Return code 1 means some files would be
                               reformatted.  Return code 123 means there was an
                               internal error.
+  --diff                      Don't write the files back, just output a diff
+                              for each file on stdout.
   --fast / --safe             If --fast given, skip temporary sanity checks.
                               [default: --safe]
   --version                   Show the version and exit.
   --fast / --safe             If --fast given, skip temporary sanity checks.
                               [default: --safe]
   --version                   Show the version and exit.
@@ -394,8 +396,10 @@ More details can be found in [CONTRIBUTING](CONTRIBUTING.md).
 
 ### 18.3a5 (unreleased)
 
 
 ### 18.3a5 (unreleased)
 
+* added `--diff` (#87)
+
 * add line breaks before all delimiters, except in cases like commas, to better
 * add line breaks before all delimiters, except in cases like commas, to better
-  comply with PEP8 (#73)
+  comply with PEP 8 (#73)
 
 * fixed handling of standalone comments within nested bracketed
   expressions; Black will no longer produce super long lines or put all
 
 * fixed handling of standalone comments within nested bracketed
   expressions; Black will no longer produce super long lines or put all
index 3bb83daef28c396cd7670503d7ad21c7d20336c2..c2c118a4883c979d68439ef2636534d1c952a5a1 100644 (file)
--- a/black.py
+++ b/black.py
@@ -3,15 +3,18 @@
 import asyncio
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 import asyncio
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
+from enum import Enum
 from functools import partial, wraps
 import keyword
 import logging
 from functools import partial, wraps
 import keyword
 import logging
+from multiprocessing import Manager
 import os
 from pathlib import Path
 import tokenize
 import signal
 import sys
 from typing import (
 import os
 from pathlib import Path
 import tokenize
 import signal
 import sys
 from typing import (
+    Any,
     Callable,
     Dict,
     Generic,
     Callable,
     Dict,
     Generic,
@@ -92,6 +95,12 @@ class FormatOff(FormatError):
     """Found a comment like `# fmt: off` in the file."""
 
 
     """Found a comment like `# fmt: off` in the file."""
 
 
+class WriteBack(Enum):
+    NO = 0
+    YES = 1
+    DIFF = 2
+
+
 @click.command()
 @click.option(
     "-l",
 @click.command()
 @click.option(
     "-l",
@@ -105,11 +114,16 @@ class FormatOff(FormatError):
     "--check",
     is_flag=True,
     help=(
     "--check",
     is_flag=True,
     help=(
-        "Don't write back the files, just return the status.  Return code 0 "
+        "Don't write the files back, just return the status.  Return code 0 "
         "means nothing would change.  Return code 1 means some files would be "
         "reformatted.  Return code 123 means there was an internal error."
     ),
 )
         "means nothing would change.  Return code 1 means some files would be "
         "reformatted.  Return code 123 means there was an internal error."
     ),
 )
+@click.option(
+    "--diff",
+    is_flag=True,
+    help="Don't write the files back, just output a diff for each file on stdout.",
+)
 @click.option(
     "--fast/--safe",
     is_flag=True,
 @click.option(
     "--fast/--safe",
     is_flag=True,
@@ -125,7 +139,12 @@ class FormatOff(FormatError):
 )
 @click.pass_context
 def main(
 )
 @click.pass_context
 def main(
-    ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
+    ctx: click.Context,
+    line_length: int,
+    check: bool,
+    diff: bool,
+    fast: bool,
+    src: List[str],
 ) -> None:
     """The uncompromising code formatter."""
     sources: List[Path] = []
 ) -> None:
     """The uncompromising code formatter."""
     sources: List[Path] = []
@@ -140,6 +159,17 @@ def main(
             sources.append(Path("-"))
         else:
             err(f"invalid path: {s}")
             sources.append(Path("-"))
         else:
             err(f"invalid path: {s}")
+    if check and diff:
+        exc = click.ClickException("Options --check and --diff are mutually exclusive")
+        exc.exit_code = 2
+        raise exc
+
+    if check:
+        write_back = WriteBack.NO
+    elif diff:
+        write_back = WriteBack.DIFF
+    else:
+        write_back = WriteBack.YES
     if len(sources) == 0:
         ctx.exit(0)
     elif len(sources) == 1:
     if len(sources) == 0:
         ctx.exit(0)
     elif len(sources) == 1:
@@ -148,11 +178,11 @@ def main(
         try:
             if not p.is_file() and str(p) == "-":
                 changed = format_stdin_to_stdout(
         try:
             if not p.is_file() and str(p) == "-":
                 changed = format_stdin_to_stdout(
-                    line_length=line_length, fast=fast, write_back=not check
+                    line_length=line_length, fast=fast, write_back=write_back
                 )
             else:
                 changed = format_file_in_place(
                 )
             else:
                 changed = format_file_in_place(
-                    p, line_length=line_length, fast=fast, write_back=not check
+                    p, line_length=line_length, fast=fast, write_back=write_back
                 )
             report.done(p, changed)
         except Exception as exc:
                 )
             report.done(p, changed)
         except Exception as exc:
@@ -165,7 +195,7 @@ def main(
         try:
             return_code = loop.run_until_complete(
                 schedule_formatting(
         try:
             return_code = loop.run_until_complete(
                 schedule_formatting(
-                    sources, line_length, not check, fast, loop, executor
+                    sources, line_length, write_back, fast, loop, executor
                 )
             )
         finally:
                 )
             )
         finally:
@@ -176,7 +206,7 @@ def main(
 async def schedule_formatting(
     sources: List[Path],
     line_length: int,
 async def schedule_formatting(
     sources: List[Path],
     line_length: int,
-    write_back: bool,
+    write_back: WriteBack,
     fast: bool,
     loop: BaseEventLoop,
     executor: Executor,
     fast: bool,
     loop: BaseEventLoop,
     executor: Executor,
@@ -188,9 +218,15 @@ async def schedule_formatting(
     `line_length`, `write_back`, and `fast` options are passed to
     :func:`format_file_in_place`.
     """
     `line_length`, `write_back`, and `fast` options are passed to
     :func:`format_file_in_place`.
     """
+    lock = None
+    if write_back == WriteBack.DIFF:
+        # For diff output, we need locks to ensure we don't interleave output
+        # from different processes.
+        manager = Manager()
+        lock = manager.Lock()
     tasks = {
         src: loop.run_in_executor(
     tasks = {
         src: loop.run_in_executor(
-            executor, format_file_in_place, src, line_length, fast, write_back
+            executor, format_file_in_place, src, line_length, fast, write_back, lock
         )
         for src in sources
     }
         )
         for src in sources
     }
@@ -220,7 +256,11 @@ async def schedule_formatting(
 
 
 def format_file_in_place(
 
 
 def format_file_in_place(
-    src: Path, line_length: int, fast: bool, write_back: bool = False
+    src: Path,
+    line_length: int,
+    fast: bool,
+    write_back: WriteBack = WriteBack.NO,
+    lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
 
 ) -> bool:
     """Format file under `src` path. Return True if changed.
 
@@ -230,37 +270,53 @@ def format_file_in_place(
     with tokenize.open(src) as src_buffer:
         src_contents = src_buffer.read()
     try:
     with tokenize.open(src) as src_buffer:
         src_contents = src_buffer.read()
     try:
-        contents = format_file_contents(
+        dst_contents = format_file_contents(
             src_contents, line_length=line_length, fast=fast
         )
     except NothingChanged:
         return False
 
             src_contents, line_length=line_length, fast=fast
         )
     except NothingChanged:
         return False
 
-    if write_back:
+    if write_back == write_back.YES:
         with open(src, "w", encoding=src_buffer.encoding) as f:
         with open(src, "w", encoding=src_buffer.encoding) as f:
-            f.write(contents)
+            f.write(dst_contents)
+    elif write_back == write_back.DIFF:
+        src_name = f"{src.name}  (original)"
+        dst_name = f"{src.name}  (formatted)"
+        diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
+        if lock:
+            lock.acquire()
+        try:
+            sys.stdout.write(diff_contents)
+        finally:
+            if lock:
+                lock.release()
     return True
 
 
 def format_stdin_to_stdout(
     return True
 
 
 def format_stdin_to_stdout(
-    line_length: int, fast: bool, write_back: bool = False
+    line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is True, write reformatted code back to stdout.
     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
     """
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is True, write reformatted code back to stdout.
     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
     """
-    contents = sys.stdin.read()
+    src = sys.stdin.read()
     try:
     try:
-        contents = format_file_contents(contents, line_length=line_length, fast=fast)
+        dst = format_file_contents(src, line_length=line_length, fast=fast)
         return True
 
     except NothingChanged:
         return True
 
     except NothingChanged:
+        dst = src
         return False
 
     finally:
         return False
 
     finally:
-        if write_back:
-            sys.stdout.write(contents)
+        if write_back == WriteBack.YES:
+            sys.stdout.write(dst)
+        elif write_back == WriteBack.DIFF:
+            src_name = "<stdin>  (original)"
+            dst_name = "<stdin>  (formatted)"
+            sys.stdout.write(diff(src, dst, src_name, dst_name))
 
 
 def format_file_contents(
 
 
 def format_file_contents(
@@ -2064,7 +2120,8 @@ def dump_to_file(*output: str) -> str:
     ) as f:
         for lines in output:
             f.write(lines)
     ) as f:
         for lines in output:
             f.write(lines)
-            f.write("\n")
+            if lines and lines[-1] != "\n":
+                f.write("\n")
     return f.name
 
 
     return f.name
 
 
diff --git a/tests/expression.diff b/tests/expression.diff
new file mode 100644 (file)
index 0000000..4cdf803
--- /dev/null
@@ -0,0 +1,238 @@
+--- <stdin>  (original)
++++ <stdin>  (formatted)
+@@ -1,8 +1,8 @@
+ ...
+-'some_string'
+-b'\\xa3'
++"some_string"
++b"\\xa3"
+ Name
+ None
+ True
+ False
+ 1
+@@ -29,65 +29,74 @@
+ ~great
+ +value
+ -1
+ ~int and not v1 ^ 123 + v2 | True
+ (~int) and (not ((v1 ^ (123 + v2)) | True))
+-flags & ~ select.EPOLLIN and waiters.write_task is not None
++flags & ~select.EPOLLIN and waiters.write_task is not None
+ lambda arg: None
+ lambda a=True: a
+ lambda a, b, c=True: a
+-lambda a, b, c=True, *, d=(1 << v2), e='str': a
+-lambda a, b, c=True, *vararg, d=(v1 << 2), e='str', **kwargs: a + b
++lambda a, b, c=True, *, d=(1 << v2), e="str": a
++lambda a, b, c=True, *vararg, d=(v1 << 2), e="str", **kwargs: a + b
+ 1 if True else 2
+ str or None if True else str or bytes or None
+ (str or None) if True else (str or bytes or None)
+ str or None if (1 if True else 2) else str or bytes or None
+ (str or None) if (1 if True else 2) else (str or bytes or None)
+-{'2.7': dead, '3.7': (long_live or die_hard)}
+-{'2.7': dead, '3.7': (long_live or die_hard), **{'3.6': verygood}}
++{"2.7": dead, "3.7": (long_live or die_hard)}
++{"2.7": dead, "3.7": (long_live or die_hard), **{"3.6": verygood}}
+ {**a, **b, **c}
+-{'2.7', '3.6', '3.7', '3.8', '3.9', ('4.0' if gilectomy else '3.10')}
+-({'a': 'b'}, (True or False), (+value), 'string', b'bytes') or None
++{"2.7", "3.6", "3.7", "3.8", "3.9", ("4.0" if gilectomy else "3.10")}
++({"a": "b"}, (True or False), (+value), "string", b"bytes") or None
+ ()
+ (1,)
+ (1, 2)
+ (1, 2, 3)
+ []
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, (10 or A), (11 or B), (12 or C)]
+-[1, 2, 3,]
++[1, 2, 3]
+ {i for i in (1, 2, 3)}
+ {(i ** 2) for i in (1, 2, 3)}
+-{(i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))}
++{(i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c"))}
+ {((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)}
+ [i for i in (1, 2, 3)]
+ [(i ** 2) for i in (1, 2, 3)]
+-[(i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))]
++[(i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c"))]
+ [((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)]
+ {i: 0 for i in (1, 2, 3)}
+-{i: j for i, j in ((1, 'a'), (2, 'b'), (3, 'c'))}
++{i: j for i, j in ((1, "a"), (2, "b"), (3, "c"))}
+ {a: b * 2 for a, b in dictionary.items()}
+ {a: b * -2 for a, b in dictionary.items()}
+-{k: v for k, v in this_is_a_very_long_variable_which_will_cause_a_trailing_comma_which_breaks_the_comprehension}
++{
++    k: v
++    for k, v in this_is_a_very_long_variable_which_will_cause_a_trailing_comma_which_breaks_the_comprehension
++}
+ Python3 > Python2 > COBOL
+ Life is Life
+ call()
+ call(arg)
+-call(kwarg='hey')
+-call(arg, kwarg='hey')
+-call(arg, another, kwarg='hey', **kwargs)
+-call(this_is_a_very_long_variable_which_will_force_a_delimiter_split, arg, another, kwarg='hey', **kwargs)  # note: no trailing comma pre-3.6
++call(kwarg="hey")
++call(arg, kwarg="hey")
++call(arg, another, kwarg="hey", **kwargs)
++call(
++    this_is_a_very_long_variable_which_will_force_a_delimiter_split,
++    arg,
++    another,
++    kwarg="hey",
++    **kwargs
++)  # note: no trailing comma pre-3.6
+ call(*gidgets[:2])
+ call(**self.screen_kwargs)
+ lukasz.langa.pl
+ call.me(maybe)
+ 1 .real
+ 1.0 .real
+ ....__class__
+ list[str]
+ dict[str, int]
+ tuple[str, ...]
+-tuple[str, int, float, dict[str, int],]
++tuple[str, int, float, dict[str, int]]
+ very_long_variable_name_filters: t.List[
+     t.Tuple[str, t.Union[str, t.List[t.Optional[str]]]],
+ ]
+ slice[0]
+ slice[0:1]
+@@ -114,71 +123,90 @@
+ numpy[-(c + 1):, d]
+ numpy[:, l[-2]]
+ numpy[:, ::-1]
+ numpy[np.newaxis, :]
+ (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None)
+-{'2.7': dead, '3.7': long_live or die_hard}
+-{'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'}
++{"2.7": dead, "3.7": long_live or die_hard}
++{"2.7", "3.6", "3.7", "3.8", "3.9", "4.0" if gilectomy else "3.10"}
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C]
+ (SomeName)
+ SomeName
+ (Good, Bad, Ugly)
+ (i for i in (1, 2, 3))
+ ((i ** 2) for i in (1, 2, 3))
+-((i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c')))
++((i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c")))
+ (((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3))
+ (*starred)
+ a = (1,)
+ b = 1,
+ c = 1
+ d = (1,) + a + (2,)
+ e = (1,).count(1)
+-what_is_up_with_those_new_coord_names = (coord_names + set(vars_to_create)) + set(vars_to_remove)
+-what_is_up_with_those_new_coord_names = (coord_names | set(vars_to_create)) - set(vars_to_remove)
+-result = session.query(models.Customer.id).filter(models.Customer.account_id == account_id, models.Customer.email == email_address).order_by(models.Customer.id.asc(),).all()
++what_is_up_with_those_new_coord_names = (coord_names + set(vars_to_create)) + set(
++    vars_to_remove
++)
++what_is_up_with_those_new_coord_names = (coord_names | set(vars_to_create)) - set(
++    vars_to_remove
++)
++result = session.query(models.Customer.id).filter(
++    models.Customer.account_id == account_id, models.Customer.email == email_address
++).order_by(
++    models.Customer.id.asc()
++).all()
++
+ def gen():
+     yield from outside_of_generator
++
+     a = (yield)
++
+ async def f():
+     await some.complicated[0].call(with_args=(True or (1 is not 1)))
+-if (
+-    threading.current_thread() != threading.main_thread() and
+-    threading.current_thread() != threading.main_thread() or
+-    signal.getsignal(signal.SIGINT) != signal.default_int_handler
+-):
+-    return True
+-if (
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa |
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+-    return True
+-if (
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa &
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+-    return True
+-if (
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+-    return True
+-if (
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+-    return True
+-if (
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa *
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+-    return True
+-if (
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa /
+-    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+-    return True
++
++if (
++    threading.current_thread() != threading.main_thread()
++    and threading.current_thread() != threading.main_thread()
++    or signal.getsignal(signal.SIGINT) != signal.default_int_handler
++):
++    return True
++
++if (
++    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++    | aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++    return True
++
++if (
++    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++    & aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++    return True
++
++if (
++    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++    + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++    return True
++
++if (
++    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++    - aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++    return True
++
++if (
++    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++    * aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++    return True
++
++if (
++    aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++    / aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++    return True
++
+ last_call()
+ # standalone comment at ENDMARKER
+
index 30ecaf693a4346f0a487074d38969c29ece034d1..226a11913e8f85bc6d5563420572581060b34d06 100644 (file)
@@ -26,7 +26,7 @@ def dump_to_stderr(*output: str) -> str:
 
 def read_data(name: str) -> Tuple[str, str]:
     """read_data('test_name') -> 'input', 'output'"""
 
 def read_data(name: str) -> Tuple[str, str]:
     """read_data('test_name') -> 'input', 'output'"""
-    if not name.endswith((".py", ".out")):
+    if not name.endswith((".py", ".out", ".diff")):
         name += ".py"
     _input: List[str] = []
     _output: List[str] = []
         name += ".py"
     _input: List[str] = []
     _output: List[str] = []
@@ -92,7 +92,9 @@ class BlackTestCase(unittest.TestCase):
         try:
             sys.stdin, sys.stdout = StringIO(source), StringIO()
             sys.stdin.name = "<stdin>"
         try:
             sys.stdin, sys.stdout = StringIO(source), StringIO()
             sys.stdin.name = "<stdin>"
-            black.format_stdin_to_stdout(line_length=ll, fast=True, write_back=True)
+            black.format_stdin_to_stdout(
+                line_length=ll, fast=True, write_back=black.WriteBack.YES
+            )
             sys.stdout.seek(0)
             actual = sys.stdout.read()
         finally:
             sys.stdout.seek(0)
             actual = sys.stdout.read()
         finally:
@@ -101,6 +103,23 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
+    def test_piping_diff(self) -> None:
+        source, _ = read_data("expression.py")
+        expected, _ = read_data("expression.diff")
+        hold_stdin, hold_stdout = sys.stdin, sys.stdout
+        try:
+            sys.stdin, sys.stdout = StringIO(source), StringIO()
+            sys.stdin.name = "<stdin>"
+            black.format_stdin_to_stdout(
+                line_length=ll, fast=True, write_back=black.WriteBack.DIFF
+            )
+            sys.stdout.seek(0)
+            actual = sys.stdout.read()
+        finally:
+            sys.stdin, sys.stdout = hold_stdin, hold_stdout
+        actual = actual.rstrip() + "\n"  # the diff output has a trailing space
+        self.assertEqual(expected, actual)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_setup(self) -> None:
         source, expected = read_data("../setup")
     @patch("black.dump_to_file", dump_to_stderr)
     def test_setup(self) -> None:
         source, expected = read_data("../setup")
@@ -126,6 +145,37 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
+    def test_expression_ff(self) -> None:
+        source, expected = read_data("expression")
+        tmp_file = Path(black.dump_to_file(source))
+        try:
+            self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
+            with open(tmp_file) as f:
+                actual = f.read()
+        finally:
+            os.unlink(tmp_file)
+        self.assertFormatEqual(expected, actual)
+        with patch("black.dump_to_file", dump_to_stderr):
+            black.assert_equivalent(source, actual)
+            black.assert_stable(source, actual, line_length=ll)
+
+    def test_expression_diff(self) -> None:
+        source, _ = read_data("expression.py")
+        expected, _ = read_data("expression.diff")
+        tmp_file = Path(black.dump_to_file(source))
+        hold_stdout = sys.stdout
+        try:
+            sys.stdout = StringIO()
+            self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
+            sys.stdout.seek(0)
+            actual = sys.stdout.read()
+            actual = actual.replace(tmp_file.name, "<stdin>")
+        finally:
+            sys.stdout = hold_stdout
+            os.unlink(tmp_file)
+        actual = actual.rstrip() + "\n"  # the diff output has a trailing space
+        self.assertEqual(expected, actual)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_fstring(self) -> None:
         source, expected = read_data("fstring")
     @patch("black.dump_to_file", dump_to_stderr)
     def test_fstring(self) -> None:
         source, expected = read_data("fstring")