]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Reword isort configuration, add --combine-as
[etc/vim.git] / black.py
index 46c64a317935952e956a982321afb35d1d3d1dfd..05ec00c8f25944756f8f2f21e66df0a8d2568d8b 100644 (file)
--- a/black.py
+++ b/black.py
@@ -159,6 +159,23 @@ class Changed(Enum):
         "silence those with 2>/dev/null."
     ),
 )
+@click.option(
+    "--pyi",
+    is_flag=True,
+    help=(
+        "Consider all input files typing stubs regardless of file extension "
+        "(useful when piping source on standard input)."
+    ),
+)
+@click.option(
+    "--py36",
+    is_flag=True,
+    help=(
+        "Allow using Python 3.6-only syntax on all input files.  This will put "
+        "trailing commas in function signatures and calls also after *args and "
+        "**kwargs.  [default: per-file auto-detection]"
+    ),
+)
 @click.version_option(version=__version__)
 @click.argument(
     "src",
@@ -174,6 +191,8 @@ def main(
     check: bool,
     diff: bool,
     fast: bool,
+    pyi: bool,
+    py36: bool,
     quiet: bool,
     src: List[str],
 ) -> None:
@@ -204,14 +223,30 @@ def main(
         return
 
     elif len(sources) == 1:
-        reformat_one(sources[0], line_length, fast, write_back, report)
+        reformat_one(
+            src=sources[0],
+            line_length=line_length,
+            fast=fast,
+            pyi=pyi,
+            py36=py36,
+            write_back=write_back,
+            report=report,
+        )
     else:
         loop = asyncio.get_event_loop()
         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
         try:
             loop.run_until_complete(
                 schedule_formatting(
-                    sources, line_length, fast, write_back, report, loop, executor
+                    sources=sources,
+                    line_length=line_length,
+                    fast=fast,
+                    pyi=pyi,
+                    py36=py36,
+                    write_back=write_back,
+                    report=report,
+                    loop=loop,
+                    executor=executor,
                 )
             )
         finally:
@@ -223,33 +258,49 @@ def main(
 
 
 def reformat_one(
-    src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
+    src: Path,
+    line_length: int,
+    fast: bool,
+    pyi: bool,
+    py36: bool,
+    write_back: WriteBack,
+    report: "Report",
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
     If `quiet` is True, non-error messages are not output. `line_length`,
-    `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
+    `write_back`, `fast` and `pyi` options are passed to
+    :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
     """
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
             if format_stdin_to_stdout(
-                line_length=line_length, fast=fast, write_back=write_back
+                line_length=line_length,
+                fast=fast,
+                is_pyi=pyi,
+                force_py36=py36,
+                write_back=write_back,
             ):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
-                cache = read_cache(line_length)
+                cache = read_cache(line_length, pyi, py36)
                 src = src.resolve()
                 if src in cache and cache[src] == get_cache_info(src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
-                src, line_length=line_length, fast=fast, write_back=write_back
+                src,
+                line_length=line_length,
+                fast=fast,
+                force_pyi=pyi,
+                force_py36=py36,
+                write_back=write_back,
             ):
                 changed = Changed.YES
             if write_back == WriteBack.YES and changed is not Changed.NO:
-                write_cache(cache, [src], line_length)
+                write_cache(cache, [src], line_length, pyi, py36)
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
@@ -259,6 +310,8 @@ async def schedule_formatting(
     sources: List[Path],
     line_length: int,
     fast: bool,
+    pyi: bool,
+    py36: bool,
     write_back: WriteBack,
     report: "Report",
     loop: BaseEventLoop,
@@ -268,12 +321,12 @@ async def schedule_formatting(
 
     (Use ProcessPoolExecutors for actual parallelism.)
 
-    `line_length`, `write_back`, and `fast` options are passed to
+    `line_length`, `write_back`, `fast`, and `pyi` options are passed to
     :func:`format_file_in_place`.
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
-        cache = read_cache(line_length)
+        cache = read_cache(line_length, pyi, py36)
         sources, cached = filter_cached(cache, sources)
         for src in cached:
             report.done(src, Changed.CACHED)
@@ -288,7 +341,15 @@ async def schedule_formatting(
             lock = manager.Lock()
         tasks = {
             loop.run_in_executor(
-                executor, format_file_in_place, src, line_length, fast, write_back, lock
+                executor,
+                format_file_in_place,
+                src,
+                line_length,
+                fast,
+                pyi,
+                py36,
+                write_back,
+                lock,
             ): src
             for src in sorted(sources)
         }
@@ -313,13 +374,15 @@ async def schedule_formatting(
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if write_back == WriteBack.YES and formatted:
-        write_cache(cache, formatted, line_length)
+        write_cache(cache, formatted, line_length, pyi, py36)
 
 
 def format_file_in_place(
     src: Path,
     line_length: int,
     fast: bool,
+    force_pyi: bool = False,
+    force_py36: bool = False,
     write_back: WriteBack = WriteBack.NO,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
@@ -328,13 +391,17 @@ def format_file_in_place(
     If `write_back` is True, write reformatted code back to stdout.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
-    is_pyi = src.suffix == ".pyi"
+    is_pyi = force_pyi or src.suffix == ".pyi"
 
     with tokenize.open(src) as src_buffer:
         src_contents = src_buffer.read()
     try:
         dst_contents = format_file_contents(
-            src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
+            src_contents,
+            line_length=line_length,
+            fast=fast,
+            is_pyi=is_pyi,
+            force_py36=force_py36,
         )
     except NothingChanged:
         return False
@@ -357,17 +424,28 @@ def format_file_in_place(
 
 
 def format_stdin_to_stdout(
-    line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
+    line_length: int,
+    fast: bool,
+    is_pyi: bool = False,
+    force_py36: bool = False,
+    write_back: WriteBack = WriteBack.NO,
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is True, write reformatted code back to stdout.
-    `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
+    `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
+    :func:`format_file_contents`.
     """
     src = sys.stdin.read()
     dst = src
     try:
-        dst = format_file_contents(src, line_length=line_length, fast=fast)
+        dst = format_file_contents(
+            src,
+            line_length=line_length,
+            fast=fast,
+            is_pyi=is_pyi,
+            force_py36=force_py36,
+        )
         return True
 
     except NothingChanged:
@@ -383,7 +461,12 @@ def format_stdin_to_stdout(
 
 
 def format_file_contents(
-    src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
+    src_contents: str,
+    *,
+    line_length: int,
+    fast: bool,
+    is_pyi: bool = False,
+    force_py36: bool = False,
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
@@ -394,20 +477,30 @@ def format_file_contents(
     if src_contents.strip() == "":
         raise NothingChanged
 
-    dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
+    dst_contents = format_str(
+        src_contents, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
+    )
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
         assert_stable(
-            src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
+            src_contents,
+            dst_contents,
+            line_length=line_length,
+            is_pyi=is_pyi,
+            force_py36=force_py36,
         )
     return dst_contents
 
 
 def format_str(
-    src_contents: str, line_length: int, *, is_pyi: bool = False
+    src_contents: str,
+    line_length: int,
+    *,
+    is_pyi: bool = False,
+    force_py36: bool = False,
 ) -> FileContent:
     """Reformat a string and return new contents.
 
@@ -417,7 +510,7 @@ def format_str(
     dst_contents = ""
     future_imports = get_future_imports(src_node)
     elt = EmptyLineTracker(is_pyi=is_pyi)
-    py36 = is_python36(src_node)
+    py36 = force_py36 or is_python36(src_node)
     lines = LineGenerator(
         remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
     )
@@ -878,27 +971,6 @@ class Line:
             and second_leaf.value == "def"
         )
 
-    @property
-    def is_flow_control(self) -> bool:
-        """Is this line a flow control statement?
-
-        Those are `return`, `raise`, `break`, and `continue`.
-        """
-        return (
-            bool(self)
-            and self.leaves[0].type == token.NAME
-            and self.leaves[0].value in FLOW_CONTROL
-        )
-
-    @property
-    def is_yield(self) -> bool:
-        """Is this line a yield statement?"""
-        return (
-            bool(self)
-            and self.leaves[0].type == token.NAME
-            and self.leaves[0].value == "yield"
-        )
-
     @property
     def is_class_paren_empty(self) -> bool:
         """Is this a class with no base classes but using parentheses?
@@ -1131,8 +1203,7 @@ class EmptyLineTracker:
         """Return the number of extra empty lines before and after the `current_line`.
 
         This is for separating `def`, `async def` and `class` with extra empty
-        lines (two on module-level), as well as providing an extra empty line
-        after flow control keywords to make them more prominent.
+        lines (two on module-level).
         """
         if isinstance(current_line, UnformattedLines):
             return 0, 0
@@ -1375,32 +1446,6 @@ class LineGenerator(Visitor[Line]):
             yield from self.line()
             yield from self.visit(child)
 
-    def visit_import_from(self, node: Node) -> Iterator[Line]:
-        """Visit import_from and maybe put invisible parentheses.
-
-        This is separate from `visit_stmt` because import statements don't
-        support arbitrary atoms and thus handling of parentheses is custom.
-        """
-        check_lpar = False
-        for index, child in enumerate(node.children):
-            if check_lpar:
-                if child.type == token.LPAR:
-                    # make parentheses invisible
-                    child.value = ""  # type: ignore
-                    node.children[-1].value = ""  # type: ignore
-                else:
-                    # insert invisible parentheses
-                    node.insert_child(index, Leaf(token.LPAR, ""))
-                    node.append_child(Leaf(token.RPAR, ""))
-                break
-
-            check_lpar = (
-                child.type == token.NAME and child.value == "import"  # type: ignore
-            )
-
-        for child in node.children:
-            yield from self.visit(child)
-
     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
         """Remove a semicolon and put the other statement on a separate line."""
         yield from self.line()
@@ -1447,6 +1492,7 @@ class LineGenerator(Visitor[Line]):
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
+        self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
@@ -1820,7 +1866,7 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
     return 0
 
 
-def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
+def generate_comments(leaf: LN) -> Iterator[Leaf]:
     """Clean the prefix of the `leaf` and generate comments from it, if any.
 
     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
@@ -2337,8 +2383,13 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     Standardizes on visible parentheses for single-element tuples, and keeps
     existing visible parentheses for other tuples and generator expressions.
     """
+    try:
+        list(generate_comments(node))
+    except FormatOff:
+        return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
+
     check_lpar = False
-    for child in list(node.children):
+    for index, child in enumerate(list(node.children)):
         if check_lpar:
             if child.type == syms.atom:
                 maybe_make_parens_invisible_in_atom(child)
@@ -2346,8 +2397,21 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
                 # wrap child in visible parentheses
                 lpar = Leaf(token.LPAR, "(")
                 rpar = Leaf(token.RPAR, ")")
-                index = child.remove() or 0
+                child.remove()
                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+            elif node.type == syms.import_from:
+                # "import from" nodes store parentheses directly as part of
+                # the statement
+                if child.type == token.LPAR:
+                    # make parentheses invisible
+                    child.value = ""  # type: ignore
+                    node.children[-1].value = ""  # type: ignore
+                elif child.type != token.STAR:
+                    # insert invisible parentheses
+                    node.insert_child(index, Leaf(token.LPAR, ""))
+                    node.append_child(Leaf(token.RPAR, ""))
+                break
+
             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
                 # wrap child in invisible parentheses
                 lpar = Leaf(token.LPAR, "")
@@ -2606,6 +2670,10 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
         if length > line_length:
             break
 
+        has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
+        if leaf.type == STANDALONE_COMMENT or has_inline_comment:
+            break
+
         optional_brackets.discard(id(leaf))
         if opening_bracket:
             if leaf is opening_bracket:
@@ -2839,9 +2907,13 @@ def assert_equivalent(src: str, dst: str) -> None:
         ) from None
 
 
-def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
+def assert_stable(
+    src: str, dst: str, line_length: int, is_pyi: bool = False, force_py36: bool = False
+) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
+    newdst = format_str(
+        dst, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
+    )
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
@@ -2944,9 +3016,6 @@ def enumerate_with_length(
 
         comment: Optional[Leaf]
         for comment in line.comments_after(leaf, index):
-            if "\n" in comment.prefix:
-                return  # Oops, standalone comment!
-
             length += len(comment.value)
 
         yield index, leaf, length
@@ -3055,16 +3124,19 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     return False
 
 
-def get_cache_file(line_length: int) -> Path:
-    return CACHE_DIR / f"cache.{line_length}.pickle"
+def get_cache_file(line_length: int, pyi: bool = False, py36: bool = False) -> Path:
+    return (
+        CACHE_DIR
+        / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
+    )
 
 
-def read_cache(line_length: int) -> Cache:
+def read_cache(line_length: int, pyi: bool = False, py36: bool = False) -> Cache:
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
-    cache_file = get_cache_file(line_length)
+    cache_file = get_cache_file(line_length, pyi, py36)
     if not cache_file.exists():
         return {}
 
@@ -3101,9 +3173,15 @@ def filter_cached(
     return todo, done
 
 
-def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
+def write_cache(
+    cache: Cache,
+    sources: List[Path],
+    line_length: int,
+    pyi: bool = False,
+    py36: bool = False,
+) -> None:
     """Update the cache file."""
-    cache_file = get_cache_file(line_length)
+    cache_file = get_cache_file(line_length, pyi, py36)
     try:
         if not CACHE_DIR.exists():
             CACHE_DIR.mkdir(parents=True)