From: Lihu Ben-Ezri-Ravin Date: Wed, 24 Jun 2020 09:09:07 +0000 (-0400) Subject: Find project root correctly (#1518) X-Git-Url: https://git.madduck.net/etc/vim.git/commitdiff_plain/2471b9256d9d9dfea1124d20072201693b9b0865?ds=inline;hp=f90f50a7436ca13517933c290ef007e7cb2e7258 Find project root correctly (#1518) Ensure root dir is a common parent of all inputs Fixes #1493 --- diff --git a/src/black/__init__.py b/src/black/__init__.py index 2b2d3d8..d4c6e62 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -5825,8 +5825,8 @@ def gen_python_files( def find_project_root(srcs: Iterable[str]) -> Path: """Return a directory containing .git, .hg, or pyproject.toml. - That directory can be one of the directories passed in `srcs` or their - common parent. + That directory will be a common parent of all files and directories + passed in `srcs`. If no directory in the tree contains a marker that would specify it's the project root, the root of the file system is returned. @@ -5834,11 +5834,20 @@ def find_project_root(srcs: Iterable[str]) -> Path: if not srcs: return Path("/").resolve() - common_base = min(Path(src).resolve() for src in srcs) - if common_base.is_dir(): - # Append a fake file so `parents` below returns `common_base_dir`, too. - common_base /= "fake-file" - for directory in common_base.parents: + path_srcs = [Path(src).resolve() for src in srcs] + + # A list of lists of parents for each 'src'. 'src' is included as a + # "parent" of itself if it is a directory + src_parents = [ + list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs + ] + + common_base = max( + set.intersection(*(set(parents) for parents in src_parents)), + key=lambda path: path.parts, + ) + + for directory in (common_base, *common_base.parents): if (directory / ".git").exists(): return directory diff --git a/tests/test_black.py b/tests/test_black.py index 88839d8..3ed5daa 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -1801,6 +1801,28 @@ class BlackTestCase(unittest.TestCase): self.assertEqual(config["exclude"], r"\.pyi?$") self.assertEqual(config["include"], r"\.py?$") + def test_find_project_root(self) -> None: + with TemporaryDirectory() as workspace: + root = Path(workspace) + test_dir = root / "test" + test_dir.mkdir() + + src_dir = root / "src" + src_dir.mkdir() + + root_pyproject = root / "pyproject.toml" + root_pyproject.touch() + src_pyproject = src_dir / "pyproject.toml" + src_pyproject.touch() + src_python = src_dir / "foo.py" + src_python.touch() + + self.assertEqual( + black.find_project_root((src_dir, test_dir)), root.resolve() + ) + self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve()) + self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve()) + class BlackDTestCase(AioHTTPTestCase): async def get_application(self) -> web.Application: