]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Make --safe work for Python2.7 syntax, by using typed_ast for safe validation (#840)
[etc/vim.git] / black.py
index e7dce5bf2abc4a3d1a685f653d21a8ed5f712895..9494c4cf44fc20cc7b580d311e36c0bc46fa4fc2 100644 (file)
--- a/black.py
+++ b/black.py
@@ -40,6 +40,7 @@ from appdirs import user_cache_dir
 from attr import dataclass, evolve, Factory
 import click
 import toml
+from typed_ast import ast3, ast27
 
 # lib2to3 fork
 from blib2to3.pytree import Node, Leaf, type_repr
@@ -440,22 +441,10 @@ def main(
             report=report,
         )
     else:
-        loop = asyncio.get_event_loop()
-        executor = ProcessPoolExecutor(max_workers=os.cpu_count())
-        try:
-            loop.run_until_complete(
-                schedule_formatting(
-                    sources=sources,
-                    fast=fast,
-                    write_back=write_back,
-                    mode=mode,
-                    report=report,
-                    loop=loop,
-                    executor=executor,
-                )
-            )
-        finally:
-            shutdown(loop)
+        reformat_many(
+            sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
+        )
+
     if verbose or not quiet:
         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
         out(f"All done! {bang}")
@@ -497,6 +486,36 @@ def reformat_one(
         report.failed(src, str(exc))
 
 
+def reformat_many(
+    sources: Set[Path],
+    fast: bool,
+    write_back: WriteBack,
+    mode: FileMode,
+    report: "Report",
+) -> None:
+    """Reformat multiple files using a ProcessPoolExecutor."""
+    loop = asyncio.get_event_loop()
+    worker_count = os.cpu_count()
+    if sys.platform == "win32":
+        # Work around https://bugs.python.org/issue26903
+        worker_count = min(worker_count, 61)
+    executor = ProcessPoolExecutor(max_workers=worker_count)
+    try:
+        loop.run_until_complete(
+            schedule_formatting(
+                sources=sources,
+                fast=fast,
+                write_back=write_back,
+                mode=mode,
+                report=report,
+                loop=loop,
+                executor=executor,
+            )
+        )
+    finally:
+        shutdown(loop)
+
+
 async def schedule_formatting(
     sources: Set[Path],
     fast: bool,
@@ -2127,15 +2146,21 @@ def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
 
     consumed = 0
     nlines = 0
+    ignored_lines = 0
     for index, line in enumerate(prefix.split("\n")):
         consumed += len(line) + 1  # adding the length of the split '\n'
         line = line.lstrip()
         if not line:
             nlines += 1
         if not line.startswith("#"):
+            # Escaped newlines outside of a comment are not really newlines at
+            # all. We treat a single-line comment following an escaped newline
+            # as a simple trailing comment.
+            if line.endswith("\\"):
+                ignored_lines += 1
             continue
 
-        if index == 0 and not is_endmarker:
+        if index == ignored_lines and not is_endmarker:
             comment_type = token.COMMENT  # simple trailing comment
         else:
             comment_type = STANDALONE_COMMENT
@@ -3356,17 +3381,31 @@ class Report:
         return ", ".join(report) + "."
 
 
+def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
+    try:
+        return ast3.parse(src)
+    except SyntaxError:
+        return ast27.parse(src)
+
+
 def assert_equivalent(src: str, dst: str) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
 
-    import ast
     import traceback
 
-    def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
+    def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
         """Simple visitor generating strings to compare ASTs by content."""
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
         for field in sorted(node._fields):
+            # TypeIgnore has only one field 'lineno' which breaks this comparison
+            if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
+                break
+
+            # Ignore str kind which is case sensitive / and ignores unicode_literals
+            if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
+                continue
+
             try:
                 value = getattr(node, field)
             except AttributeError:
@@ -3380,15 +3419,15 @@ def assert_equivalent(src: str, dst: str) -> None:
                     # parentheses and they change the AST.
                     if (
                         field == "targets"
-                        and isinstance(node, ast.Delete)
-                        and isinstance(item, ast.Tuple)
+                        and isinstance(node, (ast3.Delete, ast27.Delete))
+                        and isinstance(item, (ast3.Tuple, ast27.Tuple))
                     ):
                         for item in item.elts:
                             yield from _v(item, depth + 2)
-                    elif isinstance(item, ast.AST):
+                    elif isinstance(item, (ast3.AST, ast27.AST)):
                         yield from _v(item, depth + 2)
 
-            elif isinstance(value, ast.AST):
+            elif isinstance(value, (ast3.AST, ast27.AST)):
                 yield from _v(value, depth + 2)
 
             else:
@@ -3397,7 +3436,7 @@ def assert_equivalent(src: str, dst: str) -> None:
         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
 
     try:
-        src_ast = ast.parse(src)
+        src_ast = parse_ast(src)
     except Exception as exc:
         major, minor = sys.version_info[:2]
         raise AssertionError(
@@ -3407,7 +3446,7 @@ def assert_equivalent(src: str, dst: str) -> None:
         )
 
     try:
-        dst_ast = ast.parse(dst)
+        dst_ast = parse_ast(dst)
     except Exception as exc:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(