import asyncio
from abc import ABC, abstractmethod
from collections import defaultdict
-from concurrent.futures import Executor, ProcessPoolExecutor
+from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor
from contextlib import contextmanager
from datetime import datetime
from enum import Enum
Pattern,
Sequence,
Set,
+ Sized,
Tuple,
Type,
TypeVar,
Union,
cast,
+ TYPE_CHECKING,
)
from typing_extensions import Final
from mypy_extensions import mypyc_attr
from _black_version import version as __version__
+if TYPE_CHECKING:
+ import colorama # noqa: F401
+
DEFAULT_LINE_LENGTH = 88
DEFAULT_EXCLUDES = r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950
DEFAULT_INCLUDES = r"\.pyi?$"
YES = 1
DIFF = 2
CHECK = 3
+ COLOR_DIFF = 4
@classmethod
- def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
+ def from_configuration(
+ cls, *, check: bool, diff: bool, color: bool = False
+ ) -> "WriteBack":
if check and not diff:
return cls.CHECK
+ if diff and color:
+ return cls.COLOR_DIFF
+
return cls.DIFF if diff else cls.YES
if not config:
return None
+ target_version = config.get("target_version")
+ if target_version is not None and not isinstance(target_version, list):
+ raise click.BadOptionUsage(
+ "target-version", f"Config key target-version must be a list"
+ )
+
default_map: Dict[str, Any] = {}
if ctx.default_map:
default_map.update(ctx.default_map)
" auto-detection]"
),
)
-@click.option(
- "--py36",
- is_flag=True,
- help=(
- "Allow using Python 3.6-only syntax on all input files. This will put trailing"
- " commas in function signatures and calls also after *args and **kwargs."
- " Deprecated; use --target-version instead. [default: per-file auto-detection]"
- ),
-)
@click.option(
"--pyi",
is_flag=True,
is_flag=True,
help="Don't write the files back, just output a diff for each file on stdout.",
)
+@click.option(
+ "--color/--no-color",
+ is_flag=True,
+ help="Show colored diff. Only applies when `--diff` is given.",
+)
@click.option(
"--fast/--safe",
is_flag=True,
),
show_default=True,
)
+@click.option(
+ "--force-exclude",
+ type=str,
+ help=(
+ "Like --exclude, but files and directories matching this regex will be "
+ "excluded even when they are passed explicitly as arguments"
+ ),
+)
@click.option(
"-q",
"--quiet",
target_version: List[TargetVersion],
check: bool,
diff: bool,
+ color: bool,
fast: bool,
pyi: bool,
- py36: bool,
skip_string_normalization: bool,
quiet: bool,
verbose: bool,
include: str,
exclude: str,
+ force_exclude: Optional[str],
src: Tuple[str, ...],
config: Optional[str],
) -> None:
"""The uncompromising code formatter."""
- write_back = WriteBack.from_configuration(check=check, diff=diff)
+ write_back = WriteBack.from_configuration(check=check, diff=diff, color=color)
if target_version:
- if py36:
- err("Cannot use both --target-version and --py36")
- ctx.exit(2)
- else:
- versions = set(target_version)
- elif py36:
- err(
- "--py36 is deprecated and will be removed in a future version. Use"
- " --target-version py36 instead."
- )
- versions = PY36_VERSIONS
+ versions = set(target_version)
else:
# We'll autodetect later.
versions = set()
if code is not None:
print(format_str(code, mode=mode))
ctx.exit(0)
+ report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
+ sources = get_sources(
+ ctx=ctx,
+ src=src,
+ quiet=quiet,
+ verbose=verbose,
+ include=include,
+ exclude=exclude,
+ force_exclude=force_exclude,
+ report=report,
+ )
+
+ path_empty(
+ sources,
+ "No Python files are present to be formatted. Nothing to do 😴",
+ quiet,
+ verbose,
+ ctx,
+ )
+
+ if len(sources) == 1:
+ reformat_one(
+ src=sources.pop(),
+ fast=fast,
+ write_back=write_back,
+ mode=mode,
+ report=report,
+ )
+ else:
+ reformat_many(
+ sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
+ )
+
+ if verbose or not quiet:
+ out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
+ click.secho(str(report), err=True)
+ ctx.exit(report.return_code)
+
+
+def get_sources(
+ *,
+ ctx: click.Context,
+ src: Tuple[str, ...],
+ quiet: bool,
+ verbose: bool,
+ include: str,
+ exclude: str,
+ force_exclude: Optional[str],
+ report: "Report",
+) -> Set[Path]:
+ """Compute the set of files to be formatted."""
try:
include_regex = re_compile_maybe_verbose(include)
except re.error:
except re.error:
err(f"Invalid regular expression for exclude given: {exclude!r}")
ctx.exit(2)
- report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
+ try:
+ force_exclude_regex = (
+ re_compile_maybe_verbose(force_exclude) if force_exclude else None
+ )
+ except re.error:
+ err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
+ ctx.exit(2)
+
root = find_project_root(src)
sources: Set[Path] = set()
- path_empty(src, quiet, verbose, ctx)
+ path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx)
+ exclude_regexes = [exclude_regex]
+ if force_exclude_regex is not None:
+ exclude_regexes.append(force_exclude_regex)
+
for s in src:
p = Path(s)
if p.is_dir():
sources.update(
- gen_python_files_in_dir(
- p, root, include_regex, exclude_regex, report, get_gitignore(root)
+ gen_python_files(
+ p.iterdir(),
+ root,
+ include_regex,
+ exclude_regexes,
+ report,
+ get_gitignore(root),
)
)
- elif p.is_file() or s == "-":
- # if a file was explicitly given, we don't care about its extension
+ elif s == "-":
sources.add(p)
+ elif p.is_file():
+ sources.update(
+ gen_python_files(
+ [p], root, None, exclude_regexes, report, get_gitignore(root)
+ )
+ )
else:
err(f"invalid path: {s}")
- if len(sources) == 0:
- if verbose or not quiet:
- out("No Python files are present to be formatted. Nothing to do 😴")
- ctx.exit(0)
-
- if len(sources) == 1:
- reformat_one(
- src=sources.pop(),
- fast=fast,
- write_back=write_back,
- mode=mode,
- report=report,
- )
- else:
- reformat_many(
- sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
- )
-
- if verbose or not quiet:
- out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
- click.secho(str(report), err=True)
- ctx.exit(report.return_code)
+ return sources
def path_empty(
- src: Tuple[str, ...], quiet: bool, verbose: bool, ctx: click.Context
+ src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context
) -> None:
"""
Exit if there is no `src` provided for formatting
"""
- if not src:
+ if len(src) == 0:
if verbose or not quiet:
- out("No Path provided. Nothing to do 😴")
+ out(msg)
ctx.exit(0)
sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
) -> None:
"""Reformat multiple files using a ProcessPoolExecutor."""
+ executor: Executor
loop = asyncio.get_event_loop()
worker_count = os.cpu_count()
if sys.platform == "win32":
# Work around https://bugs.python.org/issue26903
worker_count = min(worker_count, 61)
- executor = ProcessPoolExecutor(max_workers=worker_count)
+ try:
+ executor = ProcessPoolExecutor(max_workers=worker_count)
+ except OSError:
+ # we arrive here if the underlying system does not support multi-processing
+ # like in AWS Lambda, in which case we gracefully fallback to
+ # a ThreadPollExecutor with just a single worker (more workers would not do us
+ # any good due to the Global Interpreter Lock)
+ executor = ThreadPoolExecutor(max_workers=1)
+
try:
loop.run_until_complete(
schedule_formatting(
)
finally:
shutdown(loop)
- executor.shutdown()
+ if executor is not None:
+ executor.shutdown()
async def schedule_formatting(
if write_back == WriteBack.YES:
with open(src, "w", encoding=encoding, newline=newline) as f:
f.write(dst_contents)
- elif write_back == WriteBack.DIFF:
+ elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
now = datetime.utcnow()
src_name = f"{src}\t{then} +0000"
dst_name = f"{src}\t{now} +0000"
diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
+ if write_back == write_back.COLOR_DIFF:
+ diff_contents = color_diff(diff_contents)
+
with lock or nullcontext():
f = io.TextIOWrapper(
sys.stdout.buffer,
newline=newline,
write_through=True,
)
+ f = wrap_stream_for_windows(f)
f.write(diff_contents)
f.detach()
return True
+def color_diff(contents: str) -> str:
+ """Inject the ANSI color codes to the diff."""
+ lines = contents.split("\n")
+ for i, line in enumerate(lines):
+ if line.startswith("+++") or line.startswith("---"):
+ line = "\033[1;37m" + line + "\033[0m" # bold white, reset
+ if line.startswith("@@"):
+ line = "\033[36m" + line + "\033[0m" # cyan, reset
+ if line.startswith("+"):
+ line = "\033[32m" + line + "\033[0m" # green, reset
+ elif line.startswith("-"):
+ line = "\033[31m" + line + "\033[0m" # red, reset
+ lines[i] = line
+ return "\n".join(lines)
+
+
+def wrap_stream_for_windows(
+ f: io.TextIOWrapper,
+) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
+ """
+ Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
+
+ If `colorama` is not found, then no change is made. If `colorama` does
+ exist, then it handles the logic to determine whether or not to change
+ things.
+ """
+ try:
+ from colorama import initialise
+
+ # We set `strip=False` so that we can don't have to modify
+ # test_express_diff_with_color.
+ f = initialise.wrap_stream(
+ f, convert=None, strip=False, autoreset=False, wrap=True
+ )
+
+ # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
+ # which does not have a `detach()` method. So we fake one.
+ f.detach = lambda *args, **kwargs: None # type: ignore
+ except ImportError:
+ pass
+
+ return f
+
+
def format_stdin_to_stdout(
fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
) -> bool:
)
if write_back == WriteBack.YES:
f.write(dst)
- elif write_back == WriteBack.DIFF:
+ elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
now = datetime.utcnow()
src_name = f"STDIN\t{then} +0000"
dst_name = f"STDOUT\t{now} +0000"
- f.write(diff(src, dst, src_name, dst_name))
+ d = diff(src, dst, src_name, dst_name)
+ if write_back == WriteBack.COLOR_DIFF:
+ d = color_diff(d)
+ f = wrap_stream_for_windows(f)
+ f.write(d)
f.detach()
node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
yield from self.visit_default(node)
+ def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
+ # Check if it's a docstring
+ if prev_siblings_are(
+ leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
+ ) and is_multiline_string(leaf):
+ prefix = " " * self.current_line.depth
+ docstring = fix_docstring(leaf.value[3:-3], prefix)
+ leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:]
+ normalize_string_quotes(leaf)
+
+ yield from self.visit_default(leaf)
+
def __post_init__(self) -> None:
"""You are in a twisty little maze of passages."""
v = self.visit_stmt
return None
+def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool:
+ """Return if the `node` and its previous siblings match types against the provided
+ list of tokens; the provided `node`has its type matched against the last element in
+ the list. `None` can be used as the first element to declare that the start of the
+ list is anchored at the start of its parent's children."""
+ if not tokens:
+ return True
+ if tokens[-1] is None:
+ return node is None
+ if not node:
+ return False
+ if node.type != tokens[-1]:
+ return False
+ return prev_siblings_are(node.prev_sibling, tokens[:-1])
+
+
def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
"""Return the child of `ancestor` that contains `descendant`."""
node: Optional[LN] = descendant
is_line_short_enough(line, line_length=line_length, line_str=line_str)
or line.contains_unsplittable_type_ignore()
)
+ and not (line.contains_standalone_comments() and line.inside_brackets)
):
# Only apply basic string preprocessing, since lines shouldn't be split here.
transformers = [string_merge, string_paren_strip]
"""
container: Optional[LN] = container_of(leaf)
while container is not None and container.type != token.ENDMARKER:
- if fmt_on(container):
+ if is_fmt_on(container):
return
# fix for fmt: on in children
container = container.next_sibling
-def fmt_on(container: LN) -> bool:
- is_fmt_on = False
+def is_fmt_on(container: LN) -> bool:
+ """Determine whether formatting is switched on within a container.
+ Determined by whether the last `# fmt:` comment is `on` or `off`.
+ """
+ fmt_on = False
for comment in list_comments(container.prefix, is_endmarker=False):
if comment.value in FMT_ON:
- is_fmt_on = True
+ fmt_on = True
elif comment.value in FMT_OFF:
- is_fmt_on = False
- return is_fmt_on
+ fmt_on = False
+ return fmt_on
def contains_fmt_on_at_column(container: LN, column: int) -> bool:
+ """Determine if children at a given column have formatting switched on."""
for child in container.children:
if (
isinstance(child, Node)
or isinstance(child, Leaf)
and child.column == column
):
- if fmt_on(child):
+ if is_fmt_on(child):
return True
return False
def first_leaf_column(node: Node) -> Optional[int]:
+ """Returns the column of the first leaf child of a node."""
for child in node.children:
if isinstance(child, Leaf):
return child.column
return PathSpec.from_lines("gitwildmatch", lines)
-def gen_python_files_in_dir(
- path: Path,
+def gen_python_files(
+ paths: Iterable[Path],
root: Path,
- include: Pattern[str],
- exclude: Pattern[str],
+ include: Optional[Pattern[str]],
+ exclude_regexes: Iterable[Pattern[str]],
report: "Report",
gitignore: PathSpec,
) -> Iterator[Path]:
`report` is where output about exclusions goes.
"""
assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
- for child in path.iterdir():
- # First ignore files matching .gitignore
- if gitignore.match_file(child.as_posix()):
- report.path_ignored(child, "matches the .gitignore file content")
- continue
-
+ for child in paths:
# Then ignore with `exclude` option.
try:
- normalized_path = "/" + child.resolve().relative_to(root).as_posix()
+ normalized_path = child.resolve().relative_to(root).as_posix()
except OSError as e:
report.path_ignored(child, f"cannot be read because {e}")
continue
-
except ValueError:
if child.is_symlink():
report.path_ignored(
raise
+ # First ignore files matching .gitignore
+ if gitignore.match_file(normalized_path):
+ report.path_ignored(child, "matches the .gitignore file content")
+ continue
+
+ normalized_path = "/" + normalized_path
if child.is_dir():
normalized_path += "/"
- exclude_match = exclude.search(normalized_path)
- if exclude_match and exclude_match.group(0):
- report.path_ignored(child, "matches the --exclude regular expression")
+ is_excluded = False
+ for exclude in exclude_regexes:
+ exclude_match = exclude.search(normalized_path) if exclude else None
+ if exclude_match and exclude_match.group(0):
+ report.path_ignored(child, "matches the --exclude regular expression")
+ is_excluded = True
+ break
+ if is_excluded:
continue
if child.is_dir():
- yield from gen_python_files_in_dir(
- child, root, include, exclude, report, gitignore
+ yield from gen_python_files(
+ child.iterdir(), root, include, exclude_regexes, report, gitignore
)
elif child.is_file():
- include_match = include.search(normalized_path)
+ include_match = include.search(normalized_path) if include else True
if include_match:
yield child
return node
-def assert_equivalent(src: str, dst: str) -> None:
- """Raise AssertionError if `src` and `dst` aren't equivalent."""
+def _stringify_ast(
+ node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
+) -> Iterator[str]:
+ """Simple visitor generating strings to compare ASTs by content."""
- def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
- """Simple visitor generating strings to compare ASTs by content."""
+ node = _fixup_ast_constants(node)
- node = _fixup_ast_constants(node)
+ yield f"{' ' * depth}{node.__class__.__name__}("
- yield f"{' ' * depth}{node.__class__.__name__}("
-
- for field in sorted(node._fields): # noqa: F402
- # TypeIgnore has only one field 'lineno' which breaks this comparison
- type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
- if sys.version_info >= (3, 8):
- type_ignore_classes += (ast.TypeIgnore,)
- if isinstance(node, type_ignore_classes):
- break
+ for field in sorted(node._fields): # noqa: F402
+ # TypeIgnore has only one field 'lineno' which breaks this comparison
+ type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
+ if sys.version_info >= (3, 8):
+ type_ignore_classes += (ast.TypeIgnore,)
+ if isinstance(node, type_ignore_classes):
+ break
- try:
- value = getattr(node, field)
- except AttributeError:
- continue
+ try:
+ value = getattr(node, field)
+ except AttributeError:
+ continue
- yield f"{' ' * (depth+1)}{field}="
+ yield f"{' ' * (depth+1)}{field}="
- if isinstance(value, list):
- for item in value:
- # Ignore nested tuples within del statements, because we may insert
- # parentheses and they change the AST.
- if (
- field == "targets"
- and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
- and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
- ):
- for item in item.elts:
- yield from _v(item, depth + 2)
+ if isinstance(value, list):
+ for item in value:
+ # Ignore nested tuples within del statements, because we may insert
+ # parentheses and they change the AST.
+ if (
+ field == "targets"
+ and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
+ and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
+ ):
+ for item in item.elts:
+ yield from _stringify_ast(item, depth + 2)
- elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
- yield from _v(item, depth + 2)
+ elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
+ yield from _stringify_ast(item, depth + 2)
- elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
- yield from _v(value, depth + 2)
+ elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
+ yield from _stringify_ast(value, depth + 2)
+ else:
+ # Constant strings may be indented across newlines, if they are
+ # docstrings; fold spaces after newlines when comparing
+ if (
+ isinstance(node, ast.Constant)
+ and field == "value"
+ and isinstance(value, str)
+ ):
+ normalized = re.sub(r"\n[ \t]+", "\n ", value)
else:
- yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
+ normalized = value
+ yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}"
- yield f"{' ' * depth}) # /{node.__class__.__name__}"
+ yield f"{' ' * depth}) # /{node.__class__.__name__}"
+
+def assert_equivalent(src: str, dst: str) -> None:
+ """Raise AssertionError if `src` and `dst` aren't equivalent."""
try:
src_ast = parse_ast(src)
except Exception as exc:
f" helpful: {log}"
) from None
- src_ast_str = "\n".join(_v(src_ast))
- dst_ast_str = "\n".join(_v(dst_ast))
+ src_ast_str = "\n".join(_stringify_ast(src_ast))
+ dst_ast_str = "\n".join(_stringify_ast(dst_ast))
if src_ast_str != dst_ast_str:
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError(
main()
+def fix_docstring(docstring: str, prefix: str) -> str:
+ # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
+ if not docstring:
+ return ""
+ # Convert tabs to spaces (following the normal Python rules)
+ # and split into a list of lines:
+ lines = docstring.expandtabs().splitlines()
+ # Determine minimum indentation (first line doesn't count):
+ indent = sys.maxsize
+ for line in lines[1:]:
+ stripped = line.lstrip()
+ if stripped:
+ indent = min(indent, len(line) - len(stripped))
+ # Remove indentation (first line is special):
+ trimmed = [lines[0].strip()]
+ if indent < sys.maxsize:
+ last_line_idx = len(lines) - 2
+ for i, line in enumerate(lines[1:]):
+ stripped_line = line[indent:].rstrip()
+ if stripped_line or i == last_line_idx:
+ trimmed.append(prefix + stripped_line)
+ else:
+ trimmed.append("")
+ # Return a single string:
+ return "\n".join(trimmed)
+
+
if __name__ == "__main__":
patched_main()