X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/a81702ce027ef6792955d5db02aad60a34c7ec44..e4003c2c439fc08a407d3024a195c010f03c4173:/src/black/__init__.py?ds=inline diff --git a/src/black/__init__.py b/src/black/__init__.py index 91f70d9..0065bea 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -48,7 +48,20 @@ from appdirs import user_cache_dir from dataclasses import dataclass, field, replace import click import toml -from typed_ast import ast3, ast27 + +try: + from typed_ast import ast3, ast27 +except ImportError: + if sys.version_info < (3, 8): + print( + "The typed_ast package is not installed.\n" + "You can install it with `python3 -m pip install typed-ast`.", + file=sys.stderr, + ) + sys.exit(1) + else: + ast3 = ast27 = ast + from pathspec import PathSpec # lib2to3 fork @@ -69,7 +82,7 @@ if TYPE_CHECKING: import colorama # noqa: F401 DEFAULT_LINE_LENGTH = 88 -DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 +DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 DEFAULT_INCLUDES = r"\.pyi?$" CACHE_DIR = Path(user_cache_dir("black", version=__version__)) STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__" @@ -93,7 +106,7 @@ Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]] Timestamp = float FileSize = int CacheInfo = Tuple[Timestamp, FileSize] -Cache = Dict[Path, CacheInfo] +Cache = Dict[str, CacheInfo] out = partial(click.secho, bold=True, err=True) err = partial(click.secho, fg="red", err=True) @@ -260,6 +273,7 @@ class Mode: target_versions: Set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True + magic_trailing_comma: bool = True experimental_string_processing: bool = False is_pyi: bool = False @@ -292,7 +306,11 @@ def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]: """Find the absolute filepath to a pyproject.toml if it exists""" path_project_root = find_project_root(path_search_start) path_pyproject_toml = path_project_root / "pyproject.toml" - return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + if path_pyproject_toml.is_file(): + return str(path_pyproject_toml) + + path_user_pyproject_toml = find_user_pyproject_toml() + return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: @@ -362,6 +380,17 @@ def target_version_option_callback( return [TargetVersion[val.upper()] for val in v] +def validate_regex( + ctx: click.Context, + param: click.Parameter, + value: Optional[str], +) -> Optional[Pattern]: + try: + return re_compile_maybe_verbose(value) if value is not None else None + except re.error: + raise click.BadParameter("Not a valid regular expression") + + @click.command(context_settings=dict(help_option_names=["-h", "--help"])) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( @@ -397,6 +426,12 @@ def target_version_option_callback( is_flag=True, help="Don't normalize string quotes or prefixes.", ) +@click.option( + "-C", + "--skip-magic-trailing-comma", + is_flag=True, + help="Don't use trailing commas as a reason to split lines.", +) @click.option( "--experimental-string-processing", is_flag=True, @@ -410,8 +445,8 @@ def target_version_option_callback( "--check", is_flag=True, help=( - "Don't write the files back, just return the status. Return code 0 means" - " nothing would change. Return code 1 means some files would be reformatted." + "Don't write the files back, just return the status. Return code 0 means" + " nothing would change. Return code 1 means some files would be reformatted." " Return code 123 means there was an internal error." ), ) @@ -434,11 +469,12 @@ def target_version_option_callback( "--include", type=str, default=DEFAULT_INCLUDES, + callback=validate_regex, help=( "A regular expression that matches files and directories that should be" - " included on recursive searches. An empty value means all files are included" - " regardless of the name. Use forward slashes for directories on all platforms" - " (Windows, too). Exclusions are calculated first, inclusions later." + " included on recursive searches. An empty value means all files are included" + " regardless of the name. Use forward slashes for directories on all platforms" + " (Windows, too). Exclusions are calculated first, inclusions later." ), show_default=True, ) @@ -446,17 +482,28 @@ def target_version_option_callback( "--exclude", type=str, default=DEFAULT_EXCLUDES, + callback=validate_regex, help=( "A regular expression that matches files and directories that should be" - " excluded on recursive searches. An empty value means no paths are excluded." - " Use forward slashes for directories on all platforms (Windows, too). " + " excluded on recursive searches. An empty value means no paths are excluded." + " Use forward slashes for directories on all platforms (Windows, too)." " Exclusions are calculated first, inclusions later." ), show_default=True, ) +@click.option( + "--extend-exclude", + type=str, + callback=validate_regex, + help=( + "Like --exclude, but adds additional files and directories on top of the" + " excluded ones. (Useful if you simply want to add to the default)" + ), +) @click.option( "--force-exclude", type=str, + callback=validate_regex, help=( "Like --exclude, but files and directories matching this regex will be " "excluded even when they are passed explicitly as arguments." @@ -486,7 +533,7 @@ def target_version_option_callback( is_flag=True, help=( "Also emit messages to stderr about files that were not changed or were ignored" - " due to --exclude=." + " due to exclusion patterns." ), ) @click.version_option(version=__version__) @@ -524,12 +571,14 @@ def main( fast: bool, pyi: bool, skip_string_normalization: bool, + skip_magic_trailing_comma: bool, experimental_string_processing: bool, quiet: bool, verbose: bool, - include: str, - exclude: str, - force_exclude: Optional[str], + include: Pattern, + exclude: Pattern, + extend_exclude: Optional[Pattern], + force_exclude: Optional[Pattern], stdin_filename: Optional[str], src: Tuple[str, ...], config: Optional[str], @@ -546,6 +595,7 @@ def main( line_length=line_length, is_pyi=pyi, string_normalization=not skip_string_normalization, + magic_trailing_comma=not skip_magic_trailing_comma, experimental_string_processing=experimental_string_processing, ) if config and verbose: @@ -561,6 +611,7 @@ def main( verbose=verbose, include=include, exclude=exclude, + extend_exclude=extend_exclude, force_exclude=force_exclude, report=report, stdin_filename=stdin_filename, @@ -599,30 +650,14 @@ def get_sources( src: Tuple[str, ...], quiet: bool, verbose: bool, - include: str, - exclude: str, - force_exclude: Optional[str], + include: Pattern[str], + exclude: Pattern[str], + extend_exclude: Optional[Pattern[str]], + force_exclude: Optional[Pattern[str]], report: "Report", stdin_filename: Optional[str], ) -> Set[Path]: """Compute the set of files to be formatted.""" - try: - include_regex = re_compile_maybe_verbose(include) - except re.error: - err(f"Invalid regular expression for include given: {include!r}") - ctx.exit(2) - try: - exclude_regex = re_compile_maybe_verbose(exclude) - except re.error: - err(f"Invalid regular expression for exclude given: {exclude!r}") - ctx.exit(2) - try: - force_exclude_regex = ( - re_compile_maybe_verbose(force_exclude) if force_exclude else None - ) - except re.error: - err(f"Invalid regular expression for force_exclude given: {force_exclude!r}") - ctx.exit(2) root = find_project_root(src) sources: Set[Path] = set() @@ -644,8 +679,8 @@ def get_sources( normalized_path = "/" + normalized_path # Hard-exclude any files that matches the `--force-exclude` regex. - if force_exclude_regex: - force_exclude_match = force_exclude_regex.search(normalized_path) + if force_exclude: + force_exclude_match = force_exclude.search(normalized_path) else: force_exclude_match = None if force_exclude_match and force_exclude_match.group(0): @@ -661,9 +696,10 @@ def get_sources( gen_python_files( p.iterdir(), root, - include_regex, - exclude_regex, - force_exclude_regex, + include, + exclude, + extend_exclude, + force_exclude, report, gitignore, ) @@ -715,7 +751,8 @@ def reformat_one( if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF): cache = read_cache(mode) res_src = src.resolve() - if res_src in cache and cache[res_src] == get_cache_info(res_src): + res_src_s = str(res_src) + if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src): changed = Changed.CACHED if changed is not Changed.CACHED and format_file_in_place( src, fast=fast, write_back=write_back, mode=mode @@ -810,7 +847,7 @@ async def schedule_formatting( ): src for src in sorted(sources) } - pending: Iterable["asyncio.Future[bool]"] = tasks.keys() + pending = tasks.keys() try: loop.add_signal_handler(signal.SIGINT, cancel, pending) loop.add_signal_handler(signal.SIGTERM, cancel, pending) @@ -873,7 +910,7 @@ def format_file_in_place( dst_name = f"{src}\t{now} +0000" diff_contents = diff(src_contents, dst_contents, src_name, dst_name) - if write_back == write_back.COLOR_DIFF: + if write_back == WriteBack.COLOR_DIFF: diff_contents = color_diff(diff_contents) with lock or nullcontext(): @@ -1022,13 +1059,12 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: versions = detect_target_versions(src_node) normalize_fmt_off(src_node) lines = LineGenerator( + mode=mode, remove_u_prefix="unicode_literals" in future_imports or supports_feature(versions, Feature.UNICODE_LITERALS), - is_pyi=mode.is_pyi, - normalize_strings=mode.string_normalization, ) elt = EmptyLineTracker(is_pyi=mode.is_pyi) - empty_line = Line() + empty_line = Line(mode=mode) after = 0 split_line_features = { feature @@ -1464,13 +1500,15 @@ class BracketTracker: class Line: """Holds leaves and comments. Can be printed with `str(line)`.""" + mode: Mode depth: int = 0 leaves: List[Leaf] = field(default_factory=list) # keys ordered like `leaves` comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict) bracket_tracker: BracketTracker = field(default_factory=BracketTracker) inside_brackets: bool = False - should_explode: bool = False + should_split_rhs: bool = False + magic_trailing_comma: Optional[Leaf] = None def append(self, leaf: Leaf, preformatted: bool = False) -> None: """Add a new `leaf` to the end of the line. @@ -1496,8 +1534,11 @@ class Line: ) if self.inside_brackets or not preformatted: self.bracket_tracker.mark(leaf) - if self.maybe_should_explode(leaf): - self.should_explode = True + if self.mode.magic_trailing_comma: + if self.has_magic_trailing_comma(leaf): + self.magic_trailing_comma = leaf + elif self.has_magic_trailing_comma(leaf, ensure_removable=True): + self.remove_trailing_comma() if not self.append_comment(leaf): self.leaves.append(leaf) @@ -1673,10 +1714,14 @@ class Line: def contains_multiline_strings(self) -> bool: return any(is_multiline_string(leaf) for leaf in self.leaves) - def maybe_should_explode(self, closing: Leaf) -> bool: - """Return True if this line should explode (always be split), that is when: - - there's a trailing comma here; and - - it's not a one-tuple. + def has_magic_trailing_comma( + self, closing: Leaf, ensure_removable: bool = False + ) -> bool: + """Return True if we have a magic trailing comma, that is when: + - there's a trailing comma here + - it's not a one-tuple + Additionally, if ensure_removable: + - it's not from square bracket indexing """ if not ( closing.type in CLOSING_BRACKETS @@ -1685,9 +1730,15 @@ class Line: ): return False - if closing.type in {token.RBRACE, token.RSQB}: + if closing.type == token.RBRACE: return True + if closing.type == token.RSQB: + if not ensure_removable: + return True + comma = self.leaves[-1] + return bool(comma.parent and comma.parent.type == syms.listmaker) + if self.is_import: return True @@ -1765,9 +1816,11 @@ class Line: def clone(self) -> "Line": return Line( + mode=self.mode, depth=self.depth, inside_brackets=self.inside_brackets, - should_explode=self.should_explode, + should_split_rhs=self.should_split_rhs, + magic_trailing_comma=self.magic_trailing_comma, ) def __str__(self) -> str: @@ -1923,10 +1976,9 @@ class LineGenerator(Visitor[Line]): in ways that will no longer stringify to valid Python code on the tree. """ - is_pyi: bool = False - normalize_strings: bool = True - current_line: Line = field(default_factory=Line) + mode: Mode remove_u_prefix: bool = False + current_line: Line = field(init=False) def line(self, indent: int = 0) -> Iterator[Line]: """Generate a line. @@ -1941,7 +1993,7 @@ class LineGenerator(Visitor[Line]): return # Line is empty, don't emit. Creating a new one unnecessary. complete_line = self.current_line - self.current_line = Line(depth=complete_line.depth + indent) + self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent) yield complete_line def visit_default(self, node: LN) -> Iterator[Line]: @@ -1965,7 +2017,7 @@ class LineGenerator(Visitor[Line]): yield from self.line() normalize_prefix(node, inside_brackets=any_open_brackets) - if self.normalize_strings and node.type == token.STRING: + if self.mode.string_normalization and node.type == token.STRING: normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix) normalize_string_quotes(node) if node.type == token.NUMBER: @@ -2017,16 +2069,18 @@ class LineGenerator(Visitor[Line]): def visit_suite(self, node: Node) -> Iterator[Line]: """Visit a suite.""" - if self.is_pyi and is_stub_suite(node): + if self.mode.is_pyi and is_stub_suite(node): yield from self.visit(node.children[2]) else: yield from self.visit_default(node) def visit_simple_stmt(self, node: Node) -> Iterator[Line]: """Visit a statement without nested statements.""" + if first_child_is_arith(node): + wrap_in_parentheses(node, node.children[0], visible=False) is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: - if self.is_pyi and is_stub_body(node): + if self.mode.is_pyi and is_stub_body(node): yield from self.visit_default(node) else: yield from self.line(+1) @@ -2034,7 +2088,11 @@ class LineGenerator(Visitor[Line]): yield from self.line(-1) else: - if not self.is_pyi or not node.parent or not is_stub_suite(node.parent): + if ( + not self.mode.is_pyi + or not node.parent + or not is_stub_suite(node.parent) + ): yield from self.line() yield from self.visit_default(node) @@ -2110,6 +2168,8 @@ class LineGenerator(Visitor[Line]): def __post_init__(self) -> None: """You are in a twisty little maze of passages.""" + self.current_line = Line(mode=self.mode) + v = self.visit_stmt Ø: Set[str] = set() self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","}) @@ -2552,6 +2612,8 @@ def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Pr FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"} +FMT_SKIP = {"# fmt: skip", "# fmt:skip"} +FMT_PASS = {*FMT_OFF, *FMT_SKIP} FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"} @@ -2606,7 +2668,7 @@ def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]: consumed = 0 nlines = 0 ignored_lines = 0 - for index, line in enumerate(prefix.split("\n")): + for index, line in enumerate(re.split("\r?\n", prefix)): consumed += len(line) + 1 # adding the length of the split '\n' line = line.lstrip() if not line: @@ -2679,7 +2741,8 @@ def transform_line( transformers: List[Transformer] if ( not line.contains_uncollapsable_type_comments() - and not line.should_explode + and not line.should_split_rhs + and not line.magic_trailing_comma and ( is_line_short_enough(line, line_length=mode.line_length, line_str=line_str) or line.contains_unsplittable_type_ignore() @@ -4350,9 +4413,11 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): # `StringSplitter` will break it down further if necessary. string_value = LL[string_idx].value string_line = Line( + mode=line.mode, depth=line.depth + 1, inside_brackets=True, - should_explode=line.should_explode, + should_split_rhs=line.should_split_rhs, + magic_trailing_comma=line.magic_trailing_comma, ) string_leaf = Leaf(token.STRING, string_value) insert_str_child(string_leaf) @@ -4943,7 +5008,7 @@ def bracket_split_build_line( If `is_body` is True, the result line is one-indented inside brackets and as such has its first leaf's prefix normalized and a trailing comma added when expected. """ - result = Line(depth=original.depth) + result = Line(mode=original.mode, depth=original.depth) if is_body: result.inside_brackets = True result.depth += 1 @@ -4973,8 +5038,8 @@ def bracket_split_build_line( result.append(leaf, preformatted=True) for comment_after in original.comments_after(leaf): result.append(comment_after, preformatted=True) - if is_body and should_split_body_explode(result, opening_bracket): - result.should_explode = True + if is_body and should_split_line(result, opening_bracket): + result.should_split_rhs = True return result @@ -5015,7 +5080,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ if bt.delimiter_count_with_priority(delimiter_priority) == 1: raise CannotSplit("Splitting a single attribute from its owner looks wrong") - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) lowest_depth = sys.maxsize trailing_comma_safe = True @@ -5027,7 +5094,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ except ValueError: yield current_line - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) current_line.append(leaf) for leaf in line.leaves: @@ -5051,7 +5120,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ if leaf_priority == delimiter_priority: yield current_line - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) if current_line: if ( trailing_comma_safe @@ -5072,7 +5143,9 @@ def standalone_comment_split( if not line.contains_standalone_comments(0): raise CannotSplit("Line does not have any standalone comments") - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) def append_to_line(leaf: Leaf) -> Iterator[Line]: """Append `leaf` to current line or to new line if appending impossible.""" @@ -5082,7 +5155,9 @@ def standalone_comment_split( except ValueError: yield current_line - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) current_line.append(leaf) for leaf in line.leaves: @@ -5322,10 +5397,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: check_lpar = True if check_lpar: - if is_walrus_assignment(child): - pass - - elif child.type == syms.atom: + if child.type == syms.atom: if maybe_make_parens_invisible_in_atom(child, parent=node): wrap_in_parentheses(node, child, visible=False) elif is_one_tuple(child): @@ -5364,58 +5436,80 @@ def convert_one_fmt_off_pair(node: Node) -> bool: for leaf in node.leaves(): previous_consumed = 0 for comment in list_comments(leaf.prefix, is_endmarker=False): - if comment.value in FMT_OFF: - # We only want standalone comments. If there's no previous leaf or - # the previous leaf is indentation, it's a standalone comment in - # disguise. - if comment.type != STANDALONE_COMMENT: - prev = preceding_leaf(leaf) - if prev and prev.type not in WHITESPACE: + if comment.value not in FMT_PASS: + previous_consumed = comment.consumed + continue + # We only want standalone comments. If there's no previous leaf or + # the previous leaf is indentation, it's a standalone comment in + # disguise. + if comment.value in FMT_PASS and comment.type != STANDALONE_COMMENT: + prev = preceding_leaf(leaf) + if prev: + if comment.value in FMT_OFF and prev.type not in WHITESPACE: + continue + if comment.value in FMT_SKIP and prev.type in WHITESPACE: continue - ignored_nodes = list(generate_ignored_nodes(leaf)) - if not ignored_nodes: - continue - - first = ignored_nodes[0] # Can be a container node with the `leaf`. - parent = first.parent - prefix = first.prefix - first.prefix = prefix[comment.consumed :] - hidden_value = ( - comment.value + "\n" + "".join(str(n) for n in ignored_nodes) - ) - if hidden_value.endswith("\n"): - # That happens when one of the `ignored_nodes` ended with a NEWLINE - # leaf (possibly followed by a DEDENT). - hidden_value = hidden_value[:-1] - first_idx: Optional[int] = None - for ignored in ignored_nodes: - index = ignored.remove() - if first_idx is None: - first_idx = index - assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)" - assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)" - parent.insert_child( - first_idx, - Leaf( - STANDALONE_COMMENT, - hidden_value, - prefix=prefix[:previous_consumed] + "\n" * comment.newlines, - ), - ) - return True + ignored_nodes = list(generate_ignored_nodes(leaf, comment)) + if not ignored_nodes: + continue - previous_consumed = comment.consumed + first = ignored_nodes[0] # Can be a container node with the `leaf`. + parent = first.parent + prefix = first.prefix + first.prefix = prefix[comment.consumed :] + hidden_value = "".join(str(n) for n in ignored_nodes) + if comment.value in FMT_OFF: + hidden_value = comment.value + "\n" + hidden_value + if comment.value in FMT_SKIP: + hidden_value += " " + comment.value + if hidden_value.endswith("\n"): + # That happens when one of the `ignored_nodes` ended with a NEWLINE + # leaf (possibly followed by a DEDENT). + hidden_value = hidden_value[:-1] + first_idx: Optional[int] = None + for ignored in ignored_nodes: + index = ignored.remove() + if first_idx is None: + first_idx = index + assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)" + assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)" + parent.insert_child( + first_idx, + Leaf( + STANDALONE_COMMENT, + hidden_value, + prefix=prefix[:previous_consumed] + "\n" * comment.newlines, + ), + ) + return True return False -def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]: +def generate_ignored_nodes(leaf: Leaf, comment: ProtoComment) -> Iterator[LN]: """Starting from the container of `leaf`, generate all leaves until `# fmt: on`. + If comment is skip, returns leaf only. Stops at the end of the block. """ container: Optional[LN] = container_of(leaf) + if comment.value in FMT_SKIP: + prev_sibling = leaf.prev_sibling + if comment.value in leaf.prefix and prev_sibling is not None: + leaf.prefix = leaf.prefix.replace(comment.value, "") + siblings = [prev_sibling] + while ( + "\n" not in prev_sibling.prefix + and prev_sibling.prev_sibling is not None + ): + prev_sibling = prev_sibling.prev_sibling + siblings.insert(0, prev_sibling) + for sibling in siblings: + yield sibling + elif leaf.parent is not None: + yield leaf.parent + return while container is not None and container.type != token.ENDMARKER: if is_fmt_on(container): return @@ -5475,6 +5569,7 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: Returns whether the node should itself be wrapped in invisible parentheses. """ + if ( node.type != syms.atom or is_empty_tuple(node) @@ -5484,6 +5579,10 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: ): return False + if is_walrus_assignment(node): + if parent.type in [syms.annassign, syms.expr_stmt]: + return False + first = node.children[0] last = node.children[-1] if first.type == token.LPAR and last.type == token.RPAR: @@ -5545,6 +5644,17 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]: return wrapped +def first_child_is_arith(node: Node) -> bool: + """Whether first child is an arithmetic or a binary arithmetic expression""" + expr_types = { + syms.arith_expr, + syms.shift_expr, + syms.xor_expr, + syms.and_expr, + } + return bool(node.children and node.children[0].type in expr_types) + + def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None: """Wrap `child` in parentheses. @@ -5746,7 +5856,7 @@ def ensure_visible(leaf: Leaf) -> None: leaf.value = ")" -def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool: +def should_split_line(line: Line, opening_bracket: Leaf) -> bool: """Should `line` be immediately split with `delimiter_split()` after RHS?""" if not (opening_bracket.parent and opening_bracket.value in "[{("): @@ -5767,7 +5877,7 @@ def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool: return False return max_priority == COMMA_PRIORITY and ( - trailing_comma + (line.mode.magic_trailing_comma and trailing_comma) # always explode imports or opening_bracket.parent.type in {syms.atom, syms.import_from} ) @@ -5882,7 +5992,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf """ omit: Set[LeafID] = set() - if not line.should_explode: + if not line.magic_trailing_comma: yield omit length = 4 * line.depth @@ -5904,8 +6014,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf elif leaf.type in CLOSING_BRACKETS: prev = line.leaves[index - 1] if index > 0 else None if ( - line.should_explode - and prev + prev and prev.type == token.COMMA and not is_one_tuple_between( leaf.opening_bracket, leaf, line.leaves @@ -5932,8 +6041,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf yield omit if ( - line.should_explode - and prev + prev and prev.type == token.COMMA and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves) ): @@ -6031,17 +6139,27 @@ def normalize_path_maybe_ignore( return normalized_path +def path_is_excluded( + normalized_path: str, + pattern: Optional[Pattern[str]], +) -> bool: + match = pattern.search(normalized_path) if pattern else None + return bool(match and match.group(0)) + + def gen_python_files( paths: Iterable[Path], root: Path, include: Optional[Pattern[str]], exclude: Pattern[str], + extend_exclude: Optional[Pattern[str]], force_exclude: Optional[Pattern[str]], report: "Report", gitignore: PathSpec, ) -> Iterator[Path]: """Generate all files under `path` whose paths are not excluded by the - `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex. + `exclude_regex`, `extend_exclude`, or `force_exclude` regexes, + but are included by the `include` regex. Symbolic links pointing outside of the `root` directory are ignored. @@ -6058,20 +6176,22 @@ def gen_python_files( report.path_ignored(child, "matches the .gitignore file content") continue - # Then ignore with `--exclude` and `--force-exclude` options. + # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options. normalized_path = "/" + normalized_path if child.is_dir(): normalized_path += "/" - exclude_match = exclude.search(normalized_path) if exclude else None - if exclude_match and exclude_match.group(0): + if path_is_excluded(normalized_path, exclude): report.path_ignored(child, "matches the --exclude regular expression") continue - force_exclude_match = ( - force_exclude.search(normalized_path) if force_exclude else None - ) - if force_exclude_match and force_exclude_match.group(0): + if path_is_excluded(normalized_path, extend_exclude): + report.path_ignored( + child, "matches the --extend-exclude regular expression" + ) + continue + + if path_is_excluded(normalized_path, force_exclude): report.path_ignored(child, "matches the --force-exclude regular expression") continue @@ -6081,6 +6201,7 @@ def gen_python_files( root, include, exclude, + extend_exclude, force_exclude, report, gitignore, @@ -6131,6 +6252,22 @@ def find_project_root(srcs: Iterable[str]) -> Path: return directory +@lru_cache() +def find_user_pyproject_toml() -> Path: + r"""Return the path to the top-level user configuration for black. + + This looks for ~\.black on Windows and ~/.config/black on Linux and other + Unix systems. + """ + if sys.platform == "win32": + # Windows + user_config_path = Path.home() / ".black" + else: + config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config") + user_config_path = Path(config_root).expanduser() / "black" + return user_config_path.resolve() + + @dataclass class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" @@ -6232,7 +6369,12 @@ def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]: return ast3.parse(src, filename, feature_version=feature_version) except SyntaxError: continue - + if ast27.__name__ == "ast": + raise SyntaxError( + "The requested source code has invalid Python 3 syntax.\n" + "If you are trying to format Python 2 files please reinstall Black" + " with the 'python2' extra: `python3 -m pip install black[python2]`." + ) return ast27.parse(src) @@ -6359,14 +6501,14 @@ def assert_stable(src: str, dst: str, mode: Mode) -> None: @mypyc_attr(patchable=True) -def dump_to_file(*output: str) -> str: +def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str: """Dump `output` to a temporary file. Return path to the file.""" with tempfile.NamedTemporaryFile( mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" ) as f: for lines in output: f.write(lines) - if lines and lines[-1] != "\n": + if ensure_final_newline and lines and lines[-1] != "\n": f.write("\n") return f.name @@ -6384,11 +6526,20 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: """Return a unified diff string between strings `a` and `b`.""" import difflib - a_lines = [line + "\n" for line in a.splitlines()] - b_lines = [line + "\n" for line in b.splitlines()] - return "".join( - difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5) - ) + a_lines = [line for line in a.splitlines(keepends=True)] + b_lines = [line for line in b.splitlines(keepends=True)] + diff_lines = [] + for line in difflib.unified_diff( + a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5 + ): + # Work around https://bugs.python.org/issue2142 + # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html + if line[-1] == "\n": + diff_lines.append(line) + else: + diff_lines.append(line + "\n") + diff_lines.append("\\ No newline at end of file\n") + return "".join(diff_lines) def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None: @@ -6565,7 +6716,7 @@ def can_omit_invisible_parens( penultimate = line.leaves[-2] last = line.leaves[-1] - if line.should_explode: + if line.magic_trailing_comma: try: penultimate, last = last_two_except(line.leaves, omit=omit_on_explode) except LookupError: @@ -6592,7 +6743,7 @@ def can_omit_invisible_parens( # unnecessary. return True - if line.should_explode and penultimate.type == token.COMMA: + if line.magic_trailing_comma and penultimate.type == token.COMMA: # The rightmost non-omitted bracket pair is the one we want to explode on. return True @@ -6742,8 +6893,8 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set """ todo, done = set(), set() for src in sources: - src = src.resolve() - if cache.get(src) != get_cache_info(src): + res_src = src.resolve() + if cache.get(str(res_src)) != get_cache_info(res_src): todo.add(src) else: done.add(src) @@ -6755,7 +6906,10 @@ def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None: cache_file = get_cache_file(mode) try: CACHE_DIR.mkdir(parents=True, exist_ok=True) - new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}} + new_cache = { + **cache, + **{str(src.resolve()): get_cache_info(src) for src in sources}, + } with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f: pickle.dump(new_cache, f, protocol=4) os.replace(f.name, cache_file)