import re
import sys
from functools import lru_cache
-from typing import List, Match, Pattern
+from typing import Final, List, Match, Pattern
+from black._width_table import WIDTH_TABLE
from blib2to3.pytree import Leaf
-if sys.version_info < (3, 8):
- from typing_extensions import Final
-else:
- from typing import Final
-
-
STRING_PREFIX_CHARS: Final = "furbFURB" # All possible string prefix characters.
STRING_PREFIX_RE: Final = re.compile(
r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", re.DOTALL
return back_slashes + "N{" + groups["N"].upper() + "}"
leaf.value = re.sub(UNICODE_ESCAPE_RE, replace, text)
+
+
+@lru_cache(maxsize=4096)
+def char_width(char: str) -> int:
+ """Return the width of a single character as it would be displayed in a
+ terminal or editor (which respects Unicode East Asian Width).
+
+ Full width characters are counted as 2, while half width characters are
+ counted as 1. Also control characters are counted as 0.
+ """
+ table = WIDTH_TABLE
+ codepoint = ord(char)
+ highest = len(table) - 1
+ lowest = 0
+ idx = highest // 2
+ while True:
+ start_codepoint, end_codepoint, width = table[idx]
+ if codepoint < start_codepoint:
+ highest = idx - 1
+ elif codepoint > end_codepoint:
+ lowest = idx + 1
+ else:
+ return 0 if width < 0 else width
+ if highest < lowest:
+ break
+ idx = (highest + lowest) // 2
+ return 1
+
+
+def str_width(line_str: str) -> int:
+ """Return the width of `line_str` as it would be displayed in a terminal
+ or editor (which respects Unicode East Asian Width).
+
+ You could utilize this function to determine, for example, if a string
+ is too wide to display in a terminal or editor.
+ """
+ if line_str.isascii():
+ # Fast path for a line consisting of only ASCII characters
+ return len(line_str)
+ return sum(map(char_width, line_str))
+
+
+def count_chars_in_width(line_str: str, max_width: int) -> int:
+ """Count the number of characters in `line_str` that would fit in a
+ terminal or editor of `max_width` (which respects Unicode East Asian
+ Width).
+ """
+ total_width = 0
+ for i, char in enumerate(line_str):
+ width = char_width(char)
+ if width + total_width > max_width:
+ return i
+ total_width += width
+ return len(line_str)