X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/692c0f50d91e3163bb87401e4a0e070b2eb5b163..368f043f138112f63ff521c3481993c589eb7508:/src/black/__init__.py diff --git a/src/black/__init__.py b/src/black/__init__.py index 9034bf6..5c9ab75 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -48,7 +48,20 @@ from appdirs import user_cache_dir from dataclasses import dataclass, field, replace import click import toml -from typed_ast import ast3, ast27 + +try: + from typed_ast import ast3, ast27 +except ImportError: + if sys.version_info < (3, 8): + print( + "The typed_ast package is not installed.\n" + "You can install it with `python3 -m pip install typed-ast`.", + file=sys.stderr, + ) + sys.exit(1) + else: + ast3 = ast27 = ast + from pathspec import PathSpec # lib2to3 fork @@ -69,7 +82,7 @@ if TYPE_CHECKING: import colorama # noqa: F401 DEFAULT_LINE_LENGTH = 88 -DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 +DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 DEFAULT_INCLUDES = r"\.pyi?$" CACHE_DIR = Path(user_cache_dir("black", version=__version__)) STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__" @@ -93,7 +106,7 @@ Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]] Timestamp = float FileSize = int CacheInfo = Tuple[Timestamp, FileSize] -Cache = Dict[Path, CacheInfo] +Cache = Dict[str, CacheInfo] out = partial(click.secho, bold=True, err=True) err = partial(click.secho, fg="red", err=True) @@ -289,11 +302,15 @@ def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> b return all(feature in VERSION_TO_FEATURES[version] for version in target_versions) -def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]: +def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]: """Find the absolute filepath to a pyproject.toml if it exists""" path_project_root = find_project_root(path_search_start) path_pyproject_toml = path_project_root / "pyproject.toml" - return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + if path_pyproject_toml.is_file(): + return str(path_pyproject_toml) + + path_user_pyproject_toml = find_user_pyproject_toml() + return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: @@ -363,6 +380,17 @@ def target_version_option_callback( return [TargetVersion[val.upper()] for val in v] +def validate_regex( + ctx: click.Context, + param: click.Parameter, + value: Optional[str], +) -> Optional[Pattern]: + try: + return re_compile_maybe_verbose(value) if value is not None else None + except re.error: + raise click.BadParameter("Not a valid regular expression") + + @click.command(context_settings=dict(help_option_names=["-h", "--help"])) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( @@ -417,8 +445,8 @@ def target_version_option_callback( "--check", is_flag=True, help=( - "Don't write the files back, just return the status. Return code 0 means" - " nothing would change. Return code 1 means some files would be reformatted." + "Don't write the files back, just return the status. Return code 0 means" + " nothing would change. Return code 1 means some files would be reformatted." " Return code 123 means there was an internal error." ), ) @@ -441,11 +469,12 @@ def target_version_option_callback( "--include", type=str, default=DEFAULT_INCLUDES, + callback=validate_regex, help=( "A regular expression that matches files and directories that should be" - " included on recursive searches. An empty value means all files are included" - " regardless of the name. Use forward slashes for directories on all platforms" - " (Windows, too). Exclusions are calculated first, inclusions later." + " included on recursive searches. An empty value means all files are included" + " regardless of the name. Use forward slashes for directories on all platforms" + " (Windows, too). Exclusions are calculated first, inclusions later." ), show_default=True, ) @@ -453,17 +482,28 @@ def target_version_option_callback( "--exclude", type=str, default=DEFAULT_EXCLUDES, + callback=validate_regex, help=( "A regular expression that matches files and directories that should be" - " excluded on recursive searches. An empty value means no paths are excluded." - " Use forward slashes for directories on all platforms (Windows, too). " + " excluded on recursive searches. An empty value means no paths are excluded." + " Use forward slashes for directories on all platforms (Windows, too)." " Exclusions are calculated first, inclusions later." ), show_default=True, ) +@click.option( + "--extend-exclude", + type=str, + callback=validate_regex, + help=( + "Like --exclude, but adds additional files and directories on top of the" + " excluded ones. (Useful if you simply want to add to the default)" + ), +) @click.option( "--force-exclude", type=str, + callback=validate_regex, help=( "Like --exclude, but files and directories matching this regex will be " "excluded even when they are passed explicitly as arguments." @@ -493,7 +533,7 @@ def target_version_option_callback( is_flag=True, help=( "Also emit messages to stderr about files that were not changed or were ignored" - " due to --exclude=." + " due to exclusion patterns." ), ) @click.version_option(version=__version__) @@ -535,9 +575,10 @@ def main( experimental_string_processing: bool, quiet: bool, verbose: bool, - include: str, - exclude: str, - force_exclude: Optional[str], + include: Pattern, + exclude: Pattern, + extend_exclude: Optional[Pattern], + force_exclude: Optional[Pattern], stdin_filename: Optional[str], src: Tuple[str, ...], config: Optional[str], @@ -570,6 +611,7 @@ def main( verbose=verbose, include=include, exclude=exclude, + extend_exclude=extend_exclude, force_exclude=force_exclude, report=report, stdin_filename=stdin_filename, @@ -608,30 +650,14 @@ def get_sources( src: Tuple[str, ...], quiet: bool, verbose: bool, - include: str, - exclude: str, - force_exclude: Optional[str], + include: Pattern[str], + exclude: Pattern[str], + extend_exclude: Optional[Pattern[str]], + force_exclude: Optional[Pattern[str]], report: "Report", stdin_filename: Optional[str], ) -> Set[Path]: """Compute the set of files to be formatted.""" - try: - include_regex = re_compile_maybe_verbose(include) - except re.error: - err(f"Invalid regular expression for include given: {include!r}") - ctx.exit(2) - try: - exclude_regex = re_compile_maybe_verbose(exclude) - except re.error: - err(f"Invalid regular expression for exclude given: {exclude!r}") - ctx.exit(2) - try: - force_exclude_regex = ( - re_compile_maybe_verbose(force_exclude) if force_exclude else None - ) - except re.error: - err(f"Invalid regular expression for force_exclude given: {force_exclude!r}") - ctx.exit(2) root = find_project_root(src) sources: Set[Path] = set() @@ -653,8 +679,8 @@ def get_sources( normalized_path = "/" + normalized_path # Hard-exclude any files that matches the `--force-exclude` regex. - if force_exclude_regex: - force_exclude_match = force_exclude_regex.search(normalized_path) + if force_exclude: + force_exclude_match = force_exclude.search(normalized_path) else: force_exclude_match = None if force_exclude_match and force_exclude_match.group(0): @@ -670,9 +696,10 @@ def get_sources( gen_python_files( p.iterdir(), root, - include_regex, - exclude_regex, - force_exclude_regex, + include, + exclude, + extend_exclude, + force_exclude, report, gitignore, ) @@ -724,7 +751,8 @@ def reformat_one( if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF): cache = read_cache(mode) res_src = src.resolve() - if res_src in cache and cache[res_src] == get_cache_info(res_src): + res_src_s = str(res_src) + if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src): changed = Changed.CACHED if changed is not Changed.CACHED and format_file_in_place( src, fast=fast, write_back=write_back, mode=mode @@ -756,7 +784,7 @@ def reformat_many( except (ImportError, OSError): # we arrive here if the underlying system does not support multi-processing # like in AWS Lambda or Termux, in which case we gracefully fallback to - # a ThreadPollExecutor with just a single worker (more workers would not do us + # a ThreadPoolExecutor with just a single worker (more workers would not do us # any good due to the Global Interpreter Lock) executor = ThreadPoolExecutor(max_workers=1) @@ -819,7 +847,7 @@ async def schedule_formatting( ): src for src in sorted(sources) } - pending: Iterable["asyncio.Future[bool]"] = tasks.keys() + pending = tasks.keys() try: loop.add_signal_handler(signal.SIGINT, cancel, pending) loop.add_signal_handler(signal.SIGTERM, cancel, pending) @@ -882,7 +910,7 @@ def format_file_in_place( dst_name = f"{src}\t{now} +0000" diff_contents = diff(src_contents, dst_contents, src_name, dst_name) - if write_back == write_back.COLOR_DIFF: + if write_back == WriteBack.COLOR_DIFF: diff_contents = color_diff(diff_contents) with lock or nullcontext(): @@ -1479,7 +1507,8 @@ class Line: comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict) bracket_tracker: BracketTracker = field(default_factory=BracketTracker) inside_brackets: bool = False - should_explode: bool = False + should_split_rhs: bool = False + magic_trailing_comma: Optional[Leaf] = None def append(self, leaf: Leaf, preformatted: bool = False) -> None: """Add a new `leaf` to the end of the line. @@ -1507,7 +1536,7 @@ class Line: self.bracket_tracker.mark(leaf) if self.mode.magic_trailing_comma: if self.has_magic_trailing_comma(leaf): - self.should_explode = True + self.magic_trailing_comma = leaf elif self.has_magic_trailing_comma(leaf, ensure_removable=True): self.remove_trailing_comma() if not self.append_comment(leaf): @@ -1790,7 +1819,8 @@ class Line: mode=self.mode, depth=self.depth, inside_brackets=self.inside_brackets, - should_explode=self.should_explode, + should_split_rhs=self.should_split_rhs, + magic_trailing_comma=self.magic_trailing_comma, ) def __str__(self) -> str: @@ -2046,6 +2076,8 @@ class LineGenerator(Visitor[Line]): def visit_simple_stmt(self, node: Node) -> Iterator[Line]: """Visit a statement without nested statements.""" + if first_child_is_arith(node): + wrap_in_parentheses(node, node.children[0], visible=False) is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: if self.mode.is_pyi and is_stub_body(node): @@ -2121,16 +2153,35 @@ class LineGenerator(Visitor[Line]): # We're ignoring docstrings with backslash newline escapes because changing # indentation of those changes the AST representation of the code. prefix = get_string_prefix(leaf.value) - lead_len = len(prefix) + 3 - tail_len = -3 - indent = " " * 4 * self.current_line.depth - docstring = fix_docstring(leaf.value[lead_len:tail_len], indent) + docstring = leaf.value[len(prefix) :] # Remove the prefix + quote_char = docstring[0] + # A natural way to remove the outer quotes is to do: + # docstring = docstring.strip(quote_char) + # but that breaks on """""x""" (which is '""x'). + # So we actually need to remove the first character and the next two + # characters but only if they are the same as the first. + quote_len = 1 if docstring[1] != quote_char else 3 + docstring = docstring[quote_len:-quote_len] + + if is_multiline_string(leaf): + indent = " " * 4 * self.current_line.depth + docstring = fix_docstring(docstring, indent) + else: + docstring = docstring.strip() + if docstring: - if leaf.value[lead_len - 1] == docstring[0]: + # Add some padding if the docstring starts / ends with a quote mark. + if docstring[0] == quote_char: docstring = " " + docstring - if leaf.value[tail_len + 1] == docstring[-1]: + if docstring[-1] == quote_char: docstring = docstring + " " - leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:] + else: + # Add some padding if the docstring is empty. + docstring = " " + + # We could enforce triple quotes at this point. + quote = quote_char * quote_len + leaf.value = prefix + quote + docstring + quote yield from self.visit_default(leaf) @@ -2580,6 +2631,8 @@ def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Pr FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"} +FMT_SKIP = {"# fmt: skip", "# fmt:skip"} +FMT_PASS = {*FMT_OFF, *FMT_SKIP} FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"} @@ -2634,7 +2687,7 @@ def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]: consumed = 0 nlines = 0 ignored_lines = 0 - for index, line in enumerate(prefix.split("\n")): + for index, line in enumerate(re.split("\r?\n", prefix)): consumed += len(line) + 1 # adding the length of the split '\n' line = line.lstrip() if not line: @@ -2675,6 +2728,13 @@ def make_comment(content: str) -> str: if content[0] == "#": content = content[1:] + NON_BREAKING_SPACE = " " + if ( + content + and content[0] == NON_BREAKING_SPACE + and not content.lstrip().startswith("type:") + ): + content = " " + content[1:] # Replace NBSP by a simple space if content and content[0] not in " !:#'%": content = " " + content return "#" + content @@ -2707,7 +2767,8 @@ def transform_line( transformers: List[Transformer] if ( not line.contains_uncollapsable_type_comments() - and not line.should_explode + and not line.should_split_rhs + and not line.magic_trailing_comma and ( is_line_short_enough(line, line_length=mode.line_length, line_str=line_str) or line.contains_unsplittable_type_ignore() @@ -4381,7 +4442,8 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): mode=line.mode, depth=line.depth + 1, inside_brackets=True, - should_explode=line.should_explode, + should_split_rhs=line.should_split_rhs, + magic_trailing_comma=line.magic_trailing_comma, ) string_leaf = Leaf(token.STRING, string_value) insert_str_child(string_leaf) @@ -5002,8 +5064,8 @@ def bracket_split_build_line( result.append(leaf, preformatted=True) for comment_after in original.comments_after(leaf): result.append(comment_after, preformatted=True) - if is_body and should_split_body_explode(result, opening_bracket): - result.should_explode = True + if is_body and should_split_line(result, opening_bracket): + result.should_split_rhs = True return result @@ -5361,10 +5423,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: check_lpar = True if check_lpar: - if is_walrus_assignment(child): - pass - - elif child.type == syms.atom: + if child.type == syms.atom: if maybe_make_parens_invisible_in_atom(child, parent=node): wrap_in_parentheses(node, child, visible=False) elif is_one_tuple(child): @@ -5403,58 +5462,80 @@ def convert_one_fmt_off_pair(node: Node) -> bool: for leaf in node.leaves(): previous_consumed = 0 for comment in list_comments(leaf.prefix, is_endmarker=False): - if comment.value in FMT_OFF: - # We only want standalone comments. If there's no previous leaf or - # the previous leaf is indentation, it's a standalone comment in - # disguise. - if comment.type != STANDALONE_COMMENT: - prev = preceding_leaf(leaf) - if prev and prev.type not in WHITESPACE: + if comment.value not in FMT_PASS: + previous_consumed = comment.consumed + continue + # We only want standalone comments. If there's no previous leaf or + # the previous leaf is indentation, it's a standalone comment in + # disguise. + if comment.value in FMT_PASS and comment.type != STANDALONE_COMMENT: + prev = preceding_leaf(leaf) + if prev: + if comment.value in FMT_OFF and prev.type not in WHITESPACE: + continue + if comment.value in FMT_SKIP and prev.type in WHITESPACE: continue - ignored_nodes = list(generate_ignored_nodes(leaf)) - if not ignored_nodes: - continue - - first = ignored_nodes[0] # Can be a container node with the `leaf`. - parent = first.parent - prefix = first.prefix - first.prefix = prefix[comment.consumed :] - hidden_value = ( - comment.value + "\n" + "".join(str(n) for n in ignored_nodes) - ) - if hidden_value.endswith("\n"): - # That happens when one of the `ignored_nodes` ended with a NEWLINE - # leaf (possibly followed by a DEDENT). - hidden_value = hidden_value[:-1] - first_idx: Optional[int] = None - for ignored in ignored_nodes: - index = ignored.remove() - if first_idx is None: - first_idx = index - assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)" - assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)" - parent.insert_child( - first_idx, - Leaf( - STANDALONE_COMMENT, - hidden_value, - prefix=prefix[:previous_consumed] + "\n" * comment.newlines, - ), - ) - return True + ignored_nodes = list(generate_ignored_nodes(leaf, comment)) + if not ignored_nodes: + continue - previous_consumed = comment.consumed + first = ignored_nodes[0] # Can be a container node with the `leaf`. + parent = first.parent + prefix = first.prefix + first.prefix = prefix[comment.consumed :] + hidden_value = "".join(str(n) for n in ignored_nodes) + if comment.value in FMT_OFF: + hidden_value = comment.value + "\n" + hidden_value + if comment.value in FMT_SKIP: + hidden_value += " " + comment.value + if hidden_value.endswith("\n"): + # That happens when one of the `ignored_nodes` ended with a NEWLINE + # leaf (possibly followed by a DEDENT). + hidden_value = hidden_value[:-1] + first_idx: Optional[int] = None + for ignored in ignored_nodes: + index = ignored.remove() + if first_idx is None: + first_idx = index + assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)" + assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)" + parent.insert_child( + first_idx, + Leaf( + STANDALONE_COMMENT, + hidden_value, + prefix=prefix[:previous_consumed] + "\n" * comment.newlines, + ), + ) + return True return False -def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]: +def generate_ignored_nodes(leaf: Leaf, comment: ProtoComment) -> Iterator[LN]: """Starting from the container of `leaf`, generate all leaves until `# fmt: on`. + If comment is skip, returns leaf only. Stops at the end of the block. """ container: Optional[LN] = container_of(leaf) + if comment.value in FMT_SKIP: + prev_sibling = leaf.prev_sibling + if comment.value in leaf.prefix and prev_sibling is not None: + leaf.prefix = leaf.prefix.replace(comment.value, "") + siblings = [prev_sibling] + while ( + "\n" not in prev_sibling.prefix + and prev_sibling.prev_sibling is not None + ): + prev_sibling = prev_sibling.prev_sibling + siblings.insert(0, prev_sibling) + for sibling in siblings: + yield sibling + elif leaf.parent is not None: + yield leaf.parent + return while container is not None and container.type != token.ENDMARKER: if is_fmt_on(container): return @@ -5514,6 +5595,7 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: Returns whether the node should itself be wrapped in invisible parentheses. """ + if ( node.type != syms.atom or is_empty_tuple(node) @@ -5523,6 +5605,10 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: ): return False + if is_walrus_assignment(node): + if parent.type in [syms.annassign, syms.expr_stmt]: + return False + first = node.children[0] last = node.children[-1] if first.type == token.LPAR and last.type == token.RPAR: @@ -5584,6 +5670,17 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]: return wrapped +def first_child_is_arith(node: Node) -> bool: + """Whether first child is an arithmetic or a binary arithmetic expression""" + expr_types = { + syms.arith_expr, + syms.shift_expr, + syms.xor_expr, + syms.and_expr, + } + return bool(node.children and node.children[0].type in expr_types) + + def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None: """Wrap `child` in parentheses. @@ -5785,7 +5882,7 @@ def ensure_visible(leaf: Leaf) -> None: leaf.value = ")" -def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool: +def should_split_line(line: Line, opening_bracket: Leaf) -> bool: """Should `line` be immediately split with `delimiter_split()` after RHS?""" if not (opening_bracket.parent and opening_bracket.value in "[{("): @@ -5921,7 +6018,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf """ omit: Set[LeafID] = set() - if not line.should_explode: + if not line.magic_trailing_comma: yield omit length = 4 * line.depth @@ -5943,8 +6040,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf elif leaf.type in CLOSING_BRACKETS: prev = line.leaves[index - 1] if index > 0 else None if ( - line.should_explode - and prev + prev and prev.type == token.COMMA and not is_one_tuple_between( leaf.opening_bracket, leaf, line.leaves @@ -5971,8 +6067,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf yield omit if ( - line.should_explode - and prev + prev and prev.type == token.COMMA and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves) ): @@ -6037,7 +6132,7 @@ def get_future_imports(node: Node) -> Set[str]: @lru_cache() def get_gitignore(root: Path) -> PathSpec: - """ Return a PathSpec matching gitignore content if present.""" + """Return a PathSpec matching gitignore content if present.""" gitignore = root / ".gitignore" lines: List[str] = [] if gitignore.is_file(): @@ -6070,17 +6165,27 @@ def normalize_path_maybe_ignore( return normalized_path +def path_is_excluded( + normalized_path: str, + pattern: Optional[Pattern[str]], +) -> bool: + match = pattern.search(normalized_path) if pattern else None + return bool(match and match.group(0)) + + def gen_python_files( paths: Iterable[Path], root: Path, include: Optional[Pattern[str]], exclude: Pattern[str], + extend_exclude: Optional[Pattern[str]], force_exclude: Optional[Pattern[str]], report: "Report", gitignore: PathSpec, ) -> Iterator[Path]: """Generate all files under `path` whose paths are not excluded by the - `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex. + `exclude_regex`, `extend_exclude`, or `force_exclude` regexes, + but are included by the `include` regex. Symbolic links pointing outside of the `root` directory are ignored. @@ -6097,20 +6202,22 @@ def gen_python_files( report.path_ignored(child, "matches the .gitignore file content") continue - # Then ignore with `--exclude` and `--force-exclude` options. + # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options. normalized_path = "/" + normalized_path if child.is_dir(): normalized_path += "/" - exclude_match = exclude.search(normalized_path) if exclude else None - if exclude_match and exclude_match.group(0): + if path_is_excluded(normalized_path, exclude): report.path_ignored(child, "matches the --exclude regular expression") continue - force_exclude_match = ( - force_exclude.search(normalized_path) if force_exclude else None - ) - if force_exclude_match and force_exclude_match.group(0): + if path_is_excluded(normalized_path, extend_exclude): + report.path_ignored( + child, "matches the --extend-exclude regular expression" + ) + continue + + if path_is_excluded(normalized_path, force_exclude): report.path_ignored(child, "matches the --force-exclude regular expression") continue @@ -6120,6 +6227,7 @@ def gen_python_files( root, include, exclude, + extend_exclude, force_exclude, report, gitignore, @@ -6132,7 +6240,7 @@ def gen_python_files( @lru_cache() -def find_project_root(srcs: Iterable[str]) -> Path: +def find_project_root(srcs: Tuple[str, ...]) -> Path: """Return a directory containing .git, .hg, or pyproject.toml. That directory will be a common parent of all files and directories @@ -6170,6 +6278,22 @@ def find_project_root(srcs: Iterable[str]) -> Path: return directory +@lru_cache() +def find_user_pyproject_toml() -> Path: + r"""Return the path to the top-level user configuration for black. + + This looks for ~\.black on Windows and ~/.config/black on Linux and other + Unix systems. + """ + if sys.platform == "win32": + # Windows + user_config_path = Path.home() / ".black" + else: + config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config") + user_config_path = Path(config_root).expanduser() / "black" + return user_config_path.resolve() + + @dataclass class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" @@ -6271,7 +6395,12 @@ def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]: return ast3.parse(src, filename, feature_version=feature_version) except SyntaxError: continue - + if ast27.__name__ == "ast": + raise SyntaxError( + "The requested source code has invalid Python 3 syntax.\n" + "If you are trying to format Python 2 files please reinstall Black" + " with the 'python2' extra: `python3 -m pip install black[python2]`." + ) return ast27.parse(src) @@ -6337,12 +6466,22 @@ def _stringify_ast( # Constant strings may be indented across newlines, if they are # docstrings; fold spaces after newlines when comparing. Similarly, # trailing and leading space may be removed. + # Note that when formatting Python 2 code, at least with Windows + # line-endings, docstrings can end up here as bytes instead of + # str so make sure that we handle both cases. if ( isinstance(node, ast.Constant) and field == "value" - and isinstance(value, str) + and isinstance(value, (str, bytes)) ): - normalized = re.sub(r" *\n[ \t]*", "\n", value).strip() + lineend = "\n" if isinstance(value, str) else b"\n" + # To normalize, we strip any leading and trailing space from + # each line... + stripped = [line.strip() for line in value.splitlines()] + normalized = lineend.join(stripped) # type: ignore[attr-defined] + # ...and remove any blank lines at the beginning and end of + # the whole string + normalized = normalized.strip() else: normalized = value yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}" @@ -6398,14 +6537,14 @@ def assert_stable(src: str, dst: str, mode: Mode) -> None: @mypyc_attr(patchable=True) -def dump_to_file(*output: str) -> str: +def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str: """Dump `output` to a temporary file. Return path to the file.""" with tempfile.NamedTemporaryFile( mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" ) as f: for lines in output: f.write(lines) - if lines and lines[-1] != "\n": + if ensure_final_newline and lines and lines[-1] != "\n": f.write("\n") return f.name @@ -6423,11 +6562,20 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: """Return a unified diff string between strings `a` and `b`.""" import difflib - a_lines = [line + "\n" for line in a.splitlines()] - b_lines = [line + "\n" for line in b.splitlines()] - return "".join( - difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5) - ) + a_lines = [line for line in a.splitlines(keepends=True)] + b_lines = [line for line in b.splitlines(keepends=True)] + diff_lines = [] + for line in difflib.unified_diff( + a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5 + ): + # Work around https://bugs.python.org/issue2142 + # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html + if line[-1] == "\n": + diff_lines.append(line) + else: + diff_lines.append(line + "\n") + diff_lines.append("\\ No newline at end of file\n") + return "".join(diff_lines) def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None: @@ -6604,7 +6752,7 @@ def can_omit_invisible_parens( penultimate = line.leaves[-2] last = line.leaves[-1] - if line.should_explode: + if line.magic_trailing_comma: try: penultimate, last = last_two_except(line.leaves, omit=omit_on_explode) except LookupError: @@ -6631,7 +6779,7 @@ def can_omit_invisible_parens( # unnecessary. return True - if line.should_explode and penultimate.type == token.COMMA: + if line.magic_trailing_comma and penultimate.type == token.COMMA: # The rightmost non-omitted bracket pair is the one we want to explode on. return True @@ -6781,8 +6929,8 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set """ todo, done = set(), set() for src in sources: - src = src.resolve() - if cache.get(src) != get_cache_info(src): + res_src = src.resolve() + if cache.get(str(res_src)) != get_cache_info(res_src): todo.add(src) else: done.add(src) @@ -6794,7 +6942,10 @@ def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None: cache_file = get_cache_file(mode) try: CACHE_DIR.mkdir(parents=True, exist_ok=True) - new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}} + new_cache = { + **cache, + **{str(src.resolve()): get_cache_info(src) for src in sources}, + } with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f: pickle.dump(new_cache, f, protocol=4) os.replace(f.name, cache_file) @@ -6831,11 +6982,6 @@ def patched_main() -> None: def is_docstring(leaf: Leaf) -> bool: - if not is_multiline_string(leaf): - # For the purposes of docstring re-indentation, we don't need to do anything - # with single-line docstrings. - return False - if prev_siblings_are( leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt] ):