include *.rst *.md LICENSE
recursive-include blib2to3 *.txt *.py
-recursive-include tests *.txt *.out *.py
+recursive-include tests *.txt *.out *.diff *.py
Options:
-l, --line-length INTEGER Where to wrap around. [default: 88]
- --check Don't write back the files, just return the
+ --check Don't write the files back, just return the
status. Return code 0 means nothing would
change. Return code 1 means some files would be
reformatted. Return code 123 means there was an
internal error.
+ --diff Don't write the files back, just output a diff
+ for each file on stdout.
--fast / --safe If --fast given, skip temporary sanity checks.
[default: --safe]
--version Show the version and exit.
### 18.3a5 (unreleased)
+* added `--diff` (#87)
+
* add line breaks before all delimiters, except in cases like commas, to better
- comply with PEP8 (#73)
+ comply with PEP 8 (#73)
* fixed handling of standalone comments within nested bracketed
expressions; Black will no longer produce super long lines or put all
import asyncio
from asyncio.base_events import BaseEventLoop
from concurrent.futures import Executor, ProcessPoolExecutor
+from enum import Enum
from functools import partial, wraps
import keyword
import logging
+from multiprocessing import Manager
import os
from pathlib import Path
import tokenize
import signal
import sys
from typing import (
+ Any,
Callable,
Dict,
Generic,
"""Found a comment like `# fmt: off` in the file."""
+class WriteBack(Enum):
+ NO = 0
+ YES = 1
+ DIFF = 2
+
+
@click.command()
@click.option(
"-l",
"--check",
is_flag=True,
help=(
- "Don't write back the files, just return the status. Return code 0 "
+ "Don't write the files back, just return the status. Return code 0 "
"means nothing would change. Return code 1 means some files would be "
"reformatted. Return code 123 means there was an internal error."
),
)
+@click.option(
+ "--diff",
+ is_flag=True,
+ help="Don't write the files back, just output a diff for each file on stdout.",
+)
@click.option(
"--fast/--safe",
is_flag=True,
)
@click.pass_context
def main(
- ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
+ ctx: click.Context,
+ line_length: int,
+ check: bool,
+ diff: bool,
+ fast: bool,
+ src: List[str],
) -> None:
"""The uncompromising code formatter."""
sources: List[Path] = []
sources.append(Path("-"))
else:
err(f"invalid path: {s}")
+ if check and diff:
+ exc = click.ClickException("Options --check and --diff are mutually exclusive")
+ exc.exit_code = 2
+ raise exc
+
+ if check:
+ write_back = WriteBack.NO
+ elif diff:
+ write_back = WriteBack.DIFF
+ else:
+ write_back = WriteBack.YES
if len(sources) == 0:
ctx.exit(0)
elif len(sources) == 1:
try:
if not p.is_file() and str(p) == "-":
changed = format_stdin_to_stdout(
- line_length=line_length, fast=fast, write_back=not check
+ line_length=line_length, fast=fast, write_back=write_back
)
else:
changed = format_file_in_place(
- p, line_length=line_length, fast=fast, write_back=not check
+ p, line_length=line_length, fast=fast, write_back=write_back
)
report.done(p, changed)
except Exception as exc:
try:
return_code = loop.run_until_complete(
schedule_formatting(
- sources, line_length, not check, fast, loop, executor
+ sources, line_length, write_back, fast, loop, executor
)
)
finally:
async def schedule_formatting(
sources: List[Path],
line_length: int,
- write_back: bool,
+ write_back: WriteBack,
fast: bool,
loop: BaseEventLoop,
executor: Executor,
`line_length`, `write_back`, and `fast` options are passed to
:func:`format_file_in_place`.
"""
+ lock = None
+ if write_back == WriteBack.DIFF:
+ # For diff output, we need locks to ensure we don't interleave output
+ # from different processes.
+ manager = Manager()
+ lock = manager.Lock()
tasks = {
src: loop.run_in_executor(
- executor, format_file_in_place, src, line_length, fast, write_back
+ executor, format_file_in_place, src, line_length, fast, write_back, lock
)
for src in sources
}
def format_file_in_place(
- src: Path, line_length: int, fast: bool, write_back: bool = False
+ src: Path,
+ line_length: int,
+ fast: bool,
+ write_back: WriteBack = WriteBack.NO,
+ lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
) -> bool:
"""Format file under `src` path. Return True if changed.
with tokenize.open(src) as src_buffer:
src_contents = src_buffer.read()
try:
- contents = format_file_contents(
+ dst_contents = format_file_contents(
src_contents, line_length=line_length, fast=fast
)
except NothingChanged:
return False
- if write_back:
+ if write_back == write_back.YES:
with open(src, "w", encoding=src_buffer.encoding) as f:
- f.write(contents)
+ f.write(dst_contents)
+ elif write_back == write_back.DIFF:
+ src_name = f"{src.name} (original)"
+ dst_name = f"{src.name} (formatted)"
+ diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
+ if lock:
+ lock.acquire()
+ try:
+ sys.stdout.write(diff_contents)
+ finally:
+ if lock:
+ lock.release()
return True
def format_stdin_to_stdout(
- line_length: int, fast: bool, write_back: bool = False
+ line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
) -> bool:
"""Format file on stdin. Return True if changed.
If `write_back` is True, write reformatted code back to stdout.
`line_length` and `fast` arguments are passed to :func:`format_file_contents`.
"""
- contents = sys.stdin.read()
+ src = sys.stdin.read()
try:
- contents = format_file_contents(contents, line_length=line_length, fast=fast)
+ dst = format_file_contents(src, line_length=line_length, fast=fast)
return True
except NothingChanged:
+ dst = src
return False
finally:
- if write_back:
- sys.stdout.write(contents)
+ if write_back == WriteBack.YES:
+ sys.stdout.write(dst)
+ elif write_back == WriteBack.DIFF:
+ src_name = "<stdin> (original)"
+ dst_name = "<stdin> (formatted)"
+ sys.stdout.write(diff(src, dst, src_name, dst_name))
def format_file_contents(
) as f:
for lines in output:
f.write(lines)
- f.write("\n")
+ if lines and lines[-1] != "\n":
+ f.write("\n")
return f.name
--- /dev/null
+--- <stdin> (original)
++++ <stdin> (formatted)
+@@ -1,8 +1,8 @@
+ ...
+-'some_string'
+-b'\\xa3'
++"some_string"
++b"\\xa3"
+ Name
+ None
+ True
+ False
+ 1
+@@ -29,65 +29,74 @@
+ ~great
+ +value
+ -1
+ ~int and not v1 ^ 123 + v2 | True
+ (~int) and (not ((v1 ^ (123 + v2)) | True))
+-flags & ~ select.EPOLLIN and waiters.write_task is not None
++flags & ~select.EPOLLIN and waiters.write_task is not None
+ lambda arg: None
+ lambda a=True: a
+ lambda a, b, c=True: a
+-lambda a, b, c=True, *, d=(1 << v2), e='str': a
+-lambda a, b, c=True, *vararg, d=(v1 << 2), e='str', **kwargs: a + b
++lambda a, b, c=True, *, d=(1 << v2), e="str": a
++lambda a, b, c=True, *vararg, d=(v1 << 2), e="str", **kwargs: a + b
+ 1 if True else 2
+ str or None if True else str or bytes or None
+ (str or None) if True else (str or bytes or None)
+ str or None if (1 if True else 2) else str or bytes or None
+ (str or None) if (1 if True else 2) else (str or bytes or None)
+-{'2.7': dead, '3.7': (long_live or die_hard)}
+-{'2.7': dead, '3.7': (long_live or die_hard), **{'3.6': verygood}}
++{"2.7": dead, "3.7": (long_live or die_hard)}
++{"2.7": dead, "3.7": (long_live or die_hard), **{"3.6": verygood}}
+ {**a, **b, **c}
+-{'2.7', '3.6', '3.7', '3.8', '3.9', ('4.0' if gilectomy else '3.10')}
+-({'a': 'b'}, (True or False), (+value), 'string', b'bytes') or None
++{"2.7", "3.6", "3.7", "3.8", "3.9", ("4.0" if gilectomy else "3.10")}
++({"a": "b"}, (True or False), (+value), "string", b"bytes") or None
+ ()
+ (1,)
+ (1, 2)
+ (1, 2, 3)
+ []
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, (10 or A), (11 or B), (12 or C)]
+-[1, 2, 3,]
++[1, 2, 3]
+ {i for i in (1, 2, 3)}
+ {(i ** 2) for i in (1, 2, 3)}
+-{(i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))}
++{(i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c"))}
+ {((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)}
+ [i for i in (1, 2, 3)]
+ [(i ** 2) for i in (1, 2, 3)]
+-[(i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))]
++[(i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c"))]
+ [((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3)]
+ {i: 0 for i in (1, 2, 3)}
+-{i: j for i, j in ((1, 'a'), (2, 'b'), (3, 'c'))}
++{i: j for i, j in ((1, "a"), (2, "b"), (3, "c"))}
+ {a: b * 2 for a, b in dictionary.items()}
+ {a: b * -2 for a, b in dictionary.items()}
+-{k: v for k, v in this_is_a_very_long_variable_which_will_cause_a_trailing_comma_which_breaks_the_comprehension}
++{
++ k: v
++ for k, v in this_is_a_very_long_variable_which_will_cause_a_trailing_comma_which_breaks_the_comprehension
++}
+ Python3 > Python2 > COBOL
+ Life is Life
+ call()
+ call(arg)
+-call(kwarg='hey')
+-call(arg, kwarg='hey')
+-call(arg, another, kwarg='hey', **kwargs)
+-call(this_is_a_very_long_variable_which_will_force_a_delimiter_split, arg, another, kwarg='hey', **kwargs) # note: no trailing comma pre-3.6
++call(kwarg="hey")
++call(arg, kwarg="hey")
++call(arg, another, kwarg="hey", **kwargs)
++call(
++ this_is_a_very_long_variable_which_will_force_a_delimiter_split,
++ arg,
++ another,
++ kwarg="hey",
++ **kwargs
++) # note: no trailing comma pre-3.6
+ call(*gidgets[:2])
+ call(**self.screen_kwargs)
+ lukasz.langa.pl
+ call.me(maybe)
+ 1 .real
+ 1.0 .real
+ ....__class__
+ list[str]
+ dict[str, int]
+ tuple[str, ...]
+-tuple[str, int, float, dict[str, int],]
++tuple[str, int, float, dict[str, int]]
+ very_long_variable_name_filters: t.List[
+ t.Tuple[str, t.Union[str, t.List[t.Optional[str]]]],
+ ]
+ slice[0]
+ slice[0:1]
+@@ -114,71 +123,90 @@
+ numpy[-(c + 1):, d]
+ numpy[:, l[-2]]
+ numpy[:, ::-1]
+ numpy[np.newaxis, :]
+ (str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None)
+-{'2.7': dead, '3.7': long_live or die_hard}
+-{'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'}
++{"2.7": dead, "3.7": long_live or die_hard}
++{"2.7", "3.6", "3.7", "3.8", "3.9", "4.0" if gilectomy else "3.10"}
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C]
+ (SomeName)
+ SomeName
+ (Good, Bad, Ugly)
+ (i for i in (1, 2, 3))
+ ((i ** 2) for i in (1, 2, 3))
+-((i ** 2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c')))
++((i ** 2) for i, _ in ((1, "a"), (2, "b"), (3, "c")))
+ (((i ** 2) + j) for i in (1, 2, 3) for j in (1, 2, 3))
+ (*starred)
+ a = (1,)
+ b = 1,
+ c = 1
+ d = (1,) + a + (2,)
+ e = (1,).count(1)
+-what_is_up_with_those_new_coord_names = (coord_names + set(vars_to_create)) + set(vars_to_remove)
+-what_is_up_with_those_new_coord_names = (coord_names | set(vars_to_create)) - set(vars_to_remove)
+-result = session.query(models.Customer.id).filter(models.Customer.account_id == account_id, models.Customer.email == email_address).order_by(models.Customer.id.asc(),).all()
++what_is_up_with_those_new_coord_names = (coord_names + set(vars_to_create)) + set(
++ vars_to_remove
++)
++what_is_up_with_those_new_coord_names = (coord_names | set(vars_to_create)) - set(
++ vars_to_remove
++)
++result = session.query(models.Customer.id).filter(
++ models.Customer.account_id == account_id, models.Customer.email == email_address
++).order_by(
++ models.Customer.id.asc()
++).all()
++
+
+ def gen():
+ yield from outside_of_generator
++
+ a = (yield)
++
+
+ async def f():
+ await some.complicated[0].call(with_args=(True or (1 is not 1)))
+
+-if (
+- threading.current_thread() != threading.main_thread() and
+- threading.current_thread() != threading.main_thread() or
+- signal.getsignal(signal.SIGINT) != signal.default_int_handler
+-):
+- return True
+-if (
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa |
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+- return True
+-if (
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa &
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+- return True
+-if (
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+- return True
+-if (
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa -
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+- return True
+-if (
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa *
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+- return True
+-if (
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa /
+- aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+-):
+- return True
++
++if (
++ threading.current_thread() != threading.main_thread()
++ and threading.current_thread() != threading.main_thread()
++ or signal.getsignal(signal.SIGINT) != signal.default_int_handler
++):
++ return True
++
++if (
++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++ | aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++ return True
++
++if (
++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++ & aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++ return True
++
++if (
++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++ + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++ return True
++
++if (
++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++ - aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++ return True
++
++if (
++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++ * aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++ return True
++
++if (
++ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++ / aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
++):
++ return True
++
+ last_call()
+ # standalone comment at ENDMARKER
+
def read_data(name: str) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'"""
- if not name.endswith((".py", ".out")):
+ if not name.endswith((".py", ".out", ".diff")):
name += ".py"
_input: List[str] = []
_output: List[str] = []
try:
sys.stdin, sys.stdout = StringIO(source), StringIO()
sys.stdin.name = "<stdin>"
- black.format_stdin_to_stdout(line_length=ll, fast=True, write_back=True)
+ black.format_stdin_to_stdout(
+ line_length=ll, fast=True, write_back=black.WriteBack.YES
+ )
sys.stdout.seek(0)
actual = sys.stdout.read()
finally:
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
+ def test_piping_diff(self) -> None:
+ source, _ = read_data("expression.py")
+ expected, _ = read_data("expression.diff")
+ hold_stdin, hold_stdout = sys.stdin, sys.stdout
+ try:
+ sys.stdin, sys.stdout = StringIO(source), StringIO()
+ sys.stdin.name = "<stdin>"
+ black.format_stdin_to_stdout(
+ line_length=ll, fast=True, write_back=black.WriteBack.DIFF
+ )
+ sys.stdout.seek(0)
+ actual = sys.stdout.read()
+ finally:
+ sys.stdin, sys.stdout = hold_stdin, hold_stdout
+ actual = actual.rstrip() + "\n" # the diff output has a trailing space
+ self.assertEqual(expected, actual)
+
@patch("black.dump_to_file", dump_to_stderr)
def test_setup(self) -> None:
source, expected = read_data("../setup")
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, line_length=ll)
+ def test_expression_ff(self) -> None:
+ source, expected = read_data("expression")
+ tmp_file = Path(black.dump_to_file(source))
+ try:
+ self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
+ with open(tmp_file) as f:
+ actual = f.read()
+ finally:
+ os.unlink(tmp_file)
+ self.assertFormatEqual(expected, actual)
+ with patch("black.dump_to_file", dump_to_stderr):
+ black.assert_equivalent(source, actual)
+ black.assert_stable(source, actual, line_length=ll)
+
+ def test_expression_diff(self) -> None:
+ source, _ = read_data("expression.py")
+ expected, _ = read_data("expression.diff")
+ tmp_file = Path(black.dump_to_file(source))
+ hold_stdout = sys.stdout
+ try:
+ sys.stdout = StringIO()
+ self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
+ sys.stdout.seek(0)
+ actual = sys.stdout.read()
+ actual = actual.replace(tmp_file.name, "<stdin>")
+ finally:
+ sys.stdout = hold_stdout
+ os.unlink(tmp_file)
+ actual = actual.rstrip() + "\n" # the diff output has a trailing space
+ self.assertEqual(expected, actual)
+
@patch("black.dump_to_file", dump_to_stderr)
def test_fstring(self) -> None:
source, expected = read_data("fstring")