]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Consistently format async statements similar to their non-async version. (#3609)
[etc/vim.git] / src / black / __init__.py
index d9fba41ebd34136560e09ae296e1a4a425dec8f5..4ebf28821c39ca339b05f372de2695246f6fcc99 100644 (file)
@@ -30,6 +30,7 @@ from typing import (
 import click
 from click.core import ParameterSource
 from mypy_extensions import mypyc_attr
 import click
 from click.core import ParameterSource
 from mypy_extensions import mypyc_attr
+from pathspec import PathSpec
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
 
 from _black_version import version as __version__
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
 
 from _black_version import version as __version__
@@ -218,8 +219,9 @@ def validate_regex(
     callback=target_version_option_callback,
     multiple=True,
     help=(
     callback=target_version_option_callback,
     multiple=True,
     help=(
-        "Python versions that should be supported by Black's output. [default: per-file"
-        " auto-detection]"
+        "Python versions that should be supported by Black's output. By default, Black"
+        " will try to infer this from the project metadata in pyproject.toml. If this"
+        " does not yield conclusive results, Black will use per-file auto-detection."
     ),
 )
 @click.option(
     ),
 )
 @click.option(
@@ -243,7 +245,7 @@ def validate_regex(
     multiple=True,
     help=(
         "When processing Jupyter Notebooks, add the given magic to the list"
     multiple=True,
     help=(
         "When processing Jupyter Notebooks, add the given magic to the list"
-        f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
+        f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
         " Useful for formatting cells with custom python magics."
     ),
     default=[],
         " Useful for formatting cells with custom python magics."
     ),
     default=[],
@@ -477,16 +479,20 @@ def main(  # noqa: C901
             )
 
             normalized = [
             )
 
             normalized = [
-                (source, source)
-                if source == "-"
-                else (normalize_path_maybe_ignore(Path(source), root), source)
+                (
+                    (source, source)
+                    if source == "-"
+                    else (normalize_path_maybe_ignore(Path(source), root), source)
+                )
                 for source in src
             ]
             srcs_string = ", ".join(
                 [
                 for source in src
             ]
             srcs_string = ", ".join(
                 [
-                    f'"{_norm}"'
-                    if _norm
-                    else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+                    (
+                        f'"{_norm}"'
+                        if _norm
+                        else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+                    )
                     for _norm, source in normalized
                 ]
             )
                     for _norm, source in normalized
                 ]
             )
@@ -497,8 +503,10 @@ def main(  # noqa: C901
             user_level_config = str(find_user_pyproject_toml())
             if config == user_level_config:
                 out(
             user_level_config = str(find_user_pyproject_toml())
             if config == user_level_config:
                 out(
-                    "Using configuration from user-level config at "
-                    f"'{user_level_config}'.",
+                    (
+                        "Using configuration from user-level config at "
+                        f"'{user_level_config}'."
+                    ),
                     fg="blue",
                 )
             elif config_source in (
                     fg="blue",
                 )
             elif config_source in (
@@ -508,6 +516,9 @@ def main(  # noqa: C901
                 out("Using configuration from project root.", fg="blue")
             else:
                 out(f"Using configuration in '{config}'.", fg="blue")
                 out("Using configuration from project root.", fg="blue")
             else:
                 out(f"Using configuration in '{config}'.", fg="blue")
+            if ctx.default_map:
+                for param, value in ctx.default_map.items():
+                    out(f"{param}: {value}")
 
     error_msg = "Oh no! 💥 💔 💥"
     if (
 
     error_msg = "Oh no! 💥 💔 💥"
     if (
@@ -625,6 +636,11 @@ def get_sources(
     sources: Set[Path] = set()
     root = ctx.obj["root"]
 
     sources: Set[Path] = set()
     root = ctx.obj["root"]
 
+    using_default_exclude = exclude is None
+    exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
+    gitignore: Optional[Dict[Path, PathSpec]] = None
+    root_gitignore = get_gitignore(root)
+
     for s in src:
         if s == "-" and stdin_filename:
             p = Path(stdin_filename)
     for s in src:
         if s == "-" and stdin_filename:
             p = Path(stdin_filename)
@@ -658,16 +674,12 @@ def get_sources(
 
             sources.add(p)
         elif p.is_dir():
 
             sources.add(p)
         elif p.is_dir():
-            if exclude is None:
-                exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
-                gitignore = get_gitignore(root)
-                p_gitignore = get_gitignore(p)
-                # No need to use p's gitignore if it is identical to root's gitignore
-                # (i.e. root and p point to the same directory).
-                if gitignore != p_gitignore:
-                    gitignore += p_gitignore
-            else:
-                gitignore = None
+            p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
+            if using_default_exclude:
+                gitignore = {
+                    root: root_gitignore,
+                    p: get_gitignore(p),
+                }
             sources.update(
                 gen_python_files(
                     p.iterdir(),
             sources.update(
                 gen_python_files(
                     p.iterdir(),
@@ -914,9 +926,6 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
     `mode` is passed to :func:`format_str`.
     """
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
     `mode` is passed to :func:`format_str`.
     """
-    if not src_contents.strip():
-        raise NothingChanged
-
     if mode.is_ipynb:
         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
     else:
     if mode.is_ipynb:
         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
     else:
@@ -1011,6 +1020,9 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon
     Operate cell-by-cell, only on code cells, only for Python notebooks.
     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
     """
     Operate cell-by-cell, only on code cells, only for Python notebooks.
     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
     """
+    if not src_contents:
+        raise NothingChanged
+
     trailing_newline = src_contents[-1] == "\n"
     modified = False
     nb = json.loads(src_contents)
     trailing_newline = src_contents[-1] == "\n"
     modified = False
     nb = json.loads(src_contents)
@@ -1082,8 +1094,13 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
         future_imports = get_future_imports(src_node)
         versions = detect_target_versions(src_node, future_imports=future_imports)
 
         future_imports = get_future_imports(src_node)
         versions = detect_target_versions(src_node, future_imports=future_imports)
 
-    normalize_fmt_off(src_node, preview=mode.preview)
-    lines = LineGenerator(mode=mode)
+    context_manager_features = {
+        feature
+        for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
+        if supports_feature(versions, feature)
+    }
+    normalize_fmt_off(src_node)
+    lines = LineGenerator(mode=mode, features=context_manager_features)
     elt = EmptyLineTracker(mode=mode)
     split_line_features = {
         feature
     elt = EmptyLineTracker(mode=mode)
     split_line_features = {
         feature
@@ -1103,6 +1120,13 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
     dst_contents = []
     for block in dst_blocks:
         dst_contents.extend(block.all_lines())
     dst_contents = []
     for block in dst_blocks:
         dst_contents.extend(block.all_lines())
+    if not dst_contents:
+        # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
+        # and check if normalized_content has more than one line
+        normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
+        if "\n" in normalized_content:
+            return newline
+        return ""
     return "".join(dst_contents)
 
 
     return "".join(dst_contents)
 
 
@@ -1138,6 +1162,10 @@ def get_features_used(  # noqa: C901
     - relaxed decorator syntax;
     - usage of __future__ flags (annotations);
     - print / exec statements;
     - relaxed decorator syntax;
     - usage of __future__ flags (annotations);
     - print / exec statements;
+    - parenthesized context managers;
+    - match statements;
+    - except* clause;
+    - variadic generics;
     """
     features: Set[Feature] = set()
     if future_imports:
     """
     features: Set[Feature] = set()
     if future_imports:
@@ -1213,6 +1241,23 @@ def get_features_used(  # noqa: C901
         ):
             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
 
         ):
             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
 
+        elif (
+            n.type == syms.with_stmt
+            and len(n.children) > 2
+            and n.children[1].type == syms.atom
+        ):
+            atom_children = n.children[1].children
+            if (
+                len(atom_children) == 3
+                and atom_children[0].type == token.LPAR
+                and atom_children[1].type == syms.testlist_gexp
+                and atom_children[2].type == token.RPAR
+            ):
+                features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
+
+        elif n.type == syms.match_stmt:
+            features.add(Feature.PATTERN_MATCHING)
+
         elif (
             n.type == syms.except_clause
             and len(n.children) >= 2
         elif (
             n.type == syms.except_clause
             and len(n.children) >= 2