diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/GitOverride.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/GitOverride.java index 1c209f18ed8543..7256d2cf2ee9bb 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/GitOverride.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/GitOverride.java @@ -28,9 +28,10 @@ public static GitOverride create( ImmutableList patches, ImmutableList patchCmds, int patchStrip, - boolean initSubmodules) { + boolean initSubmodules, + String stripPrefix) { return new AutoValue_GitOverride( - remote, commit, patches, patchCmds, patchStrip, initSubmodules); + remote, commit, patches, patchCmds, patchStrip, initSubmodules, stripPrefix); } /** The URL pointing to the git repository. */ @@ -51,6 +52,9 @@ public static GitOverride create( /** Whether submodules in the fetched repo should be recursively initialized. */ public abstract boolean getInitSubmodules(); + /** The directory prefix to strip from the extracted files. */ + public abstract String getStripPrefix(); + /** Returns the {@link RepoSpec} that defines this repository. */ @Override public RepoSpec getRepoSpec() { @@ -61,6 +65,7 @@ public RepoSpec getRepoSpec() { .setPatchCmds(getPatchCmds()) .setPatchArgs(ImmutableList.of("-p" + getPatchStrip())) .setInitSubmodules(getInitSubmodules()) + .setStripPrefix(getStripPrefix()) .build(); } diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java index 2d747935dd0129..f620d11520312a 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java @@ -984,10 +984,20 @@ public void archiveOverride( defaultValue = "0"), @Param( name = "init_submodules", - doc = "Whether submodules in the fetched repo should be recursively initialized.", + doc = "Whether git submodules in the fetched repo should be recursively initialized.", named = true, positional = false, defaultValue = "False"), + @Param( + name = "strip_prefix", + doc = + "A directory prefix to strip from the extracted files. This can be used to target" + + " a subdirectory of the git repo. Note that the subdirectory must have its" + + " own `MODULE.bazel` file with a module name that is the same as the" + + " `module_name` arg passed to this `git_override`.", + named = true, + positional = false, + defaultValue = "''"), }, useStarlarkThread = true) public void gitOverride( @@ -998,6 +1008,7 @@ public void gitOverride( Iterable patchCmds, StarlarkInt patchStrip, boolean initSubmodules, + String stripPrefix, StarlarkThread thread) throws EvalException { ModuleThreadContext context = ModuleThreadContext.fromOrFail(thread, "git_override()"); @@ -1013,7 +1024,8 @@ public void gitOverride( .collect(toImmutableList()), Sequence.cast(patchCmds, String.class, "patchCmds").getImmutableList(), patchStrip.toInt("git_override.patch_strip"), - initSubmodules)); + initSubmodules, + stripPrefix)); } @StarlarkMethod( diff --git a/src/test/py/bazel/bzlmod/bazel_overrides_test.py b/src/test/py/bazel/bzlmod/bazel_overrides_test.py index 9236094951c847..45c3f4d9890fd1 100644 --- a/src/test/py/bazel/bzlmod/bazel_overrides_test.py +++ b/src/test/py/bazel/bzlmod/bazel_overrides_test.py @@ -15,6 +15,7 @@ # pylint: disable=g-long-ternary import os +import shutil import tempfile from absl.testing import absltest from src.test.py.bazel import test_base @@ -36,6 +37,8 @@ def setUp(self): 'bbb', '1.1', {'aaa': '1.1'} ).createCcModule( 'ccc', '1.1', {'aaa': '1.1', 'bbb': '1.1'} + ).createCcModule( + 'ddd', '1.0' ) self.ScratchFile( '.bazelrc', @@ -272,6 +275,104 @@ def testGitOverride(self): self.assertIn('main function => bbb@1.1', stdout) self.assertIn('bbb@1.1 => aaa@1.0 (locally patched)', stdout) + def testGitOverrideStripPrefix(self): + self.writeMainProjectFiles() + + # Update BUILD and main.cc to also call `ddd`. + self.ScratchFile( + 'BUILD', + [ + 'cc_binary(', + ' name = "main",', + ' srcs = ["main.cc"],', + ' deps = [', + ' "@aaa//:lib_aaa",', + ' "@bbb//:lib_bbb",', + ' "@ddd//:lib_ddd",', + ' ],', + ')', + ], + ) + self.ScratchFile( + 'main.cc', + [ + '#include "aaa.h"', + '#include "bbb.h"', + '#include "ddd.h"', + 'int main() {', + ' hello_aaa("main function");', + ' hello_bbb("main function");', + ' hello_ddd("main function");', + '}', + ], + ) + src_aaa_1_0 = self.main_registry.projects.joinpath('aaa', '1.0') + src_ddd_1_0 = self.main_registry.projects.joinpath('ddd', '1.0') + self.RunProgram(['git', 'init'], cwd=src_aaa_1_0) + self.RunProgram( + ['git', 'config', 'user.name', 'tester'], + cwd=src_aaa_1_0, + ) + self.RunProgram( + ['git', 'config', 'user.email', 'tester@foo.com'], + cwd=src_aaa_1_0, + ) + + # Make a subdirectory that itself is the published module 'ddd'. + subdir_name = 'subdir_containing_ddd' + shutil.copytree(src=src_ddd_1_0, dst=src_aaa_1_0 / subdir_name) + + # Edit the code in 'subdir_containing_ddd/ddd.cc' so that we can assert + # that we're using it. + src_aaa_relpath = src_aaa_1_0.relative_to(self._test_cwd) + self.ScratchFile( + str(src_aaa_relpath / subdir_name / 'ddd.cc'), + [ + '#include ', + '#include "ddd.h"', + 'void hello_ddd(const std::string& caller) {', + ' std::string lib_name = "ddd@1.0";', + ( + ' printf("%s => %s from subdir\\n", caller.c_str(),' + ' lib_name.c_str());' + ), + '}', + ], + ) + + self.RunProgram(['git', 'add', './'], cwd=src_aaa_1_0) + self.RunProgram( + ['git', 'commit', '-m', 'Initial commit.'], + cwd=src_aaa_1_0, + ) + + _, stdout, _ = self.RunProgram( + ['git', 'rev-parse', 'HEAD'], cwd=src_aaa_1_0 + ) + + commit = stdout[0].strip() + + self.ScratchFile( + 'MODULE.bazel', + [ + 'bazel_dep(name = "aaa", version = "1.1")', + 'bazel_dep(name = "bbb", version = "1.1")', + 'bazel_dep(name = "ddd", version = "1.0")', + 'git_override(', + ' module_name = "ddd",', + ' remote = "%s",' % src_aaa_1_0.as_uri(), + ' commit = "%s",' % commit, + ' strip_prefix = "%s",' % subdir_name, + ')', + ], + ) + + _, stdout, _ = self.RunBazel(['run', '//:main']) + self.assertIn('main function => aaa@1.1', stdout) + self.assertIn('main function => bbb@1.1', stdout) + self.assertIn('bbb@1.1 => aaa@1.1', stdout) + self.assertIn('main function => ddd@1.0 from subdir', stdout) + def testLocalPathOverride(self): src_aaa_1_0 = self.main_registry.projects.joinpath('aaa', '1.0') self.writeMainProjectFiles()