X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/767604e03f5e454ae5b5c268cd5831c672f46de8..01b8d3d4095ebdb91d0d39012a517931625c63cb:/src/black/__init__.py

diff --git a/src/black/__init__.py b/src/black/__init__.py
index ded4a73..dbcb559 100644
--- a/src/black/__init__.py
+++ b/src/black/__init__.py
@@ -7,7 +7,7 @@ import tokenize
 import traceback
 from contextlib import contextmanager
 from dataclasses import replace
-from datetime import datetime
+from datetime import datetime, timezone
 from enum import Enum
 from json.decoder import JSONDecodeError
 from pathlib import Path
@@ -30,6 +30,7 @@ from typing import (
 import click
 from click.core import ParameterSource
 from mypy_extensions import mypyc_attr
+from pathspec import PathSpec
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
 
 from _black_version import version as __version__
@@ -61,7 +62,7 @@ from black.handle_ipynb_magics import (
     unmask_cell,
 )
 from black.linegen import LN, LineGenerator, transform_line
-from black.lines import EmptyLineTracker, Line
+from black.lines import EmptyLineTracker, LinesBlock
 from black.mode import (
     FUTURE_FLAG_TO_FEATURE,
     VERSION_TO_FEATURES,
@@ -218,8 +219,9 @@ def validate_regex(
     callback=target_version_option_callback,
     multiple=True,
     help=(
-        "Python versions that should be supported by Black's output. [default: per-file"
-        " auto-detection]"
+        "Python versions that should be supported by Black's output. By default, Black"
+        " will try to infer this from the project metadata in pyproject.toml. If this"
+        " does not yield conclusive results, Black will use per-file auto-detection."
     ),
 )
 @click.option(
@@ -243,11 +245,17 @@ def validate_regex(
     multiple=True,
     help=(
         "When processing Jupyter Notebooks, add the given magic to the list"
-        f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})."
+        f" of known python-magics ({', '.join(sorted(PYTHON_CELL_MAGICS))})."
         " Useful for formatting cells with custom python magics."
     ),
     default=[],
 )
+@click.option(
+    "-x",
+    "--skip-source-first-line",
+    is_flag=True,
+    help="Skip the first line of the source code.",
+)
 @click.option(
     "-S",
     "--skip-string-normalization",
@@ -428,6 +436,7 @@ def main(  # noqa: C901
     pyi: bool,
     ipynb: bool,
     python_cell_magics: Sequence[str],
+    skip_source_first_line: bool,
     skip_string_normalization: bool,
     skip_magic_trailing_comma: bool,
     experimental_string_processing: bool,
@@ -470,16 +479,20 @@ def main(  # noqa: C901
             )
 
             normalized = [
-                (source, source)
-                if source == "-"
-                else (normalize_path_maybe_ignore(Path(source), root), source)
+                (
+                    (source, source)
+                    if source == "-"
+                    else (normalize_path_maybe_ignore(Path(source), root), source)
+                )
                 for source in src
             ]
             srcs_string = ", ".join(
                 [
-                    f'"{_norm}"'
-                    if _norm
-                    else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+                    (
+                        f'"{_norm}"'
+                        if _norm
+                        else f'\033[31m"{source} (skipping - invalid)"\033[34m'
+                    )
                     for _norm, source in normalized
                 ]
             )
@@ -501,6 +514,9 @@ def main(  # noqa: C901
                 out("Using configuration from project root.", fg="blue")
             else:
                 out(f"Using configuration in '{config}'.", fg="blue")
+            if ctx.default_map:
+                for param, value in ctx.default_map.items():
+                    out(f"{param}: {value}")
 
     error_msg = "Oh no! 💥 💔 💥"
     if (
@@ -528,6 +544,7 @@ def main(  # noqa: C901
         line_length=line_length,
         is_pyi=pyi,
         is_ipynb=ipynb,
+        skip_source_first_line=skip_source_first_line,
         string_normalization=not skip_string_normalization,
         magic_trailing_comma=not skip_magic_trailing_comma,
         experimental_string_processing=experimental_string_processing,
@@ -617,6 +634,11 @@ def get_sources(
     sources: Set[Path] = set()
     root = ctx.obj["root"]
 
+    using_default_exclude = exclude is None
+    exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
+    gitignore: Optional[Dict[Path, PathSpec]] = None
+    root_gitignore = get_gitignore(root)
+
     for s in src:
         if s == "-" and stdin_filename:
             p = Path(stdin_filename)
@@ -650,16 +672,12 @@ def get_sources(
 
             sources.add(p)
         elif p.is_dir():
-            if exclude is None:
-                exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
-                gitignore = get_gitignore(root)
-                p_gitignore = get_gitignore(p)
-                # No need to use p's gitignore if it is identical to root's gitignore
-                # (i.e. root and p point to the same directory).
-                if gitignore != p_gitignore:
-                    gitignore += p_gitignore
-            else:
-                gitignore = None
+            p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report)
+            if using_default_exclude:
+                gitignore = {
+                    root: root_gitignore,
+                    p: get_gitignore(p),
+                }
             sources.update(
                 gen_python_files(
                     p.iterdir(),
@@ -789,8 +807,11 @@ def format_file_in_place(
     elif src.suffix == ".ipynb":
         mode = replace(mode, is_ipynb=True)
 
-    then = datetime.utcfromtimestamp(src.stat().st_mtime)
+    then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc)
+    header = b""
     with open(src, "rb") as buf:
+        if mode.skip_source_first_line:
+            header = buf.readline()
         src_contents, encoding, newline = decode_bytes(buf.read())
     try:
         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
@@ -800,14 +821,16 @@ def format_file_in_place(
         raise ValueError(
             f"File '{src}' cannot be parsed as valid Jupyter notebook."
         ) from None
+    src_contents = header.decode(encoding) + src_contents
+    dst_contents = header.decode(encoding) + dst_contents
 
     if write_back == WriteBack.YES:
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
-        now = datetime.utcnow()
-        src_name = f"{src}\t{then} +0000"
-        dst_name = f"{src}\t{now} +0000"
+        now = datetime.now(timezone.utc)
+        src_name = f"{src}\t{then}"
+        dst_name = f"{src}\t{now}"
         if mode.is_ipynb:
             diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name)
         else:
@@ -845,7 +868,7 @@ def format_stdin_to_stdout(
     write a diff to stdout. The `mode` argument is passed to
     :func:`format_file_contents`.
     """
-    then = datetime.utcnow()
+    then = datetime.now(timezone.utc)
 
     if content is None:
         src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
@@ -870,9 +893,9 @@ def format_stdin_to_stdout(
                 dst += "\n"
             f.write(dst)
         elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
-            now = datetime.utcnow()
-            src_name = f"STDIN\t{then} +0000"
-            dst_name = f"STDOUT\t{now} +0000"
+            now = datetime.now(timezone.utc)
+            src_name = f"STDIN\t{then}"
+            dst_name = f"STDOUT\t{now}"
             d = diff(src, dst, src_name, dst_name)
             if write_back == WriteBack.COLOR_DIFF:
                 d = color_diff(d)
@@ -901,9 +924,6 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
     `mode` is passed to :func:`format_str`.
     """
-    if not src_contents.strip():
-        raise NothingChanged
-
     if mode.is_ipynb:
         dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
     else:
@@ -998,6 +1018,9 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon
     Operate cell-by-cell, only on code cells, only for Python notebooks.
     If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
     """
+    if not src_contents:
+        raise NothingChanged
+
     trailing_newline = src_contents[-1] == "\n"
     modified = False
     nb = json.loads(src_contents)
@@ -1062,31 +1085,46 @@ def format_str(src_contents: str, *, mode: Mode) -> str:
 
 def _format_str_once(src_contents: str, *, mode: Mode) -> str:
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
-    dst_contents = []
+    dst_blocks: List[LinesBlock] = []
     if mode.target_versions:
         versions = mode.target_versions
     else:
         future_imports = get_future_imports(src_node)
         versions = detect_target_versions(src_node, future_imports=future_imports)
 
-    normalize_fmt_off(src_node, preview=mode.preview)
-    lines = LineGenerator(mode=mode)
-    elt = EmptyLineTracker(is_pyi=mode.is_pyi)
-    empty_line = Line(mode=mode)
-    after = 0
+    context_manager_features = {
+        feature
+        for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS}
+        if supports_feature(versions, feature)
+    }
+    normalize_fmt_off(src_node)
+    lines = LineGenerator(mode=mode, features=context_manager_features)
+    elt = EmptyLineTracker(mode=mode)
     split_line_features = {
         feature
         for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
         if supports_feature(versions, feature)
     }
+    block: Optional[LinesBlock] = None
     for current_line in lines.visit(src_node):
-        dst_contents.append(str(empty_line) * after)
-        before, after = elt.maybe_empty_lines(current_line)
-        dst_contents.append(str(empty_line) * before)
+        block = elt.maybe_empty_lines(current_line)
+        dst_blocks.append(block)
         for line in transform_line(
             current_line, mode=mode, features=split_line_features
         ):
-            dst_contents.append(str(line))
+            block.content_lines.append(str(line))
+    if dst_blocks:
+        dst_blocks[-1].after = 0
+    dst_contents = []
+    for block in dst_blocks:
+        dst_contents.extend(block.all_lines())
+    if not dst_contents:
+        # Use decode_bytes to retrieve the correct source newline (CRLF or LF),
+        # and check if normalized_content has more than one line
+        normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8"))
+        if "\n" in normalized_content:
+            return newline
+        return ""
     return "".join(dst_contents)
 
 
@@ -1122,6 +1160,10 @@ def get_features_used(  # noqa: C901
     - relaxed decorator syntax;
     - usage of __future__ flags (annotations);
     - print / exec statements;
+    - parenthesized context managers;
+    - match statements;
+    - except* clause;
+    - variadic generics;
     """
     features: Set[Feature] = set()
     if future_imports:
@@ -1197,6 +1239,23 @@ def get_features_used(  # noqa: C901
         ):
             features.add(Feature.ANN_ASSIGN_EXTENDED_RHS)
 
+        elif (
+            n.type == syms.with_stmt
+            and len(n.children) > 2
+            and n.children[1].type == syms.atom
+        ):
+            atom_children = n.children[1].children
+            if (
+                len(atom_children) == 3
+                and atom_children[0].type == token.LPAR
+                and atom_children[1].type == syms.testlist_gexp
+                and atom_children[2].type == token.RPAR
+            ):
+                features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS)
+
+        elif n.type == syms.match_stmt:
+            features.add(Feature.PATTERN_MATCHING)
+
         elif (
             n.type == syms.except_clause
             and len(n.children) >= 2
@@ -1216,6 +1275,9 @@ def get_features_used(  # noqa: C901
         ):
             features.add(Feature.VARIADIC_GENERICS)
 
+        elif n.type in (syms.type_stmt, syms.typeparams):
+            features.add(Feature.TYPE_PARAMS)
+
     return features
 
 
@@ -1369,13 +1431,15 @@ def patch_click() -> None:
 
     for module in modules:
         if hasattr(module, "_verify_python3_env"):
-            module._verify_python3_env = lambda: None  # type: ignore
+            module._verify_python3_env = lambda: None
         if hasattr(module, "_verify_python_env"):
-            module._verify_python_env = lambda: None  # type: ignore
+            module._verify_python_env = lambda: None
 
 
 def patched_main() -> None:
-    if sys.platform == "win32" and getattr(sys, "frozen", False):
+    # PyInstaller patches multiprocessing to need freeze_support() even in non-Windows
+    # environments so just assume we always need to call it if frozen.
+    if getattr(sys, "frozen", False):
         from multiprocessing import freeze_support
 
         freeze_support()