X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/de806405d2934b629d67e2a6317ad7e826765a20..bc7a4b2391a19b9ff1093a26d249e32324402498:/black.py?ds=sidebyside

diff --git a/black.py b/black.py
index 7bfcfca..7629d9f 100644
--- a/black.py
+++ b/black.py
@@ -40,6 +40,7 @@ from appdirs import user_cache_dir
 from attr import dataclass, evolve, Factory
 import click
 import toml
+from typed_ast import ast3, ast27
 
 # lib2to3 fork
 from blib2to3.pytree import Node, Leaf, type_repr
@@ -135,19 +136,28 @@ class Feature(Enum):
     NUMERIC_UNDERSCORES = 3
     TRAILING_COMMA_IN_CALL = 4
     TRAILING_COMMA_IN_DEF = 5
+    # The following two feature-flags are mutually exclusive, and exactly one should be
+    # set for every version of python.
+    ASYNC_IDENTIFIERS = 6
+    ASYNC_KEYWORDS = 7
 
 
 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
-    TargetVersion.PY27: set(),
-    TargetVersion.PY33: {Feature.UNICODE_LITERALS},
-    TargetVersion.PY34: {Feature.UNICODE_LITERALS},
-    TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
+    TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
+    TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
+    TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
+    TargetVersion.PY35: {
+        Feature.UNICODE_LITERALS,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.ASYNC_IDENTIFIERS,
+    },
     TargetVersion.PY36: {
         Feature.UNICODE_LITERALS,
         Feature.F_STRINGS,
         Feature.NUMERIC_UNDERSCORES,
         Feature.TRAILING_COMMA_IN_CALL,
         Feature.TRAILING_COMMA_IN_DEF,
+        Feature.ASYNC_IDENTIFIERS,
     },
     TargetVersion.PY37: {
         Feature.UNICODE_LITERALS,
@@ -155,6 +165,7 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
         Feature.NUMERIC_UNDERSCORES,
         Feature.TRAILING_COMMA_IN_CALL,
         Feature.TRAILING_COMMA_IN_DEF,
+        Feature.ASYNC_KEYWORDS,
     },
     TargetVersion.PY38: {
         Feature.UNICODE_LITERALS,
@@ -162,6 +173,7 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
         Feature.NUMERIC_UNDERSCORES,
         Feature.TRAILING_COMMA_IN_CALL,
         Feature.TRAILING_COMMA_IN_DEF,
+        Feature.ASYNC_KEYWORDS,
     },
 }
 
