From a0230fd16f5c80b6f85b0c97bfd5173049e06b3d Mon Sep 17 00:00:00 2001 From: Fabian Meumertzheim Date: Wed, 29 Mar 2023 22:42:17 +0200 Subject: [PATCH 1/2] Verify that `module()` is called first --- .../lib/bazel/bzlmod/ModuleFileGlobals.java | 14 +++++++++ .../bazel/bzlmod/ModuleFileFunctionTest.java | 30 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java index 3b3f42c0ba0709..59db92124946cf 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java @@ -61,6 +61,7 @@ public class ModuleFileGlobals { Pattern.compile("(>|<|-|<=|>=)(\\d+\\.){2}\\d+"); private boolean moduleCalled = false; + private boolean hadNonModuleCall = false; private final boolean ignoreDevDeps; private final Module.Builder module; private final Map deps = new LinkedHashMap<>(); @@ -208,6 +209,9 @@ public void module( if (moduleCalled) { throw Starlark.errorf("the module() directive can only be called once"); } + if (hadNonModuleCall) { + throw Starlark.errorf("if module() is called, it must be called before any other functions"); + } moduleCalled = true; if (!name.isEmpty()) { validateModuleName(name); @@ -298,6 +302,7 @@ private static ImmutableList checkAllCompatibilityVersions( public void bazelDep( String name, String version, String repoName, boolean devDependency, StarlarkThread thread) throws EvalException { + hadNonModuleCall = true; if (repoName.isEmpty()) { repoName = name; } @@ -330,6 +335,7 @@ public void bazelDep( allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)}, doc = "The labels of the platforms to register.")) public void registerExecutionPlatforms(Sequence platformLabels) throws EvalException { + hadNonModuleCall = true; module.addExecutionPlatformsToRegister( checkAllAbsolutePatterns(platformLabels, "register_execution_platforms")); } @@ -347,6 +353,7 @@ public void registerExecutionPlatforms(Sequence platformLabels) throws EvalEx allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)}, doc = "The labels of the toolchains to register.")) public void registerToolchains(Sequence toolchainLabels) throws EvalException { + hadNonModuleCall = true; module.addToolchainsToRegister( checkAllAbsolutePatterns(toolchainLabels, "register_toolchains")); } @@ -377,6 +384,7 @@ public void registerToolchains(Sequence toolchainLabels) throws EvalException useStarlarkThread = true) public ModuleExtensionProxy useExtension( String extensionBzlFile, String extensionName, boolean devDependency, StarlarkThread thread) { + hadNonModuleCall = true; ModuleExtensionUsageBuilder newUsageBuilder = new ModuleExtensionUsageBuilder( extensionBzlFile, extensionName, thread.getCallerLocation()); @@ -516,6 +524,7 @@ public void useRepo( Dict kwargs, StarlarkThread thread) throws EvalException { + hadNonModuleCall = true; Location location = thread.getCallerLocation(); for (String arg : Sequence.cast(args, String.class, "args")) { extensionProxy.addImport(arg, arg, location); @@ -598,6 +607,7 @@ public void singleVersionOverride( Iterable patchCmds, StarlarkInt patchStrip) throws EvalException { + hadNonModuleCall = true; Version parsedVersion; try { parsedVersion = Version.parse(version); @@ -652,6 +662,7 @@ public void singleVersionOverride( }) public void multipleVersionOverride(String moduleName, Iterable versions, String registry) throws EvalException { + hadNonModuleCall = true; ImmutableList.Builder parsedVersionsBuilder = new ImmutableList.Builder<>(); try { for (String version : Sequence.cast(versions, String.class, "versions").getImmutableList()) { @@ -735,6 +746,7 @@ public void archiveOverride( Iterable patchCmds, StarlarkInt patchStrip) throws EvalException { + hadNonModuleCall = true; ImmutableList urlList = urls instanceof String ? ImmutableList.of((String) urls) @@ -806,6 +818,7 @@ public void gitOverride( Iterable patchCmds, StarlarkInt patchStrip) throws EvalException { + hadNonModuleCall = true; addOverride( moduleName, GitOverride.create( @@ -835,6 +848,7 @@ public void gitOverride( positional = false), }) public void localPathOverride(String moduleName, String path) throws EvalException { + hadNonModuleCall = true; addOverride(moduleName, LocalPathOverride.create(path)); } diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java index 2df3f7af45a1e2..5321617d4c79b6 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java @@ -956,4 +956,34 @@ public void moduleRepoName_conflict() throws Exception { assertContainsEvent("The repo name 'bbb' is already being used as the module's own repo name"); } + + @Test + public void module_calledTwice() throws Exception { + scratch.file( + rootDirectory.getRelative("MODULE.bazel").getPathString(), + "module(name='aaa',version='0.1',repo_name='bbb')", + "module(name='aaa',version='0.1',repo_name='bbb')"); + FakeRegistry registry = registryFactory.newFakeRegistry("/foo"); + ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl())); + + reporter.removeHandler(failFastHandler); // expect failures + evaluator.evaluate(ImmutableList.of(ModuleFileValue.KEY_FOR_ROOT_MODULE), evaluationContext); + + assertContainsEvent("the module() directive can only be called once"); + } + + @Test + public void module_calledLate() throws Exception { + scratch.file( + rootDirectory.getRelative("MODULE.bazel").getPathString(), + "use_extension('//:extensions.bzl', 'my_ext')", + "module(name='aaa',version='0.1',repo_name='bbb')"); + FakeRegistry registry = registryFactory.newFakeRegistry("/foo"); + ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl())); + + reporter.removeHandler(failFastHandler); // expect failures + evaluator.evaluate(ImmutableList.of(ModuleFileValue.KEY_FOR_ROOT_MODULE), evaluationContext); + + assertContainsEvent("if module() is called, it must be called before any other functions"); + } } From 0df7c14ec44d593626994d66b09badd6caa508a9 Mon Sep 17 00:00:00 2001 From: Fabian Meumertzheim Date: Wed, 29 Mar 2023 22:42:36 +0200 Subject: [PATCH 2/2] Normalize `use_extension` label Normalize the label by adding the current module's repo_name if the label doesn't specify a repository name. This is necessary as ModuleExtensionUsages are grouped by the string value of this label, but later mapped to their Label representation. If multiple strings map to the same Label, this would result in a crash. --- .../build/lib/bazel/bzlmod/Module.java | 2 + .../lib/bazel/bzlmod/ModuleFileGlobals.java | 23 +++++- .../bzlmod/ModuleExtensionResolutionTest.java | 82 +++++++++++++++++++ .../bazel/bzlmod/ModuleFileFunctionTest.java | 8 +- 4 files changed, 109 insertions(+), 6 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java index 2952548e851bd5..cfc09237b415b0 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java @@ -242,6 +242,8 @@ public Builder addExtensionUsage(ModuleExtensionUsage value) { return this; } + abstract ModuleKey getKey(); + abstract String getName(); abstract Optional getRepoName(); diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java index 59db92124946cf..0a4bbf3391976a 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java @@ -382,9 +382,12 @@ public void registerToolchains(Sequence toolchainLabels) throws EvalException defaultValue = "False"), }, useStarlarkThread = true) - public ModuleExtensionProxy useExtension( - String extensionBzlFile, String extensionName, boolean devDependency, StarlarkThread thread) { + public ModuleExtensionProxy useExtension(String rawExtensionBzlFile, String extensionName, + boolean devDependency, StarlarkThread thread) { hadNonModuleCall = true; + + String extensionBzlFile = normalizeLabelString(rawExtensionBzlFile); + ModuleExtensionUsageBuilder newUsageBuilder = new ModuleExtensionUsageBuilder( extensionBzlFile, extensionName, thread.getCallerLocation()); @@ -407,6 +410,22 @@ public ModuleExtensionProxy useExtension( return newUsageBuilder.getProxy(devDependency); } + private String normalizeLabelString(String rawExtensionBzlFile) { + // Normalize the label by adding the current module's repo_name if the label doesn't specify a + // repository name. This is necessary as ModuleExtensionUsages are grouped by the string value + // of this label, but later mapped to their Label representation. If multiple strings map to the + // same Label, this would result in a crash. + // ownName can't change anymore as calling module() after this results in an error. + String ownName = module.getRepoName().orElse(module.getName()); + if (module.getKey().equals(ModuleKey.ROOT) && rawExtensionBzlFile.startsWith("@//")) { + return "@" + ownName + rawExtensionBzlFile.substring(1); + } else if (rawExtensionBzlFile.startsWith("//")) { + return "@" + ownName + rawExtensionBzlFile; + } else { + return rawExtensionBzlFile; + } + } + class ModuleExtensionUsageBuilder { private final String extensionBzlFile; private final String extensionName; diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java index a6f3127a8d9b24..2b488324e115b0 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java @@ -329,6 +329,88 @@ public void simpleExtension() throws Exception { assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba"); } + @Test + public void simpleExtension_nonCanonicalLabel() throws Exception { + scratch.file( + workspaceRoot.getRelative("MODULE.bazel").getPathString(), + "module(name='my_module', version = '1.0')", + "bazel_dep(name='data_repo', version='1.0')", + "ext1 = use_extension('//:defs.bzl', 'ext')", + "ext1.tag(name='foo', data='fu')", + "use_repo(ext1, 'foo')", + "ext2 = use_extension('@my_module//:defs.bzl', 'ext')", + "ext2.tag(name='bar', data='ba')", + "use_repo(ext2, 'bar')", + "ext3 = use_extension('@//:defs.bzl', 'ext')", + "ext3.tag(name='quz', data='qu')", + "use_repo(ext3, 'quz')"); + scratch.file( + workspaceRoot.getRelative("defs.bzl").getPathString(), + "load('@data_repo//:defs.bzl','data_repo')", + "tag = tag_class(attrs = {'name':attr.string(),'data':attr.string()})", + "def _ext_impl(ctx):", + " for mod in ctx.modules:", + " for tag in mod.tags.tag:", + " data_repo(name=tag.name,data=tag.data)", + "ext = module_extension(implementation=_ext_impl, tag_classes={'tag':tag})"); + scratch.file(workspaceRoot.getRelative("BUILD").getPathString()); + scratch.file( + workspaceRoot.getRelative("data.bzl").getPathString(), + "load('@foo//:data.bzl', foo_data='data')", + "load('@bar//:data.bzl', bar_data='data')", + "load('@quz//:data.bzl', quz_data='data')", + "data = 'foo:'+foo_data+' bar:'+bar_data+' quz:'+quz_data"); + + SkyKey skyKey = BzlLoadValue.keyForBuild(Label.parseCanonical("//:data.bzl")); + EvaluationResult result = + evaluator.evaluate(ImmutableList.of(skyKey), evaluationContext); + if (result.hasError()) { + throw result.getError().getException(); + } + assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba quz:qu"); + } + + @Test + public void simpleExtension_nonCanonicalLabel_repoName() throws Exception { + scratch.file( + workspaceRoot.getRelative("MODULE.bazel").getPathString(), + "module(name='my_module', version = '1.0', repo_name='my_name')", + "bazel_dep(name='data_repo', version='1.0')", + "ext1 = use_extension('//:defs.bzl', 'ext')", + "ext1.tag(name='foo', data='fu')", + "use_repo(ext1, 'foo')", + "ext2 = use_extension('@my_name//:defs.bzl', 'ext')", + "ext2.tag(name='bar', data='ba')", + "use_repo(ext2, 'bar')", + "ext3 = use_extension('@//:defs.bzl', 'ext')", + "ext3.tag(name='quz', data='qu')", + "use_repo(ext3, 'quz')"); + scratch.file( + workspaceRoot.getRelative("defs.bzl").getPathString(), + "load('@data_repo//:defs.bzl','data_repo')", + "tag = tag_class(attrs = {'name':attr.string(),'data':attr.string()})", + "def _ext_impl(ctx):", + " for mod in ctx.modules:", + " for tag in mod.tags.tag:", + " data_repo(name=tag.name,data=tag.data)", + "ext = module_extension(implementation=_ext_impl, tag_classes={'tag':tag})"); + scratch.file(workspaceRoot.getRelative("BUILD").getPathString()); + scratch.file( + workspaceRoot.getRelative("data.bzl").getPathString(), + "load('@foo//:data.bzl', foo_data='data')", + "load('@bar//:data.bzl', bar_data='data')", + "load('@quz//:data.bzl', quz_data='data')", + "data = 'foo:'+foo_data+' bar:'+bar_data+' quz:'+quz_data"); + + SkyKey skyKey = BzlLoadValue.keyForBuild(Label.parseCanonical("//:data.bzl")); + EvaluationResult result = + evaluator.evaluate(ImmutableList.of(skyKey), evaluationContext); + if (result.hasError()) { + throw result.getError().getException(); + } + assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba quz:qu"); + } + @Test public void multipleModules() throws Exception { scratch.file( diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java index 5321617d4c79b6..1b8e52ac9cc8fe 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java @@ -473,7 +473,7 @@ public void testModuleExtensions_good() throws Exception { .setRegistry(registry) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@mymod//:defs.bzl") .setExtensionName("myext1") .setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 2, 23)) .setImports(ImmutableBiMap.of("repo1", "repo1")) @@ -491,7 +491,7 @@ public void testModuleExtensions_good() throws Exception { .build()) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@mymod//:defs.bzl") .setExtensionName("myext2") .setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 5, 23)) .setImports(ImmutableBiMap.of("other_repo1", "repo1", "repo2", "repo2")) @@ -582,7 +582,7 @@ public void testModuleExtensions_duplicateProxy_asRoot() throws Exception { .setKey(ModuleKey.ROOT) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@//:defs.bzl") .setExtensionName("myext") .setLocation(Location.fromFileLineColumn("/MODULE.bazel", 1, 23)) .setImports( @@ -672,7 +672,7 @@ public void testModuleExtensions_duplicateProxy_asDep() throws Exception { .setRegistry(registry) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@mymod//:defs.bzl") .setExtensionName("myext") .setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 5, 23)) .setImports(ImmutableBiMap.of("beta", "beta", "delta", "delta"))