All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
   3 from pathlib import Path
 
   4 from typing import List, Tuple, Any
 
   5 from functools import partial
 
   8 from black.output import out, err
 
   9 from black.debug import DebugVisitor
 
  11 THIS_DIR = Path(__file__).parent
 
  12 PROJECT_ROOT = THIS_DIR.parent
 
  13 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
 
  14 DETERMINISTIC_HEADER = "[Deterministic header]"
 
  17 DEFAULT_MODE = black.Mode()
 
  18 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
 
  19 fs = partial(black.format_str, mode=DEFAULT_MODE)
 
  22 def dump_to_stderr(*output: str) -> str:
 
  23     return "\n" + "\n".join(output) + "\n"
 
  26 class BlackBaseTestCase(unittest.TestCase):
 
  28     _diffThreshold = 2 ** 20
 
  30     def assertFormatEqual(self, expected: str, actual: str) -> None:
 
  31         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
 
  32             bdv: DebugVisitor[Any]
 
  33             out("Expected tree:", fg="green")
 
  35                 exp_node = black.lib2to3_parse(expected)
 
  37                 list(bdv.visit(exp_node))
 
  38             except Exception as ve:
 
  40             out("Actual tree:", fg="red")
 
  42                 exp_node = black.lib2to3_parse(actual)
 
  44                 list(bdv.visit(exp_node))
 
  45             except Exception as ve:
 
  47         self.assertMultiLineEqual(expected, actual)
 
  50 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
 
  51     """read_data('test_name') -> 'input', 'output'"""
 
  52     if not name.endswith((".py", ".pyi", ".out", ".diff")):
 
  54     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
 
  55     return read_data_from_file(base_dir / name)
 
  58 def read_data_from_file(file_name: Path) -> Tuple[str, str]:
 
  59     with open(file_name, "r", encoding="utf8") as test:
 
  60         lines = test.readlines()
 
  61     _input: List[str] = []
 
  62     _output: List[str] = []
 
  65         line = line.replace(EMPTY_LINE, "")
 
  66         if line.rstrip() == "# output":
 
  71     if _input and not _output:
 
  72         # If there's no output marker, treat the entire file as already pre-formatted.
 
  74     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"