]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Do not add trailing commas to return type annotations using PEP 604 unions (#3735)
[etc/vim.git] / src / black / __init__.py
index d9fba41ebd34136560e09ae296e1a4a425dec8f5..dbcb559f09d5d647b603c1cbc4099443a63de2f5 100644 (file)
@@ -7,7 +7,7 @@ import tokenize
 import traceback
 from contextlib import contextmanager
 from dataclasses import replace
 import traceback
 from contextlib import contextmanager
 from dataclasses import replace
-from datetime import datetime
+from datetime import datetime, timezone
 from enum import Enum
 from json.decoder import JSONDecodeError
 from pathlib import Path
 from enum import Enum
 from json.decoder import JSONDecodeError
 from pathlib import Path
@@ -30,6 +30,7 @@ from typing import (
 import click
 from click.core import ParameterSource
 from mypy_extensions import mypyc_attr
 import click
 from click.core import ParameterSource
 from mypy_extensions import mypyc_attr
+from pathspec import PathSpec
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
 
 from _black_version import version as __version__
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
 
 from _black_version import version as __version__
@@ -218,8 +219,9 @@ def validate_regex(
     callback=target_version_option_callback,
     multiple=True,
     help=(
     callback=target_version_option_callback,
     multiple=True,
     help=(
-        "Python versions that should be supported by Black's output. [default: per-file"
-        " auto-detection]"
+        "Python versions that should be supported by Black's output. By default, Black"
+        " will try to infer this from the project metadata in pyproject.toml. If this"
+        " does not yield conclusive results, Black will use per-file auto-detection."
     ),
 )
 @click.option(
     ),
 )
 @click.option(
@@ -243,7 +245,7 @@ def validate_regex(
     multiple=True,
     help=(
         "When processing Jupyter Notebooks, add the given magic to the list"
     multiple=True,
     help=(
         "When processing Jupyter Notebooks, add the given magic to the list"
-        f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
+        f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
         " Useful for formatting cells with custom python magics."
     ),
     default=[],
         " Useful for formatting cells with custom python magics."
     ),
     default=[],
@@ -477,16 +479,20 @@ def main(  # noqa: C901
             )
 
             normalized = [
             )
 
             normalized = [
-                (source, source)
-                if source == "-"
-                else (normalize_path_maybe_ignore(Path(source), root), source)
+                (
+                    (source, source)
+                    if source == "-"
+                    else (normalize_path_maybe_ignore(Path(source), root), source)
+                )
                 for source in src
             ]
             srcs_string = ", ".join(
                 [
                 for source in src
             ]
             srcs_string = ", ".join(
                 [
-                    f'"{_norm}"'
-                    if _norm
-                    else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+                    (
+                        f'"{_norm}"'
+                        if _norm
+                        else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+                    )
                     for _norm, source in normalized
                 ]
             )
                     for _norm, source in normalized
                 ]
             )
@@ -508,6 +514,9 @@ def main(  # noqa: C901
                 out("Using configuration from project root.", fg="blue")
             else:
                 out(f"Using configuration in '{config}'.", fg="blue")
                 out("Using configuration from project root.", fg="blue")
             else:
                 out(f"Using configuration in '{config}'.", fg="blue")
+            if ctx.default_map:
+                for param, value in ctx.default_map.items():
+                    out(f"{param}: {value}")
 
     error_msg = "Oh no! 💥 💔 💥"
     if (
 
     error_msg = "Oh no! 💥 💔 💥"
     if (
@@ -625,6 +634,11 @@ def get_sources(
     sources: Set[Path] = set()
     root = ctx.obj["root"]
 
     sources: Set[Path] = set()
     root = ctx.obj["root"]
 
+    using_default_exclude = exclude is None
+    exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
+    gitignore: Optional[Dict[Path, PathSpec]] = None
+    root_gitignore = get_gitignore(root)
+
     for s in src:
         if s == "-" and stdin_filename:
             p = Path(stdin_filename)
     for s in src:
         if s == "-" and stdin_filename:
             p = Path(stdin_filename)
@@ -658,16 +672,12 @@ def get_sources(
 
             sources.add(p)
         elif p.is_dir():
 
             sources.add(p)
         elif p.is_dir():
-            if exclude is None:
-                exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
-                gitignore = get_gitignore(root)
-                p_gitignore = get_gitignore(p)
-                # No need to use p's gitignore if it is identical to root's gitignore
-                # (i.e. root and p point to the same directory).
-                if gitignore != p_gitignore:
-                    gitignore += p_gitignore
-            else:
-                gitignore = None
+            p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
+            if using_default_exclude:
+                gitignore = {
+                    root: root_gitignore,
+                    p: get_gitignore(p),
+                }
             sources.update(
                 gen_python_files(
                     p.iterdir(),
             sources.update(
                 gen_python_files(
                     p.iterdir(),
@@ -797,7 +807,7 @@ def format_file_in_place(
     elif src.suffix == ".ipynb":
         mode = replace(mode, is_ipynb=True)
 
     elif src.suffix == ".ipynb":
         mode = replace(mode, is_ipynb=True)
 
-    then = datetime.utcfromtimestamp(src.stat().st_mtime)
+    then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)
     header = b""
     with open(src, "rb") as buf:
         if mode.skip_source_first_line:
     header = b""
     with open(src, "rb") as buf:
         if mode.skip_source_first_line:
@@ -818,9 +828,9 @@ def format_file_in_place(
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
-        now = datetime.utcnow()
-        src_name = f"{src}\t{then} +0000"
-        dst_name = f"{src}\t{now} +0000"
+        now = datetime.now(timezone.utc)
+        src_name = f"{src}\t{then}"
+        dst_name = f"{src}\t{now}"
         if mode.is_ipynb:
             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
         else:
         if mode.is_ipynb:
             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
         else:
@@ -858,7 +868,7 @@ def format_stdin_to_stdout(
     write a diff to stdout. The `mode` argument is passed to
     :func:`format_file_contents`.
     """
     write a diff to stdout. The `mode` argument is passed to
     :func:`format_file_contents`.
     """
-    then = datetime.utcnow()
+    then = datetime.now(timezone.utc)
 
     if content is None:
         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
 
     if content is None:
         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
@@ -883,9 +893,9 @@ def format_stdin_to_stdout(
                 dst += "\n"
             f.write(dst)
         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
                 dst += "\n"
             f.write(dst)
         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
-            now = datetime.utcnow()
-            src_name = f"STDIN\t{then} +0000"
-            dst_name = f"STDOUT\t{now} +0000"
+            now = datetime.now(timezone.utc)
+            src_name = f"STDIN\t{then}"
+            dst_name = f"STDOUT\t{now}"
             d = diff(src, dst, src_name, dst_name)
             if write_back == WriteBack.COLOR_DIFF:
                 d = color_diff(d)
             d = diff(src, dst, src_name, dst_name)
             if write_back == WriteBack.COLOR_DIFF:
                 d = color_diff(d)
@@ -914,9 +924,6 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
     `mode` is passed to :func:`format_str`.
     """
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
     `mode` is passed to :func:`format_str`.
     """
-    if not src_contents.strip():
-        raise NothingChanged
-
     if mode.is_ipynb:
         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
     else:
     if mode.is_ipynb:
         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
     else:
@@ -1011,6 +1018,9 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon
     Operate cell-by-cell, only on code cells, only for Python notebooks.
     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
     """
     Operate cell-by-cell, only on code cells, only for Python notebooks.
     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
     """
+    if not src_contents:
+        raise NothingChanged
+
     trailing_newline = src_contents[-1] == "\n"
     modified = False
     nb = json.loads(src_contents)
     trailing_newline = src_contents[-1] == "\n"
     modified = False
     nb = json.loads(src_contents)
@@ -1082,8 +1092,13 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
         future_imports = get_future_imports(src_node)
         versions = detect_target_versions(src_node, future_imports=future_imports)
 
         future_imports = get_future_imports(src_node)
         versions = detect_target_versions(src_node, future_imports=future_imports)
 
-    normalize_fmt_off(src_node, preview=mode.preview)
-    lines = LineGenerator(mode=mode)
+    context_manager_features = {
+        feature
+        for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
+        if supports_feature(versions, feature)
+    }
+    normalize_fmt_off(src_node)
+    lines = LineGenerator(mode=mode, features=context_manager_features)
     elt = EmptyLineTracker(mode=mode)
     split_line_features = {
         feature
     elt = EmptyLineTracker(mode=mode)
     split_line_features = {
         feature
@@ -1103,6 +1118,13 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
     dst_contents = []
     for block in dst_blocks:
         dst_contents.extend(block.all_lines())
     dst_contents = []
     for block in dst_blocks:
         dst_contents.extend(block.all_lines())
+    if not dst_contents:
+        # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
+        # and check if normalized_content has more than one line
+        normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
+        if "\n" in normalized_content:
+            return newline
+        return ""
     return "".join(dst_contents)
 
 
     return "".join(dst_contents)
 
 
@@ -1138,6 +1160,10 @@ def get_features_used(  # noqa: C901
     - relaxed decorator syntax;
     - usage of __future__ flags (annotations);
     - print / exec statements;
     - relaxed decorator syntax;
     - usage of __future__ flags (annotations);
     - print / exec statements;
+    - parenthesized context managers;
+    - match statements;
+    - except* clause;
+    - variadic generics;
     """
     features: Set[Feature] = set()
     if future_imports:
     """
     features: Set[Feature] = set()
     if future_imports:
@@ -1213,6 +1239,23 @@ def get_features_used(  # noqa: C901
         ):
             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
 
         ):
             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
 
+        elif (
+            n.type == syms.with_stmt
+            and len(n.children) > 2
+            and n.children[1].type == syms.atom
+        ):
+            atom_children = n.children[1].children
+            if (
+                len(atom_children) == 3
+                and atom_children[0].type == token.LPAR
+                and atom_children[1].type == syms.testlist_gexp
+                and atom_children[2].type == token.RPAR
+            ):
+                features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
+
+        elif n.type == syms.match_stmt:
+            features.add(Feature.PATTERN_MATCHING)
+
         elif (
             n.type == syms.except_clause
             and len(n.children) >= 2
         elif (
             n.type == syms.except_clause
             and len(n.children) >= 2
@@ -1232,6 +1275,9 @@ def get_features_used(  # noqa: C901
         ):
             features.add(Feature.VARIADIC_GENERICS)
 
         ):
             features.add(Feature.VARIADIC_GENERICS)
 
+        elif n.type in (syms.type_stmt, syms.typeparams):
+            features.add(Feature.TYPE_PARAMS)
+
     return features
 
 
     return features