]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Issue 1144: Fix type annotations for --config option (#1166)
[etc/vim.git] / black.py
index 69d24c54133320e3cd7b6c8e0a17cc3c23f17981..2df03f7bacd0adeee573c8a7d2c9228af5a34b30 100644 (file)
--- a/black.py
+++ b/black.py
@@ -186,7 +186,7 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
 
 
 @dataclass
 
 
 @dataclass
-class FileMode:
+class Mode:
     target_versions: Set[TargetVersion] = field(default_factory=set)
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
     target_versions: Set[TargetVersion] = field(default_factory=set)
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
@@ -209,6 +209,10 @@ class FileMode:
         return ".".join(parts)
 
 
         return ".".join(parts)
 
 
+# Legacy name, left for integrations.
+FileMode = Mode
+
+
 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
@@ -231,14 +235,13 @@ def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
 
 
 def read_pyproject_toml(
 
 
 def read_pyproject_toml(
-    ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
+    ctx: click.Context, param: click.Parameter, value: Optional[str]
 ) -> Optional[str]:
     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
 
     Returns the path to a successfully found and read configuration file, None
     otherwise.
     """
 ) -> Optional[str]:
     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
 
     Returns the path to a successfully found and read configuration file, None
     otherwise.
     """
-    assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
     if not value:
         value = find_pyproject_toml(ctx.params.get("src", ()))
         if value is None:
     if not value:
         value = find_pyproject_toml(ctx.params.get("src", ()))
         if value is None:
@@ -254,9 +257,12 @@ def read_pyproject_toml(
     if not config:
         return None
 
     if not config:
         return None
 
-    if ctx.default_map is None:
-        ctx.default_map = {}
-    ctx.default_map.update(config)  # type: ignore  # bad types in .pyi
+    default_map: Dict[str, Any] = {}
+    if ctx.default_map:
+        default_map.update(ctx.default_map)
+    default_map.update(config)
+
+    ctx.default_map = default_map
     return value
 
 
     return value
 
 
@@ -390,7 +396,12 @@ def target_version_option_callback(
 @click.option(
     "--config",
     type=click.Path(
 @click.option(
     "--config",
     type=click.Path(
-        exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
+        exists=True,
+        file_okay=True,
+        dir_okay=False,
+        readable=True,
+        allow_dash=False,
+        path_type=str,
     ),
     is_eager=True,
     callback=read_pyproject_toml,
     ),
     is_eager=True,
     callback=read_pyproject_toml,
@@ -419,7 +430,7 @@ def main(
     write_back = WriteBack.from_configuration(check=check, diff=diff)
     if target_version:
         if py36:
     write_back = WriteBack.from_configuration(check=check, diff=diff)
     if target_version:
         if py36:
-            err(f"Cannot use both --target-version and --py36")
+            err("Cannot use both --target-version and --py36")
             ctx.exit(2)
         else:
             versions = set(target_version)
             ctx.exit(2)
         else:
             versions = set(target_version)
@@ -432,7 +443,7 @@ def main(
     else:
         # We'll autodetect later.
         versions = set()
     else:
         # We'll autodetect later.
         versions = set()
-    mode = FileMode(
+    mode = Mode(
         target_versions=versions,
         line_length=line_length,
         is_pyi=pyi,
         target_versions=versions,
         line_length=line_length,
         is_pyi=pyi,
@@ -507,7 +518,7 @@ def path_empty(
 
 
 def reformat_one(
 
 
 def reformat_one(
-    src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
+    src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
@@ -540,11 +551,7 @@ def reformat_one(
 
 
 def reformat_many(
 
 
 def reformat_many(
-    sources: Set[Path],
-    fast: bool,
-    write_back: WriteBack,
-    mode: FileMode,
-    report: "Report",
+    sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
 ) -> None:
     """Reformat multiple files using a ProcessPoolExecutor."""
     loop = asyncio.get_event_loop()
 ) -> None:
     """Reformat multiple files using a ProcessPoolExecutor."""
     loop = asyncio.get_event_loop()
@@ -574,7 +581,7 @@ async def schedule_formatting(
     sources: Set[Path],
     fast: bool,
     write_back: WriteBack,
     sources: Set[Path],
     fast: bool,
     write_back: WriteBack,
-    mode: FileMode,
+    mode: Mode,
     report: "Report",
     loop: asyncio.AbstractEventLoop,
     executor: Executor,
     report: "Report",
     loop: asyncio.AbstractEventLoop,
     executor: Executor,
@@ -644,7 +651,7 @@ async def schedule_formatting(
 def format_file_in_place(
     src: Path,
     fast: bool,
 def format_file_in_place(
     src: Path,
     fast: bool,
-    mode: FileMode,
+    mode: Mode,
     write_back: WriteBack = WriteBack.NO,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     write_back: WriteBack = WriteBack.NO,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
@@ -688,7 +695,7 @@ def format_file_in_place(
 
 
 def format_stdin_to_stdout(
 
 
 def format_stdin_to_stdout(
-    fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
+    fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
@@ -720,9 +727,7 @@ def format_stdin_to_stdout(
         f.detach()
 
 
         f.detach()
 
 
-def format_file_contents(
-    src_contents: str, *, fast: bool, mode: FileMode
-) -> FileContent:
+def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
     """Reformat contents a file and return new contents.
 
     If `fast` is False, additionally confirm that the reformatted code is
     """Reformat contents a file and return new contents.
 
     If `fast` is False, additionally confirm that the reformatted code is
@@ -742,11 +747,34 @@ def format_file_contents(
     return dst_contents
 
 
     return dst_contents
 
 
-def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
+def format_str(src_contents: str, *, mode: Mode) -> FileContent:
     """Reformat a string and return new contents.
 
     `mode` determines formatting options, such as how many characters per line are
     """Reformat a string and return new contents.
 
     `mode` determines formatting options, such as how many characters per line are
-    allowed.
+    allowed.  Example:
+
+    >>> import black
+    >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
+    def f(arg: str = "") -> None:
+        ...
+
+    A more complex example:
+    >>> print(
+    ...   black.format_str(
+    ...     "def f(arg:str='')->None: hey",
+    ...     mode=black.Mode(
+    ...       target_versions={black.TargetVersion.PY36},
+    ...       line_length=10,
+    ...       string_normalization=False,
+    ...       is_pyi=False,
+    ...     ),
+    ...   ),
+    ... )
+    def f(
+        arg: str = '',
+    ) -> None:
+        hey
+
     """
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
     dst_contents = []
     """
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
     dst_contents = []
@@ -2842,7 +2870,7 @@ def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
     assert match is not None, f"failed to match string {leaf.value!r}"
     orig_prefix = match.group(1)
     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
     assert match is not None, f"failed to match string {leaf.value!r}"
     orig_prefix = match.group(1)
-    new_prefix = orig_prefix.lower()
+    new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
     if remove_u_prefix:
         new_prefix = new_prefix.replace("u", "")
     leaf.value = f"{new_prefix}{match.group(2)}"
     if remove_u_prefix:
         new_prefix = new_prefix.replace("u", "")
     leaf.value = f"{new_prefix}{match.group(2)}"
@@ -3546,7 +3574,7 @@ def gen_python_files_in_dir(
     for child in path.iterdir():
         # First ignore files matching .gitignore
         if gitignore.match_file(child.as_posix()):
     for child in path.iterdir():
         # First ignore files matching .gitignore
         if gitignore.match_file(child.as_posix()):
-            report.path_ignored(child, f"matches the .gitignore file content")
+            report.path_ignored(child, "matches the .gitignore file content")
             continue
 
         # Then ignore with `exclude` option.
             continue
 
         # Then ignore with `exclude` option.
@@ -3570,7 +3598,7 @@ def gen_python_files_in_dir(
 
         exclude_match = exclude.search(normalized_path)
         if exclude_match and exclude_match.group(0):
 
         exclude_match = exclude.search(normalized_path)
         if exclude_match and exclude_match.group(0):
-            report.path_ignored(child, f"matches the --exclude regular expression")
+            report.path_ignored(child, "matches the --exclude regular expression")
             continue
 
         if child.is_dir():
             continue
 
         if child.is_dir():
@@ -3813,7 +3841,7 @@ def assert_equivalent(src: str, dst: str) -> None:
         ) from None
 
 
         ) from None
 
 
-def assert_stable(src: str, dst: str, mode: FileMode) -> None:
+def assert_stable(src: str, dst: str, mode: Mode) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
     newdst = format_str(dst, mode=mode)
     if dst != newdst:
     """Raise AssertionError if `dst` reformats differently the second time."""
     newdst = format_str(dst, mode=mode)
     if dst != newdst:
@@ -3855,8 +3883,8 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
 
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
 
-    a_lines = [line + "\n" for line in a.split("\n")]
-    b_lines = [line + "\n" for line in b.split("\n")]
+    a_lines = [line + "\n" for line in a.splitlines()]
+    b_lines = [line + "\n" for line in b.splitlines()]
     return "".join(
         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
     )
     return "".join(
         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
     )
@@ -4084,11 +4112,11 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     return False
 
 
     return False
 
 
-def get_cache_file(mode: FileMode) -> Path:
+def get_cache_file(mode: Mode) -> Path:
     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
 
 
     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
 
 
-def read_cache(mode: FileMode) -> Cache:
+def read_cache(mode: Mode) -> Cache:
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
@@ -4128,7 +4156,7 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set
     return todo, done
 
 
     return todo, done
 
 
-def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
+def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
     """Update the cache file."""
     cache_file = get_cache_file(mode)
     try:
     """Update the cache file."""
     cache_file = get_cache_file(mode)
     try: