]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

[#149] Make check and diff not mutually exclusive (#161)
authorVishwas B Sharma <sharma.vishwas88@gmail.com>
Mon, 23 Apr 2018 18:23:11 +0000 (11:23 -0700)
committerŁukasz Langa <lukasz@langa.pl>
Mon, 23 Apr 2018 18:23:11 +0000 (11:23 -0700)
Fixes #149.

black.py
tests/test_black.py

index 58f7976aa60bc72b0ceb72b5a160a6107bf6637b..c77166a2a6c454486c564aa8ad0a3065e3d0edde 100644 (file)
--- a/black.py
+++ b/black.py
@@ -184,12 +184,8 @@ def main(
             sources.append(Path("-"))
         else:
             err(f"invalid path: {s}")
             sources.append(Path("-"))
         else:
             err(f"invalid path: {s}")
-    if check and diff:
-        exc = click.ClickException("Options --check and --diff are mutually exclusive")
-        exc.exit_code = 2
-        raise exc
 
 
-    if check:
+    if check and not diff:
         write_back = WriteBack.NO
     elif diff:
         write_back = WriteBack.DIFF
         write_back = WriteBack.NO
     elif diff:
         write_back = WriteBack.DIFF
@@ -200,7 +196,9 @@ def main(
         return
 
     elif len(sources) == 1:
         return
 
     elif len(sources) == 1:
-        return_code = reformat_one(sources[0], line_length, fast, quiet, write_back)
+        return_code = reformat_one(
+            sources[0], line_length, fast, quiet, write_back, check
+        )
     else:
         loop = asyncio.get_event_loop()
         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
     else:
         loop = asyncio.get_event_loop()
         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
@@ -208,7 +206,7 @@ def main(
         try:
             return_code = loop.run_until_complete(
                 schedule_formatting(
         try:
             return_code = loop.run_until_complete(
                 schedule_formatting(
-                    sources, line_length, write_back, fast, quiet, loop, executor
+                    sources, line_length, write_back, fast, quiet, loop, executor, check
                 )
             )
         finally:
                 )
             )
         finally:
@@ -217,14 +215,19 @@ def main(
 
 
 def reformat_one(
 
 
 def reformat_one(
-    src: Path, line_length: int, fast: bool, quiet: bool, write_back: WriteBack
+    src: Path,
+    line_length: int,
+    fast: bool,
+    quiet: bool,
+    write_back: WriteBack,
+    check: bool,
 ) -> int:
     """Reformat a single file under `src` without spawning child processes.
 
     If `quiet` is True, non-error messages are not output. `line_length`,
     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
     """
 ) -> int:
     """Reformat a single file under `src` without spawning child processes.
 
     If `quiet` is True, non-error messages are not output. `line_length`,
     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
     """
-    report = Report(check=write_back is WriteBack.NO, quiet=quiet)
+    report = Report(check=check, quiet=quiet)
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
@@ -262,6 +265,7 @@ async def schedule_formatting(
     quiet: bool,
     loop: BaseEventLoop,
     executor: Executor,
     quiet: bool,
     loop: BaseEventLoop,
     executor: Executor,
+    check: bool,
 ) -> int:
     """Run formatting of `sources` in parallel using the provided `executor`.
 
 ) -> int:
     """Run formatting of `sources` in parallel using the provided `executor`.
 
@@ -270,7 +274,7 @@ async def schedule_formatting(
     `line_length`, `write_back`, and `fast` options are passed to
     :func:`format_file_in_place`.
     """
     `line_length`, `write_back`, and `fast` options are passed to
     :func:`format_file_in_place`.
     """
-    report = Report(check=write_back is WriteBack.NO, quiet=quiet)
+    report = Report(check=check, quiet=quiet)
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
         cache = read_cache()
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
         cache = read_cache()
index fc310690ccec649a339bad18b53e6d44ccd335c7..b4820531fab62d860808b37448881ccf3b494a78 100644 (file)
@@ -595,6 +595,24 @@ class BlackTestCase(unittest.TestCase):
             mock.side_effect = OSError
             black.write_cache({}, [])
 
             mock.side_effect = OSError
             black.write_cache({}, [])
 
+    def test_check_diff_use_together(self) -> None:
+        with cache_dir():
+            # Files which will be reformatted.
+            src1 = (THIS_DIR / "string_quotes.py").resolve()
+            result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
+            self.assertEqual(result.exit_code, 1)
+
+            # Files which will not be reformatted.
+            src2 = (THIS_DIR / "composition.py").resolve()
+            result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
+            self.assertEqual(result.exit_code, 0)
+
+            # Multi file command.
+            result = CliRunner().invoke(
+                black.main, [str(src1), str(src2), "--diff", "--check"]
+            )
+            self.assertEqual(result.exit_code, 1)
+
 
 if __name__ == "__main__":
     unittest.main()
 
 if __name__ == "__main__":
     unittest.main()