X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/22127c633eba10d41519fb562c1252f859e2d7fa..8e0803e7e5acabdd28b80258f15d8aebf11fbb4c:/src/black/__init__.py?ds=sidebyside diff --git a/src/black/__init__.py b/src/black/__init__.py index c2b0ad4..431ee02 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -48,7 +48,20 @@ from appdirs import user_cache_dir from dataclasses import dataclass, field, replace import click import toml -from typed_ast import ast3, ast27 + +try: + from typed_ast import ast3, ast27 +except ImportError: + if sys.version_info < (3, 8): + print( + "The typed_ast package is not installed.\n" + "You can install it with `python3 -m pip install typed-ast`.", + file=sys.stderr, + ) + sys.exit(1) + else: + ast3 = ast27 = ast + from pathspec import PathSpec # lib2to3 fork @@ -293,7 +306,11 @@ def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]: """Find the absolute filepath to a pyproject.toml if it exists""" path_project_root = find_project_root(path_search_start) path_pyproject_toml = path_project_root / "pyproject.toml" - return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + if path_pyproject_toml.is_file(): + return str(path_pyproject_toml) + + path_user_pyproject_toml = find_user_pyproject_toml() + return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: @@ -363,6 +380,17 @@ def target_version_option_callback( return [TargetVersion[val.upper()] for val in v] +def validate_regex( + ctx: click.Context, + param: click.Parameter, + value: Optional[str], +) -> Optional[Pattern]: + try: + return re_compile_maybe_verbose(value) if value is not None else None + except re.error: + raise click.BadParameter("Not a valid regular expression") + + @click.command(context_settings=dict(help_option_names=["-h", "--help"])) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( @@ -417,8 +445,8 @@ def target_version_option_callback( "--check", is_flag=True, help=( - "Don't write the files back, just return the status. Return code 0 means" - " nothing would change. Return code 1 means some files would be reformatted." + "Don't write the files back, just return the status. Return code 0 means" + " nothing would change. Return code 1 means some files would be reformatted." " Return code 123 means there was an internal error." ), ) @@ -441,11 +469,12 @@ def target_version_option_callback( "--include", type=str, default=DEFAULT_INCLUDES, + callback=validate_regex, help=( "A regular expression that matches files and directories that should be" - " included on recursive searches. An empty value means all files are included" - " regardless of the name. Use forward slashes for directories on all platforms" - " (Windows, too). Exclusions are calculated first, inclusions later." + " included on recursive searches. An empty value means all files are included" + " regardless of the name. Use forward slashes for directories on all platforms" + " (Windows, too). Exclusions are calculated first, inclusions later." ), show_default=True, ) @@ -453,17 +482,28 @@ def target_version_option_callback( "--exclude", type=str, default=DEFAULT_EXCLUDES, + callback=validate_regex, help=( "A regular expression that matches files and directories that should be" - " excluded on recursive searches. An empty value means no paths are excluded." - " Use forward slashes for directories on all platforms (Windows, too). " + " excluded on recursive searches. An empty value means no paths are excluded." + " Use forward slashes for directories on all platforms (Windows, too)." " Exclusions are calculated first, inclusions later." ), show_default=True, ) +@click.option( + "--extend-exclude", + type=str, + callback=validate_regex, + help=( + "Like --exclude, but adds additional files and directories on top of the" + " excluded ones. (Useful if you simply want to add to the default)" + ), +) @click.option( "--force-exclude", type=str, + callback=validate_regex, help=( "Like --exclude, but files and directories matching this regex will be " "excluded even when they are passed explicitly as arguments." @@ -493,7 +533,7 @@ def target_version_option_callback( is_flag=True, help=( "Also emit messages to stderr about files that were not changed or were ignored" - " due to --exclude=." + " due to exclusion patterns." ), ) @click.version_option(version=__version__) @@ -535,9 +575,10 @@ def main( experimental_string_processing: bool, quiet: bool, verbose: bool, - include: str, - exclude: str, - force_exclude: Optional[str], + include: Pattern, + exclude: Pattern, + extend_exclude: Optional[Pattern], + force_exclude: Optional[Pattern], stdin_filename: Optional[str], src: Tuple[str, ...], config: Optional[str], @@ -570,6 +611,7 @@ def main( verbose=verbose, include=include, exclude=exclude, + extend_exclude=extend_exclude, force_exclude=force_exclude, report=report, stdin_filename=stdin_filename, @@ -608,30 +650,14 @@ def get_sources( src: Tuple[str, ...], quiet: bool, verbose: bool, - include: str, - exclude: str, - force_exclude: Optional[str], + include: Pattern[str], + exclude: Pattern[str], + extend_exclude: Optional[Pattern[str]], + force_exclude: Optional[Pattern[str]], report: "Report", stdin_filename: Optional[str], ) -> Set[Path]: """Compute the set of files to be formatted.""" - try: - include_regex = re_compile_maybe_verbose(include) - except re.error: - err(f"Invalid regular expression for include given: {include!r}") - ctx.exit(2) - try: - exclude_regex = re_compile_maybe_verbose(exclude) - except re.error: - err(f"Invalid regular expression for exclude given: {exclude!r}") - ctx.exit(2) - try: - force_exclude_regex = ( - re_compile_maybe_verbose(force_exclude) if force_exclude else None - ) - except re.error: - err(f"Invalid regular expression for force_exclude given: {force_exclude!r}") - ctx.exit(2) root = find_project_root(src) sources: Set[Path] = set() @@ -653,8 +679,8 @@ def get_sources( normalized_path = "/" + normalized_path # Hard-exclude any files that matches the `--force-exclude` regex. - if force_exclude_regex: - force_exclude_match = force_exclude_regex.search(normalized_path) + if force_exclude: + force_exclude_match = force_exclude.search(normalized_path) else: force_exclude_match = None if force_exclude_match and force_exclude_match.group(0): @@ -670,9 +696,10 @@ def get_sources( gen_python_files( p.iterdir(), root, - include_regex, - exclude_regex, - force_exclude_regex, + include, + exclude, + extend_exclude, + force_exclude, report, gitignore, ) @@ -883,7 +910,7 @@ def format_file_in_place( dst_name = f"{src}\t{now} +0000" diff_contents = diff(src_contents, dst_contents, src_name, dst_name) - if write_back == write_back.COLOR_DIFF: + if write_back == WriteBack.COLOR_DIFF: diff_contents = color_diff(diff_contents) with lock or nullcontext(): @@ -1480,7 +1507,7 @@ class Line: comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict) bracket_tracker: BracketTracker = field(default_factory=BracketTracker) inside_brackets: bool = False - should_split: bool = False + should_split_rhs: bool = False magic_trailing_comma: Optional[Leaf] = None def append(self, leaf: Leaf, preformatted: bool = False) -> None: @@ -1792,7 +1819,7 @@ class Line: mode=self.mode, depth=self.depth, inside_brackets=self.inside_brackets, - should_split=self.should_split, + should_split_rhs=self.should_split_rhs, magic_trailing_comma=self.magic_trailing_comma, ) @@ -2049,6 +2076,8 @@ class LineGenerator(Visitor[Line]): def visit_simple_stmt(self, node: Node) -> Iterator[Line]: """Visit a statement without nested statements.""" + if first_child_is_arith(node): + wrap_in_parentheses(node, node.children[0], visible=False) is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: if self.mode.is_pyi and is_stub_body(node): @@ -2712,7 +2741,7 @@ def transform_line( transformers: List[Transformer] if ( not line.contains_uncollapsable_type_comments() - and not line.should_split + and not line.should_split_rhs and not line.magic_trailing_comma and ( is_line_short_enough(line, line_length=mode.line_length, line_str=line_str) @@ -4387,7 +4416,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): mode=line.mode, depth=line.depth + 1, inside_brackets=True, - should_split=line.should_split, + should_split_rhs=line.should_split_rhs, magic_trailing_comma=line.magic_trailing_comma, ) string_leaf = Leaf(token.STRING, string_value) @@ -5010,7 +5039,7 @@ def bracket_split_build_line( for comment_after in original.comments_after(leaf): result.append(comment_after, preformatted=True) if is_body and should_split_line(result, opening_bracket): - result.should_split = True + result.should_split_rhs = True return result @@ -5368,10 +5397,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: check_lpar = True if check_lpar: - if is_walrus_assignment(child): - pass - - elif child.type == syms.atom: + if child.type == syms.atom: if maybe_make_parens_invisible_in_atom(child, parent=node): wrap_in_parentheses(node, child, visible=False) elif is_one_tuple(child): @@ -5543,6 +5569,7 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: Returns whether the node should itself be wrapped in invisible parentheses. """ + if ( node.type != syms.atom or is_empty_tuple(node) @@ -5552,6 +5579,10 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: ): return False + if is_walrus_assignment(node): + if parent.type in [syms.annassign, syms.expr_stmt]: + return False + first = node.children[0] last = node.children[-1] if first.type == token.LPAR and last.type == token.RPAR: @@ -5613,6 +5644,17 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]: return wrapped +def first_child_is_arith(node: Node) -> bool: + """Whether first child is an arithmetic or a binary arithmetic expression""" + expr_types = { + syms.arith_expr, + syms.shift_expr, + syms.xor_expr, + syms.and_expr, + } + return bool(node.children and node.children[0].type in expr_types) + + def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None: """Wrap `child` in parentheses. @@ -6097,17 +6139,27 @@ def normalize_path_maybe_ignore( return normalized_path +def path_is_excluded( + normalized_path: str, + pattern: Optional[Pattern[str]], +) -> bool: + match = pattern.search(normalized_path) if pattern else None + return bool(match and match.group(0)) + + def gen_python_files( paths: Iterable[Path], root: Path, include: Optional[Pattern[str]], exclude: Pattern[str], + extend_exclude: Optional[Pattern[str]], force_exclude: Optional[Pattern[str]], report: "Report", gitignore: PathSpec, ) -> Iterator[Path]: """Generate all files under `path` whose paths are not excluded by the - `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex. + `exclude_regex`, `extend_exclude`, or `force_exclude` regexes, + but are included by the `include` regex. Symbolic links pointing outside of the `root` directory are ignored. @@ -6124,20 +6176,22 @@ def gen_python_files( report.path_ignored(child, "matches the .gitignore file content") continue - # Then ignore with `--exclude` and `--force-exclude` options. + # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options. normalized_path = "/" + normalized_path if child.is_dir(): normalized_path += "/" - exclude_match = exclude.search(normalized_path) if exclude else None - if exclude_match and exclude_match.group(0): + if path_is_excluded(normalized_path, exclude): report.path_ignored(child, "matches the --exclude regular expression") continue - force_exclude_match = ( - force_exclude.search(normalized_path) if force_exclude else None - ) - if force_exclude_match and force_exclude_match.group(0): + if path_is_excluded(normalized_path, extend_exclude): + report.path_ignored( + child, "matches the --extend-exclude regular expression" + ) + continue + + if path_is_excluded(normalized_path, force_exclude): report.path_ignored(child, "matches the --force-exclude regular expression") continue @@ -6147,6 +6201,7 @@ def gen_python_files( root, include, exclude, + extend_exclude, force_exclude, report, gitignore, @@ -6197,6 +6252,22 @@ def find_project_root(srcs: Iterable[str]) -> Path: return directory +@lru_cache() +def find_user_pyproject_toml() -> Path: + r"""Return the path to the top-level user configuration for black. + + This looks for ~\.black on Windows and ~/.config/black on Linux and other + Unix systems. + """ + if sys.platform == "win32": + # Windows + user_config_path = Path.home() / ".black" + else: + config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config") + user_config_path = Path(config_root).expanduser() / "black" + return user_config_path.resolve() + + @dataclass class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" @@ -6298,7 +6369,12 @@ def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]: return ast3.parse(src, filename, feature_version=feature_version) except SyntaxError: continue - + if ast27.__name__ == "ast": + raise SyntaxError( + "The requested source code has invalid Python 3 syntax.\n" + "If you are trying to format Python 2 files please reinstall Black" + " with the 'python2' extra: `python3 -m pip install black[python2]`." + ) return ast27.parse(src) @@ -6425,14 +6501,14 @@ def assert_stable(src: str, dst: str, mode: Mode) -> None: @mypyc_attr(patchable=True) -def dump_to_file(*output: str) -> str: +def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str: """Dump `output` to a temporary file. Return path to the file.""" with tempfile.NamedTemporaryFile( mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" ) as f: for lines in output: f.write(lines) - if lines and lines[-1] != "\n": + if ensure_final_newline and lines and lines[-1] != "\n": f.write("\n") return f.name @@ -6450,11 +6526,20 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: """Return a unified diff string between strings `a` and `b`.""" import difflib - a_lines = [line + "\n" for line in a.splitlines()] - b_lines = [line + "\n" for line in b.splitlines()] - return "".join( - difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5) - ) + a_lines = [line for line in a.splitlines(keepends=True)] + b_lines = [line for line in b.splitlines(keepends=True)] + diff_lines = [] + for line in difflib.unified_diff( + a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5 + ): + # Work around https://bugs.python.org/issue2142 + # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html + if line[-1] == "\n": + diff_lines.append(line) + else: + diff_lines.append(line + "\n") + diff_lines.append("\\ No newline at end of file\n") + return "".join(diff_lines) def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None: @@ -6658,7 +6743,7 @@ def can_omit_invisible_parens( # unnecessary. return True - if penultimate.type == token.COMMA: + if line.magic_trailing_comma and penultimate.type == token.COMMA: # The rightmost non-omitted bracket pair is the one we want to explode on. return True