@@ -231,6 +243,7 @@ def read_pyproject_toml(
 
 
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
+@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
     "-l",
     "--line-length",
@@ -357,6 +370,7 @@ def read_pyproject_toml(
 @click.pass_context
 def main(
     ctx: click.Context,
+    code: Optional[str],
     line_length: int,
     target_version: List[TargetVersion],
     check: bool,
@@ -397,6 +411,9 @@ def main(
     )
     if config and verbose:
         out(f"Using configuration from {config}.", bold=False, fg="blue")
+    if code is not None:
+        print(format_str(code, mode=mode))
+        ctx.exit(0)
     try:
         include_regex = re_compile_maybe_verbose(include)
     except re.error:
@@ -435,25 +452,12 @@ def main(
             report=report,
         )
     else:
-        loop = asyncio.get_event_loop()
-        executor = ProcessPoolExecutor(max_workers=os.cpu_count())
-        try:
-            loop.run_until_complete(
-                schedule_formatting(
-                    sources=sources,
-                    fast=fast,
-                    write_back=write_back,
-                    mode=mode,
-                    report=report,
-                    loop=loop,
-                    executor=executor,
-                )
-            )
-        finally:
-            shutdown(loop)
+        reformat_many(
+            sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
+        )
+
     if verbose or not quiet:
-        bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
-        out(f"All done! {bang}")
+        out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
         click.secho(str(report), err=True)
     ctx.exit(report.return_code)
 
@@ -492,6 +496,36 @@ def reformat_one(
         report.failed(src, str(exc))
 
 
+def reformat_many(
+    sources: Set[Path],
+    fast: bool,
+    write_back: WriteBack,
+    mode: FileMode,
+    report: "Report",
+) -> None:
+    """Reformat multiple files using a ProcessPoolExecutor."""
+    loop = asyncio.get_event_loop()
+    worker_count = os.cpu_count()
+    if sys.platform == "win32":
+        # Work around https://bugs.python.org/issue26903
+        worker_count = min(worker_count, 61)
+    executor = ProcessPoolExecutor(max_workers=worker_count)
+    try:
+        loop.run_until_complete(
+            schedule_formatting(
+                sources=sources,
+                fast=fast,
+                write_back=write_back,
+                mode=mode,
+                report=report,
+                loop=loop,
+                executor=executor,
+            )
+        )
+    finally:
+        shutdown(loop)
+
+
 async def schedule_formatting(
     sources: Set[Path],
     fast: bool,
@@ -673,7 +707,7 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
     `line_length` determines how many characters per line are allowed.
     """
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
-    dst_contents = ""
+    dst_contents = []
     future_imports = get_future_imports(src_node)
     if mode.target_versions:
         versions = mode.target_versions
@@ -696,15 +730,15 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
     }
     for current_line in lines.visit(src_node):
         for _ in range(after):
-            dst_contents += str(empty_line)
+            dst_contents.append(str(empty_line))
         before, after = elt.maybe_empty_lines(current_line)
         for _ in range(before):
-            dst_contents += str(empty_line)
+            dst_contents.append(str(empty_line))
         for line in split_line(
             current_line, line_length=mode.line_length, features=split_line_features
         ):
-            dst_contents += str(line)
-    return dst_contents
+            dst_contents.append(str(line))
+    return "".join(dst_contents)
 
 
 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
@@ -728,16 +762,38 @@ def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
     if not target_versions:
         # No target_version specified, so try all grammars.
         return [
+            # Python 3.7+
+            pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
+            # Python 3.0-3.6
             pygram.python_grammar_no_print_statement_no_exec_statement,
+            # Python 2.7 with future print_function import
             pygram.python_grammar_no_print_statement,
+            # Python 2.7
             pygram.python_grammar,
         ]
     elif all(version.is_python2() for version in target_versions):
         # Python 2-only code, so try Python 2 grammars.
-        return [pygram.python_grammar_no_print_statement, pygram.python_grammar]
+        return [
+            # Python 2.7 with future print_function import
+            pygram.python_grammar_no_print_statement,
+            # Python 2.7
+            pygram.python_grammar,
+        ]
     else:
         # Python 3-compatible code, so only try Python 3 grammar.
-        return [pygram.python_grammar_no_print_statement_no_exec_statement]
+        grammars = []
+        # If we have to parse both, try to parse async as a keyword first
+        if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
+            # Python 3.7+
+            grammars.append(
+                pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
+            )
+        if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
+            # Python 3.0-3.6
+            grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
+        # At least one of the above branches must have been taken, because every Python
+        # version has exactly one of the two 'ASYNC_*' flags
+        return grammars
 
 
 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
@@ -1545,6 +1601,26 @@ class LineGenerator(Visitor[Line]):
                 self.current_line.append(node)
         yield from super().visit_default(node)
 
+    def visit_atom(self, node: Node) -> Iterator[Line]:
+        # Always make parentheses invisible around a single node, because it should
+        # not be needed (except in the case of yield, where removing the parentheses
+        # produces a SyntaxError).
+        if (
+            len(node.children) == 3
+            and isinstance(node.children[0], Leaf)
+            and node.children[0].type == token.LPAR
+            and isinstance(node.children[2], Leaf)
+            and node.children[2].type == token.RPAR
+            and isinstance(node.children[1], Leaf)
+            and not (
+                node.children[1].type == token.NAME
+                and node.children[1].value == "yield"
+            )
+        ):
+            node.children[0].value = ""
+            node.children[2].value = ""
+        yield from super().visit_default(node)
+
     def visit_INDENT(self, node: Node) -> Iterator[Line]:
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
@@ -2122,15 +2198,21 @@ def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
 
     consumed = 0
     nlines = 0
+    ignored_lines = 0
     for index, line in enumerate(prefix.split("\n")):
         consumed += len(line) + 1  # adding the length of the split '\n'
         line = line.lstrip()
         if not line:
             nlines += 1
         if not line.startswith("#"):
+            # Escaped newlines outside of a comment are not really newlines at
+            # all. We treat a single-line comment following an escaped newline
+            # as a simple trailing comment.
+            if line.endswith("\\"):
+                ignored_lines += 1
             continue
 
-        if index == 0 and not is_endmarker:
+        if index == ignored_lines and not is_endmarker:
             comment_type = token.COMMENT  # simple trailing comment
         else:
             comment_type = STANDALONE_COMMENT
@@ -2646,7 +2728,15 @@ def normalize_string_quotes(leaf: Leaf) -> None:
         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
     if "f" in prefix.casefold():
-        matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
+        matches = re.findall(
+            r"""
+            (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
+                ([^{].*?)  # contents of the brackets except if begins with {{
+            \}(?:[^}]|$)  # A } followed by end of the string or a non-}
+            """,
+            new_body,
+            re.VERBOSE,
+        )
         for m in matches:
             if "\\" in str(m):
                 # Do not introduce backslashes in interpolated expressions
@@ -3351,17 +3441,34 @@ class Report:
         return ", ".join(report) + "."
 
 
+def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
+    for feature_version in (7, 6):
+        try:
+            return ast3.parse(src, feature_version=feature_version)
+        except SyntaxError:
+            continue
+
+    return ast27.parse(src)
+
+
 def assert_equivalent(src: str, dst: str) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
 
-    import ast
     import traceback
 
-    def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
+    def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
         """Simple visitor generating strings to compare ASTs by content."""
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
         for field in sorted(node._fields):
+            # TypeIgnore has only one field 'lineno' which breaks this comparison
+            if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
+                break
+
+            # Ignore str kind which is case sensitive / and ignores unicode_literals
+            if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
+                continue
+
             try:
                 value = getattr(node, field)
             except AttributeError:
@@ -3375,15 +3482,15 @@ def assert_equivalent(src: str, dst: str) -> None:
                     # parentheses and they change the AST.
                     if (
                         field == "targets"
-                        and isinstance(node, ast.Delete)
-                        and isinstance(item, ast.Tuple)
+                        and isinstance(node, (ast3.Delete, ast27.Delete))
+                        and isinstance(item, (ast3.Tuple, ast27.Tuple))
                     ):
                         for item in item.elts:
                             yield from _v(item, depth + 2)
-                    elif isinstance(item, ast.AST):
+                    elif isinstance(item, (ast3.AST, ast27.AST)):
                         yield from _v(item, depth + 2)
 
-            elif isinstance(value, ast.AST):
+            elif isinstance(value, (ast3.AST, ast27.AST)):
                 yield from _v(value, depth + 2)
 
             else:
@@ -3392,17 +3499,15 @@ def assert_equivalent(src: str, dst: str) -> None:
         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
 
     try:
-        src_ast = ast.parse(src)
+        src_ast = parse_ast(src)
     except Exception as exc:
-        major, minor = sys.version_info[:2]
         raise AssertionError(
-            f"cannot use --safe with this file; failed to parse source file "
-            f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
-            f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
+            f"cannot use --safe with this file; failed to parse source file.  "
+            f"AST error message: {exc}"
         )
 
     try:
-        dst_ast = ast.parse(dst)
+        dst_ast = parse_ast(dst)
     except Exception as exc:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(