TypeVar,
Union,
cast,
+ TYPE_CHECKING,
)
from typing_extensions import Final
from mypy_extensions import mypyc_attr
from _black_version import version as __version__
+if TYPE_CHECKING:
+ import colorama # noqa: F401
+
DEFAULT_LINE_LENGTH = 88
DEFAULT_EXCLUDES = r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950
DEFAULT_INCLUDES = r"\.pyi?$"
YES = 1
DIFF = 2
CHECK = 3
+ COLOR_DIFF = 4
@classmethod
- def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
+ def from_configuration(
+ cls, *, check: bool, diff: bool, color: bool = False
+ ) -> "WriteBack":
if check and not diff:
return cls.CHECK
+ if diff and color:
+ return cls.COLOR_DIFF
+
return cls.DIFF if diff else cls.YES
if not config:
return None
+ target_version = config.get("target_version")
+ if target_version is not None and not isinstance(target_version, list):
+ raise click.BadOptionUsage(
+ "target-version", f"Config key target-version must be a list"
+ )
+
default_map: Dict[str, Any] = {}
if ctx.default_map:
default_map.update(ctx.default_map)
is_flag=True,
help="Don't write the files back, just output a diff for each file on stdout.",
)
+@click.option(
+ "--color/--no-color",
+ is_flag=True,
+ help="Show colored diff. Only applies when `--diff` is given.",
+)
@click.option(
"--fast/--safe",
is_flag=True,
target_version: List[TargetVersion],
check: bool,
diff: bool,
+ color: bool,
fast: bool,
pyi: bool,
py36: bool,
config: Optional[str],
) -> None:
"""The uncompromising code formatter."""
- write_back = WriteBack.from_configuration(check=check, diff=diff)
+ write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
if target_version:
if py36:
err("Cannot use both --target-version and --py36")
if write_back == WriteBack.YES:
with open(src, "w", encoding=encoding, newline=newline) as f:
f.write(dst_contents)
- elif write_back == WriteBack.DIFF:
+ elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
now = datetime.utcnow()
src_name = f"{src}\t{then} +0000"
dst_name = f"{src}\t{now} +0000"
diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
+ if write_back == write_back.COLOR_DIFF:
+ diff_contents = color_diff(diff_contents)
+
with lock or nullcontext():
f = io.TextIOWrapper(
sys.stdout.buffer,
newline=newline,
write_through=True,
)
+ f = wrap_stream_for_windows(f)
f.write(diff_contents)
f.detach()
return True
+def color_diff(contents: str) -> str:
+ """Inject the ANSI color codes to the diff."""
+ lines = contents.split("\n")
+ for i, line in enumerate(lines):
+ if line.startswith("+++") or line.startswith("---"):
+ line = "\033[1;37m" + line + "\033[0m" # bold white, reset
+ if line.startswith("@@"):
+ line = "\033[36m" + line + "\033[0m" # cyan, reset
+ if line.startswith("+"):
+ line = "\033[32m" + line + "\033[0m" # green, reset
+ elif line.startswith("-"):
+ line = "\033[31m" + line + "\033[0m" # red, reset
+ lines[i] = line
+ return "\n".join(lines)
+
+
+def wrap_stream_for_windows(
+ f: io.TextIOWrapper,
+) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
+ """
+ Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
+
+ If `colorama` is not found, then no change is made. If `colorama` does
+ exist, then it handles the logic to determine whether or not to change
+ things.
+ """
+ try:
+ from colorama import initialise
+
+ # We set `strip=False` so that we can don't have to modify
+ # test_express_diff_with_color.
+ f = initialise.wrap_stream(
+ f, convert=None, strip=False, autoreset=False, wrap=True
+ )
+
+ # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
+ # which does not have a `detach()` method. So we fake one.
+ f.detach = lambda *args, **kwargs: None # type: ignore
+ except ImportError:
+ pass
+
+ return f
+
+
def format_stdin_to_stdout(
fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
) -> bool:
)
if write_back == WriteBack.YES:
f.write(dst)
- elif write_back == WriteBack.DIFF:
+ elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
now = datetime.utcnow()
src_name = f"STDIN\t{then} +0000"
dst_name = f"STDOUT\t{now} +0000"
- f.write(diff(src, dst, src_name, dst_name))
+ d = diff(src, dst, src_name, dst_name)
+ if write_back == WriteBack.COLOR_DIFF:
+ d = color_diff(d)
+ f = wrap_stream_for_windows(f)
+ f.write(d)
f.detach()
node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
yield from self.visit_default(node)
+ def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
+ # Check if it's a docstring
+ if prev_siblings_are(
+ leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
+ ) and is_multiline_string(leaf):
+ prefix = " " * self.current_line.depth
+ docstring = fix_docstring(leaf.value[3:-3], prefix)
+ leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:]
+ normalize_string_quotes(leaf)
+
+ yield from self.visit_default(leaf)
+
def __post_init__(self) -> None:
"""You are in a twisty little maze of passages."""
v = self.visit_stmt
return None
+def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
+ """Return if the `node` and its previous siblings match types against the provided
+ list of tokens; the provided `node`has its type matched against the last element in
+ the list. `None` can be used as the first element to declare that the start of the
+ list is anchored at the start of its parent's children."""
+ if not tokens:
+ return True
+ if tokens[-1] is None:
+ return node is None
+ if not node:
+ return False
+ if node.type != tokens[-1]:
+ return False
+ return prev_siblings_are(node.prev_sibling, tokens[:-1])
+
+
def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
"""Return the child of `ancestor` that contains `descendant`."""
node: Optional[LN] = descendant
is_line_short_enough(line, line_length=line_length, line_str=line_str)
or line.contains_unsplittable_type_ignore()
)
+ and not (line.contains_standalone_comments() and line.inside_brackets)
):
# Only apply basic string preprocessing, since lines shouldn't be split here.
transformers = [string_merge, string_paren_strip]
return node
-def assert_equivalent(src: str, dst: str) -> None:
- """Raise AssertionError if `src` and `dst` aren't equivalent."""
-
- def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
- """Simple visitor generating strings to compare ASTs by content."""
+def _stringify_ast(
+ node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
+) -> Iterator[str]:
+ """Simple visitor generating strings to compare ASTs by content."""
- node = _fixup_ast_constants(node)
+ node = _fixup_ast_constants(node)
- yield f"{' ' * depth}{node.__class__.__name__}("
+ yield f"{' ' * depth}{node.__class__.__name__}("
- for field in sorted(node._fields): # noqa: F402
- # TypeIgnore has only one field 'lineno' which breaks this comparison
- type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
- if sys.version_info >= (3, 8):
- type_ignore_classes += (ast.TypeIgnore,)
- if isinstance(node, type_ignore_classes):
- break
+ for field in sorted(node._fields): # noqa: F402
+ # TypeIgnore has only one field 'lineno' which breaks this comparison
+ type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
+ if sys.version_info >= (3, 8):
+ type_ignore_classes += (ast.TypeIgnore,)
+ if isinstance(node, type_ignore_classes):
+ break
- try:
- value = getattr(node, field)
- except AttributeError:
- continue
+ try:
+ value = getattr(node, field)
+ except AttributeError:
+ continue
- yield f"{' ' * (depth+1)}{field}="
+ yield f"{' ' * (depth+1)}{field}="
- if isinstance(value, list):
- for item in value:
- # Ignore nested tuples within del statements, because we may insert
- # parentheses and they change the AST.
- if (
- field == "targets"
- and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
- and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
- ):
- for item in item.elts:
- yield from _v(item, depth + 2)
+ if isinstance(value, list):
+ for item in value:
+ # Ignore nested tuples within del statements, because we may insert
+ # parentheses and they change the AST.
+ if (
+ field == "targets"
+ and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
+ and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
+ ):
+ for item in item.elts:
+ yield from _stringify_ast(item, depth + 2)
- elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
- yield from _v(item, depth + 2)
+ elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
+ yield from _stringify_ast(item, depth + 2)
- elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
- yield from _v(value, depth + 2)
+ elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
+ yield from _stringify_ast(value, depth + 2)
+ else:
+ # Constant strings may be indented across newlines, if they are
+ # docstrings; fold spaces after newlines when comparing
+ if (
+ isinstance(node, ast.Constant)
+ and field == "value"
+ and isinstance(value, str)
+ ):
+ normalized = re.sub(r"\n[ \t]+", "\n ", value)
else:
- yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
+ normalized = value
+ yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}"
- yield f"{' ' * depth}) # /{node.__class__.__name__}"
+ yield f"{' ' * depth}) # /{node.__class__.__name__}"
+
+def assert_equivalent(src: str, dst: str) -> None:
+ """Raise AssertionError if `src` and `dst` aren't equivalent."""
try:
src_ast = parse_ast(src)
except Exception as exc:
f" helpful: {log}"
) from None
- src_ast_str = "\n".join(_v(src_ast))
- dst_ast_str = "\n".join(_v(dst_ast))
+ src_ast_str = "\n".join(_stringify_ast(src_ast))
+ dst_ast_str = "\n".join(_stringify_ast(dst_ast))
if src_ast_str != dst_ast_str:
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError(
main()
+def fix_docstring(docstring: str, prefix: str) -> str:
+ # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
+ if not docstring:
+ return ""
+ # Convert tabs to spaces (following the normal Python rules)
+ # and split into a list of lines:
+ lines = docstring.expandtabs().splitlines()
+ # Determine minimum indentation (first line doesn't count):
+ indent = sys.maxsize
+ for line in lines[1:]:
+ stripped = line.lstrip()
+ if stripped:
+ indent = min(indent, len(line) - len(stripped))
+ # Remove indentation (first line is special):
+ trimmed = [lines[0].strip()]
+ if indent < sys.maxsize:
+ last_line_idx = len(lines) - 2
+ for i, line in enumerate(lines[1:]):
+ stripped_line = line[indent:].rstrip()
+ if stripped_line or i == last_line_idx:
+ trimmed.append(prefix + stripped_line)
+ else:
+ trimmed.append("")
+ # Return a single string:
+ return "\n".join(trimmed)
+
+
if __name__ == "__main__":
patched_main()