]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Do not add trailing commas to return type annotations using PEP 604 unions (#3735)
[etc/vim.git] / src / black / __init__.py
index 39d12968c4b15e4756fdf138fa17f68423a5005b..dbcb559f09d5d647b603c1cbc4099443a63de2f5 100644 (file)
@@ -7,7 +7,7 @@ import tokenize
 import traceback
 from contextlib import contextmanager
 from dataclasses import replace
 import traceback
 from contextlib import contextmanager
 from dataclasses import replace
-from datetime import datetime
+from datetime import datetime, timezone
 from enum import Enum
 from json.decoder import JSONDecodeError
 from pathlib import Path
 from enum import Enum
 from json.decoder import JSONDecodeError
 from pathlib import Path
@@ -219,8 +219,9 @@ def validate_regex(
     callback=target_version_option_callback,
     multiple=True,
     help=(
     callback=target_version_option_callback,
     multiple=True,
     help=(
-        "Python versions that should be supported by Black's output. [default: per-file"
-        " auto-detection]"
+        "Python versions that should be supported by Black's output. By default, Black"
+        " will try to infer this from the project metadata in pyproject.toml. If this"
+        " does not yield conclusive results, Black will use per-file auto-detection."
     ),
 )
 @click.option(
     ),
 )
 @click.option(
@@ -244,7 +245,7 @@ def validate_regex(
     multiple=True,
     help=(
         "When processing Jupyter Notebooks, add the given magic to the list"
     multiple=True,
     help=(
         "When processing Jupyter Notebooks, add the given magic to the list"
-        f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
+        f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
         " Useful for formatting cells with custom python magics."
     ),
     default=[],
         " Useful for formatting cells with custom python magics."
     ),
     default=[],
@@ -478,16 +479,20 @@ def main(  # noqa: C901
             )
 
             normalized = [
             )
 
             normalized = [
-                (source, source)
-                if source == "-"
-                else (normalize_path_maybe_ignore(Path(source), root), source)
+                (
+                    (source, source)
+                    if source == "-"
+                    else (normalize_path_maybe_ignore(Path(source), root), source)
+                )
                 for source in src
             ]
             srcs_string = ", ".join(
                 [
                 for source in src
             ]
             srcs_string = ", ".join(
                 [
-                    f'"{_norm}"'
-                    if _norm
-                    else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+                    (
+                        f'"{_norm}"'
+                        if _norm
+                        else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+                    )
                     for _norm, source in normalized
                 ]
             )
                     for _norm, source in normalized
                 ]
             )
@@ -498,10 +503,8 @@ def main(  # noqa: C901
             user_level_config = str(find_user_pyproject_toml())
             if config == user_level_config:
                 out(
             user_level_config = str(find_user_pyproject_toml())
             if config == user_level_config:
                 out(
-                    (
-                        "Using configuration from user-level config at "
-                        f"'{user_level_config}'."
-                    ),
+                    "Using configuration from user-level config at "
+                    f"'{user_level_config}'.",
                     fg="blue",
                 )
             elif config_source in (
                     fg="blue",
                 )
             elif config_source in (
@@ -511,6 +514,9 @@ def main(  # noqa: C901
                 out("Using configuration from project root.", fg="blue")
             else:
                 out(f"Using configuration in '{config}'.", fg="blue")
                 out("Using configuration from project root.", fg="blue")
             else:
                 out(f"Using configuration in '{config}'.", fg="blue")
+            if ctx.default_map:
+                for param, value in ctx.default_map.items():
+                    out(f"{param}: {value}")
 
     error_msg = "Oh no! 💥 💔 💥"
     if (
 
     error_msg = "Oh no! 💥 💔 💥"
     if (
@@ -666,10 +672,11 @@ def get_sources(
 
             sources.add(p)
         elif p.is_dir():
 
             sources.add(p)
         elif p.is_dir():
+            p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
             if using_default_exclude:
                 gitignore = {
                     root: root_gitignore,
             if using_default_exclude:
                 gitignore = {
                     root: root_gitignore,
-                    root / p: get_gitignore(p),
+                    p: get_gitignore(p),
                 }
             sources.update(
                 gen_python_files(
                 }
             sources.update(
                 gen_python_files(
@@ -800,7 +807,7 @@ def format_file_in_place(
     elif src.suffix == ".ipynb":
         mode = replace(mode, is_ipynb=True)
 
     elif src.suffix == ".ipynb":
         mode = replace(mode, is_ipynb=True)
 
-    then = datetime.utcfromtimestamp(src.stat().st_mtime)
+    then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)
     header = b""
     with open(src, "rb") as buf:
         if mode.skip_source_first_line:
     header = b""
     with open(src, "rb") as buf:
         if mode.skip_source_first_line:
@@ -821,9 +828,9 @@ def format_file_in_place(
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
-        now = datetime.utcnow()
-        src_name = f"{src}\t{then} +0000"
-        dst_name = f"{src}\t{now} +0000"
+        now = datetime.now(timezone.utc)
+        src_name = f"{src}\t{then}"
+        dst_name = f"{src}\t{now}"
         if mode.is_ipynb:
             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
         else:
         if mode.is_ipynb:
             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
         else:
@@ -861,7 +868,7 @@ def format_stdin_to_stdout(
     write a diff to stdout. The `mode` argument is passed to
     :func:`format_file_contents`.
     """
     write a diff to stdout. The `mode` argument is passed to
     :func:`format_file_contents`.
     """
-    then = datetime.utcnow()
+    then = datetime.now(timezone.utc)
 
     if content is None:
         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
 
     if content is None:
         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
@@ -886,9 +893,9 @@ def format_stdin_to_stdout(
                 dst += "\n"
             f.write(dst)
         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
                 dst += "\n"
             f.write(dst)
         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
-            now = datetime.utcnow()
-            src_name = f"STDIN\t{then} +0000"
-            dst_name = f"STDOUT\t{now} +0000"
+            now = datetime.now(timezone.utc)
+            src_name = f"STDIN\t{then}"
+            dst_name = f"STDOUT\t{now}"
             d = diff(src, dst, src_name, dst_name)
             if write_back == WriteBack.COLOR_DIFF:
                 d = color_diff(d)
             d = diff(src, dst, src_name, dst_name)
             if write_back == WriteBack.COLOR_DIFF:
                 d = color_diff(d)
@@ -917,9 +924,6 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
     `mode` is passed to :func:`format_str`.
     """
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
     `mode` is passed to :func:`format_str`.
     """
-    if not mode.preview and not src_contents.strip():
-        raise NothingChanged
-
     if mode.is_ipynb:
         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
     else:
     if mode.is_ipynb:
         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
     else:
@@ -1014,7 +1018,7 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon
     Operate cell-by-cell, only on code cells, only for Python notebooks.
     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
     """
     Operate cell-by-cell, only on code cells, only for Python notebooks.
     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
     """
-    if mode.preview and not src_contents:
+    if not src_contents:
         raise NothingChanged
 
     trailing_newline = src_contents[-1] == "\n"
         raise NothingChanged
 
     trailing_newline = src_contents[-1] == "\n"
@@ -1088,8 +1092,13 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
         future_imports = get_future_imports(src_node)
         versions = detect_target_versions(src_node, future_imports=future_imports)
 
         future_imports = get_future_imports(src_node)
         versions = detect_target_versions(src_node, future_imports=future_imports)
 
-    normalize_fmt_off(src_node, preview=mode.preview)
-    lines = LineGenerator(mode=mode)
+    context_manager_features = {
+        feature
+        for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
+        if supports_feature(versions, feature)
+    }
+    normalize_fmt_off(src_node)
+    lines = LineGenerator(mode=mode, features=context_manager_features)
     elt = EmptyLineTracker(mode=mode)
     split_line_features = {
         feature
     elt = EmptyLineTracker(mode=mode)
     split_line_features = {
         feature
@@ -1109,7 +1118,7 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
     dst_contents = []
     for block in dst_blocks:
         dst_contents.extend(block.all_lines())
     dst_contents = []
     for block in dst_blocks:
         dst_contents.extend(block.all_lines())
-    if mode.preview and not dst_contents:
+    if not dst_contents:
         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
         # and check if normalized_content has more than one line
         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
         # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
         # and check if normalized_content has more than one line
         normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
@@ -1151,6 +1160,10 @@ def get_features_used(  # noqa: C901
     - relaxed decorator syntax;
     - usage of __future__ flags (annotations);
     - print / exec statements;
     - relaxed decorator syntax;
     - usage of __future__ flags (annotations);
     - print / exec statements;
+    - parenthesized context managers;
+    - match statements;
+    - except* clause;
+    - variadic generics;
     """
     features: Set[Feature] = set()
     if future_imports:
     """
     features: Set[Feature] = set()
     if future_imports:
@@ -1226,6 +1239,23 @@ def get_features_used(  # noqa: C901
         ):
             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
 
         ):
             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
 
+        elif (
+            n.type == syms.with_stmt
+            and len(n.children) > 2
+            and n.children[1].type == syms.atom
+        ):
+            atom_children = n.children[1].children
+            if (
+                len(atom_children) == 3
+                and atom_children[0].type == token.LPAR
+                and atom_children[1].type == syms.testlist_gexp
+                and atom_children[2].type == token.RPAR
+            ):
+                features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
+
+        elif n.type == syms.match_stmt:
+            features.add(Feature.PATTERN_MATCHING)
+
         elif (
             n.type == syms.except_clause
             and len(n.children) >= 2
         elif (
             n.type == syms.except_clause
             and len(n.children) >= 2
@@ -1245,6 +1275,9 @@ def get_features_used(  # noqa: C901
         ):
             features.add(Feature.VARIADIC_GENERICS)
 
         ):
             features.add(Feature.VARIADIC_GENERICS)
 
+        elif n.type in (syms.type_stmt, syms.typeparams):
+            features.add(Feature.TYPE_PARAMS)
+
     return features
 
 
     return features