From e3c71c3a477a44e6d817d37825a59bc6ba6a9897 Mon Sep 17 00:00:00 2001 From: Joshua Cannon Date: Tue, 2 Mar 2021 19:21:50 -0600 Subject: [PATCH] Turn test_regex into a click callback (#2016) Co-authored-by: Richard Si <63936253+ichard26@users.noreply.github.com> --- src/black/__init__.py | 61 ++++++++++++++++++++----------------------- tests/test_black.py | 28 ++++++++++---------- 2 files changed, 43 insertions(+), 46 deletions(-) diff --git a/src/black/__init__.py b/src/black/__init__.py index e21e2af..a8f4f89 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -363,6 +363,17 @@ def target_version_option_callback( return [TargetVersion[val.upper()] for val in v] +def validate_regex( + ctx: click.Context, + param: click.Parameter, + value: Optional[str], +) -> Optional[Pattern]: + try: + return re_compile_maybe_verbose(value) if value is not None else None + except re.error: + raise click.BadParameter("Not a valid regular expression") + + @click.command(context_settings=dict(help_option_names=["-h", "--help"])) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( @@ -441,6 +452,7 @@ def target_version_option_callback( "--include", type=str, default=DEFAULT_INCLUDES, + callback=validate_regex, help=( "A regular expression that matches files and directories that should be" " included on recursive searches. An empty value means all files are included" @@ -453,6 +465,7 @@ def target_version_option_callback( "--exclude", type=str, default=DEFAULT_EXCLUDES, + callback=validate_regex, help=( "A regular expression that matches files and directories that should be" " excluded on recursive searches. An empty value means no paths are excluded." @@ -464,6 +477,7 @@ def target_version_option_callback( @click.option( "--extend-exclude", type=str, + callback=validate_regex, help=( "Like --exclude, but adds additional files and directories on top of the" " excluded ones. (Useful if you simply want to add to the default)" @@ -472,6 +486,7 @@ def target_version_option_callback( @click.option( "--force-exclude", type=str, + callback=validate_regex, help=( "Like --exclude, but files and directories matching this regex will be " "excluded even when they are passed explicitly as arguments." @@ -543,10 +558,10 @@ def main( experimental_string_processing: bool, quiet: bool, verbose: bool, - include: str, - exclude: str, - extend_exclude: Optional[str], - force_exclude: Optional[str], + include: Pattern, + exclude: Pattern, + extend_exclude: Optional[Pattern], + force_exclude: Optional[Pattern], stdin_filename: Optional[str], src: Tuple[str, ...], config: Optional[str], @@ -612,39 +627,21 @@ def main( ctx.exit(report.return_code) -def test_regex( - ctx: click.Context, - regex_name: str, - regex: Optional[str], -) -> Optional[Pattern]: - try: - return re_compile_maybe_verbose(regex) if regex is not None else None - except re.error: - err(f"Invalid regular expression for {regex_name} given: {regex!r}") - ctx.exit(2) - - def get_sources( *, ctx: click.Context, src: Tuple[str, ...], quiet: bool, verbose: bool, - include: str, - exclude: str, - extend_exclude: Optional[str], - force_exclude: Optional[str], + include: Pattern[str], + exclude: Pattern[str], + extend_exclude: Optional[Pattern[str]], + force_exclude: Optional[Pattern[str]], report: "Report", stdin_filename: Optional[str], ) -> Set[Path]: """Compute the set of files to be formatted.""" - include_regex = test_regex(ctx, "include", include) - exclude_regex = test_regex(ctx, "exclude", exclude) - assert exclude_regex is not None - extend_exclude_regex = test_regex(ctx, "extend_exclude", extend_exclude) - force_exclude_regex = test_regex(ctx, "force_exclude", force_exclude) - root = find_project_root(src) sources: Set[Path] = set() path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx) @@ -665,8 +662,8 @@ def get_sources( normalized_path = "/" + normalized_path # Hard-exclude any files that matches the `--force-exclude` regex. - if force_exclude_regex: - force_exclude_match = force_exclude_regex.search(normalized_path) + if force_exclude: + force_exclude_match = force_exclude.search(normalized_path) else: force_exclude_match = None if force_exclude_match and force_exclude_match.group(0): @@ -682,10 +679,10 @@ def get_sources( gen_python_files( p.iterdir(), root, - include_regex, - exclude_regex, - extend_exclude_regex, - force_exclude_regex, + include, + exclude, + extend_exclude, + force_exclude, report, gitignore, ) diff --git a/tests/test_black.py b/tests/test_black.py index ba1869a..72e16a3 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -1375,8 +1375,8 @@ class BlackTestCase(BlackBaseTestCase): src=(src,), quiet=True, verbose=False, - include=include, - exclude=exclude, + include=re.compile(include), + exclude=re.compile(exclude), extend_exclude=None, force_exclude=None, report=report, @@ -1398,8 +1398,8 @@ class BlackTestCase(BlackBaseTestCase): src=(src,), quiet=True, verbose=False, - include=include, - exclude=exclude, + include=re.compile(include), + exclude=re.compile(exclude), extend_exclude=None, force_exclude=None, report=report, @@ -1422,8 +1422,8 @@ class BlackTestCase(BlackBaseTestCase): src=(src,), quiet=True, verbose=False, - include=include, - exclude=exclude, + include=re.compile(include), + exclude=re.compile(exclude), extend_exclude=None, force_exclude=None, report=report, @@ -1450,8 +1450,8 @@ class BlackTestCase(BlackBaseTestCase): src=(src,), quiet=True, verbose=False, - include=include, - exclude=exclude, + include=re.compile(include), + exclude=re.compile(exclude), extend_exclude=None, force_exclude=None, report=report, @@ -1478,9 +1478,9 @@ class BlackTestCase(BlackBaseTestCase): src=(src,), quiet=True, verbose=False, - include=include, - exclude="", - extend_exclude=extend_exclude, + include=re.compile(include), + exclude=re.compile(""), + extend_exclude=re.compile(extend_exclude), force_exclude=None, report=report, stdin_filename=stdin_filename, @@ -1504,10 +1504,10 @@ class BlackTestCase(BlackBaseTestCase): src=(src,), quiet=True, verbose=False, - include=include, - exclude="", + include=re.compile(include), + exclude=re.compile(""), extend_exclude=None, - force_exclude=force_exclude, + force_exclude=re.compile(force_exclude), report=report, stdin_filename=stdin_filename, ) -- 2.39.5