Skip to content

Commit

Permalink
Prevent side-effects in initializer callbacks
Browse files Browse the repository at this point in the history
Side-effect are prevented by opening a new StarlarkThread without BazelStarlarkContext, Most of the side-effectful functions we want to prevent need the context.

Fix Bazel crashes when special functions are used in callback. There are several other preexisting locations where StarlarkThread without BazelStarlarkContext is used to execute a callback. If callback uses a function that calls BazelStarlarkContext.from, this resulted in ISE and Bazel crash. Caused solely by user inputs.

Improve error messaging. Whenever StarlarkThread wo/ BazelStarlarkContext is used for a callback, we also provide information where the problem happened.

PiperOrigin-RevId: 581146696
Change-Id: I682fdaabd881814a7ef42ab55aae2a50367bf213
  • Loading branch information
comius authored and copybara-github committed Nov 10, 2023
1 parent 305ab3b commit 8473928
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ public ActionResult execute(ActionExecutionContext ctx)
Object substitutionDictObject = null;
try (Mutability mutability = Mutability.create("translate_build_info_file")) {
try {
StarlarkThread thread = new StarlarkThread(mutability, semantics);
StarlarkThread thread =
new StarlarkThread(
mutability,
semantics,
isVolatile() ? "transform_version_file callback" : "transform_info_file callback");
substitutionDictObject =
Starlark.call(
thread,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ public StarlarkActionResourceSetBuilder(
@Override
public ResourceSet buildResourceSet(OS os, int inputsSize) throws ExecException {
try (Mutability mu = Mutability.create("resource_set_builder_function")) {
StarlarkThread thread = new StarlarkThread(mu, semantics);
StarlarkThread thread = new StarlarkThread(mu, semantics, "resource_set callback");
StarlarkInt inputInt = StarlarkInt.of(inputsSize);
Object response =
Starlark.call(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1021,8 +1021,8 @@ public Object call(StarlarkThread thread, Tuple args, Dict<String, Object> kwarg
throw new EvalException("Unexpected positional arguments");
}
try {
BazelStarlarkContext.from(thread).checkLoadingPhase(getName());
} catch (IllegalStateException unused) {
BazelStarlarkContext.checkLoadingPhase(thread, getName());
} catch (EvalException unused) {
throw new EvalException(
"A rule can only be instantiated in a BUILD file, or a macro "
+ "invoked from a BUILD file");
Expand All @@ -1033,58 +1033,71 @@ public Object call(StarlarkThread thread, Tuple args, Dict<String, Object> kwarg

validateRulePropagatedAspects(ruleClass);

// We call all the initializers of the rule and its ancestor rules, proceeding from child to
// ancestor, so each initializer can transform the attributes it knows about in turn.
for (RuleClass currentRuleClass = ruleClass;
currentRuleClass != null;
currentRuleClass = currentRuleClass.getStarlarkParent()) {
if (currentRuleClass.getInitializer() == null) {
continue;
}
// Remove {@link BazelStarlarkContext} to prevent calls to load and analysis time functions.
// Mutating values in initializers is mostly not a problem, because the attribute values are
// copied before calling the initializers (<-TODO) and before they are set on the target.
// Exception is a legacy case allowing arbitrary type of parameter values. In that case the
// values may be mutated by the initializer, but they are still copied when set on the target.
BazelStarlarkContext bazelStarlarkContext = BazelStarlarkContext.fromOrFail(thread);
try {
thread.setThreadLocal(BazelStarlarkContext.class, null);
thread.setUncheckedExceptionContext(() -> "an initializer");

// We call all the initializers of the rule and its ancestor rules, proceeding from child to
// ancestor, so each initializer can transform the attributes it knows about in turn.
for (RuleClass currentRuleClass = ruleClass;
currentRuleClass != null;
currentRuleClass = currentRuleClass.getStarlarkParent()) {
if (currentRuleClass.getInitializer() == null) {
continue;
}

// TODO: b/298561048 - lift parameters to more accurate type - for example strings to
// Labels
// You might feel tempted to inspect the signature of the initializer function. The
// temptation might come from handling default values, making them work for better for the
// users.
// The less magic the better. Do not give in those temptations!
Dict.Builder<String, Object> initializerKwargs = Dict.builder();
for (var attr : currentRuleClass.getAttributes()) {
if (attr.isPublic() && attr.starlarkDefined()) {
if (kwargs.containsKey(attr.getName())) {
initializerKwargs.put(attr.getName(), kwargs.get(attr.getName()));
// TODO: b/298561048 - lift parameters to more accurate type - for example strings to
// Labels
// You might feel tempted to inspect the signature of the initializer function. The
// temptation might come from handling default values, making them work for better for the
// users.
// The less magic the better. Do not give in those temptations!
Dict.Builder<String, Object> initializerKwargs = Dict.builder();
for (var attr : currentRuleClass.getAttributes()) {
if (attr.isPublic() && attr.starlarkDefined()) {
if (kwargs.containsKey(attr.getName())) {
initializerKwargs.put(attr.getName(), kwargs.get(attr.getName()));
}
}
}
}
Object ret =
Starlark.call(
thread,
currentRuleClass.getInitializer(),
Tuple.of(),
initializerKwargs.build(thread.mutability()));
Dict<String, Object> newKwargs =
ret == Starlark.NONE
? Dict.empty()
: Dict.cast(ret, String.class, Object.class, "rule's initializer return value");

for (var arg : newKwargs.keySet()) {
checkAttributeName(arg);
if (arg.startsWith("_")) {
// allow setting private attributes from initializers in builtins
Label definitionLabel = ruleClass.getRuleDefinitionEnvironmentLabel();
BuiltinRestriction.failIfLabelOutsideAllowlist(
definitionLabel,
RepositoryMapping.ALWAYS_FALLBACK,
ALLOWLIST_RULE_EXTENSION_API_EXPERIMENTAL);
}
String nativeName = arg.startsWith("_") ? "$" + arg.substring(1) : arg;
Attribute attr = currentRuleClass.getAttributeByNameMaybe(nativeName);
if (attr != null && !attr.starlarkDefined()) {
throw Starlark.errorf(
"Initializer can only set Starlark defined attributes, not '%s'", arg);
Object ret =
Starlark.call(
thread,
currentRuleClass.getInitializer(),
Tuple.of(),
initializerKwargs.build(thread.mutability()));
Dict<String, Object> newKwargs =
ret == Starlark.NONE
? Dict.empty()
: Dict.cast(ret, String.class, Object.class, "rule's initializer return value");

for (var arg : newKwargs.keySet()) {
checkAttributeName(arg);
if (arg.startsWith("_")) {
// allow setting private attributes from initializers in builtins
Label definitionLabel = ruleClass.getRuleDefinitionEnvironmentLabel();
BuiltinRestriction.failIfLabelOutsideAllowlist(
definitionLabel,
RepositoryMapping.ALWAYS_FALLBACK,
ALLOWLIST_RULE_EXTENSION_API_EXPERIMENTAL);
}
String nativeName = arg.startsWith("_") ? "$" + arg.substring(1) : arg;
Attribute attr = currentRuleClass.getAttributeByNameMaybe(nativeName);
if (attr != null && !attr.starlarkDefined()) {
throw Starlark.errorf(
"Initializer can only set Starlark defined attributes, not '%s'", arg);
}
kwargs.putEntry(nativeName, newKwargs.get(arg));
}
kwargs.putEntry(nativeName, newKwargs.get(arg));
}
} finally {
bazelStarlarkContext.storeInThread(thread);
}

BuildLangTypedAttributeValuesMap attributeValues =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public LazySubstitution(
@Override
public String getValue() throws EvalException {
try (Mutability mutability = Mutability.create("expand_template")) {
StarlarkThread execThread = new StarlarkThread(mutability, semantics);
StarlarkThread execThread = new StarlarkThread(mutability, semantics, "map_each callback");
ImmutableList<?> values = valuesSet.toList();
List<String> parts = new ArrayList<>(values.size());
for (Object val : values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ public StarlarkCallable repositoryRule(
Object doc, // <String> or Starlark.NONE
StarlarkThread thread)
throws EvalException {
BazelStarlarkContext context = BazelStarlarkContext.from(thread);
context.checkLoadingOrWorkspacePhase("repository_rule");
BazelStarlarkContext.checkLoadingOrWorkspacePhase(thread, "repository_rule");
// We'll set the name later, pass the empty string for now.
RuleClass.Builder builder = new RuleClass.Builder("", RuleClassType.WORKSPACE, true);

Expand Down Expand Up @@ -241,7 +240,7 @@ private String getRuleClassName() {

private Object createRuleLegacy(StarlarkThread thread, Dict<String, Object> kwargs)
throws EvalException, InterruptedException {
BazelStarlarkContext.from(thread).checkWorkspacePhase("repository rule " + exportedName);
BazelStarlarkContext.checkWorkspacePhase(thread, "repository rule " + exportedName);
String ruleClassName = getRuleClassName();
try {
RuleClass ruleClass = builder.build(ruleClassName, ruleClassName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ public enum Phase {
}

/**
* Retrieves this context from a Starlark thread, or throws {@link IllegalStateException} if
* unavailable.
* Retrieves this context from a Starlark thread, or throws {@link EvalException} if unavailable.
*/
public static BazelStarlarkContext from(StarlarkThread thread) {
public static BazelStarlarkContext fromOrFail(StarlarkThread thread) throws EvalException {
BazelStarlarkContext ctx = thread.getThreadLocal(BazelStarlarkContext.class);
// ISE rather than NPE for symmetry with subclasses.
Preconditions.checkState(
ctx != null, "Expected BazelStarlarkContext to be available in this Starlark thread");
if (ctx == null) {
throw Starlark.errorf(
"this function cannot be called from %s", thread.getContextDescription());
}
return ctx;
}

Expand Down Expand Up @@ -136,8 +136,14 @@ public String getContextForUncheckedException() {
*/
// TODO(b/236456122): The Phase enum is incomplete. Ex: `Args.map_each` evaluation happens at
// execution time. So this is a misnomer and possibly wrong in those contexts.
public void checkLoadingOrWorkspacePhase(String function) throws EvalException {
if (phase == Phase.ANALYSIS) {
public static void checkLoadingOrWorkspacePhase(StarlarkThread thread, String function)
throws EvalException {
BazelStarlarkContext ctx = thread.getThreadLocal(BazelStarlarkContext.class);
if (ctx == null) {
throw Starlark.errorf(
"'%s' cannot be called from %s", function, thread.getContextDescription());
}
if (ctx.phase == Phase.ANALYSIS) {
throw Starlark.errorf("'%s' cannot be called during the analysis phase", function);
}
}
Expand All @@ -147,9 +153,17 @@ public void checkLoadingOrWorkspacePhase(String function) throws EvalException {
*
* @param function name of a function that requires this check
*/
public void checkLoadingPhase(String function) throws EvalException {
if (phase != Phase.LOADING) {
throw Starlark.errorf("'%s' can only be called during the loading phase", function);
public static void checkLoadingPhase(StarlarkThread thread, String function)
throws EvalException {
BazelStarlarkContext ctx = thread.getThreadLocal(BazelStarlarkContext.class);
if (ctx == null) {
throw Starlark.errorf(
"'%s' cannot be called from %s", function, thread.getContextDescription());
}
if (ctx.phase != Phase.LOADING) {
throw Starlark.errorf(
"'%s' can only be called from a BUILD file, or a macro invoked from a BUILD file",
function);
}
}

Expand All @@ -158,8 +172,14 @@ public void checkLoadingPhase(String function) throws EvalException {
*
* @param function name of a function that requires this check
*/
public void checkWorkspacePhase(String function) throws EvalException {
if (phase != Phase.WORKSPACE) {
public static void checkWorkspacePhase(StarlarkThread thread, String function)
throws EvalException {
BazelStarlarkContext ctx = thread.getThreadLocal(BazelStarlarkContext.class);
if (ctx == null) {
throw Starlark.errorf(
"'%s' cannot be called from %s", function, thread.getContextDescription());
}
if (ctx.phase != Phase.WORKSPACE) {
throw Starlark.errorf("'%s' can only be called during workspace loading", function);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ public NoneType environmentGroup(
Sequence<?> defaultsList, // <Label>
StarlarkThread thread)
throws EvalException {
BazelStarlarkContext.checkLoadingPhase(thread, "environment_group");
PackageContext context = PackageFactory.getContext(thread);
List<Label> environments =
BuildType.LABEL_LIST.convert(
Expand Down Expand Up @@ -118,6 +119,7 @@ public NoneType licenses(
Sequence<?> licensesList, // list of license strings
StarlarkThread thread)
throws EvalException {
BazelStarlarkContext.checkLoadingPhase(thread, "licenses");
PackageContext context = PackageFactory.getContext(thread);
try {
License license = BuildType.LICENSE.convert(licensesList, "'licenses' operand");
Expand All @@ -139,6 +141,7 @@ public NoneType licenses(
documented = false,
useStarlarkThread = true)
public NoneType distribs(Object object, StarlarkThread thread) throws EvalException {
BazelStarlarkContext.checkLoadingPhase(thread, "distribs");
PackageContext context = PackageFactory.getContext(thread);

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ public NoneType call(StarlarkThread thread, Tuple args, Dict<String, Object> kwa
if (!args.isEmpty()) {
throw Starlark.errorf("unexpected positional arguments");
}
BazelStarlarkContext.from(thread).checkLoadingOrWorkspacePhase(ruleClass.getName());
BazelStarlarkContext.checkLoadingOrWorkspacePhase(thread, ruleClass.getName());
try {
PackageContext context = PackageFactory.getContext(thread);
RuleFactory.createAndAddRule(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public Sequence<?> glob(
Object allowEmptyArgument,
StarlarkThread thread)
throws EvalException, InterruptedException {
BazelStarlarkContext.from(thread).checkLoadingPhase("native.glob");
BazelStarlarkContext.checkLoadingPhase(thread, "native.glob");
PackageContext context = getContext(thread);

List<String> includes = Type.STRING_LIST.convert(include, "'glob' argument");
Expand Down Expand Up @@ -435,7 +435,7 @@ public Object existingRule(String name, StarlarkThread thread) throws EvalExcept
if (thread.getThreadLocal(ExistingRulesShouldBeNoOp.class) != null) {
return Starlark.NONE;
}
BazelStarlarkContext.from(thread).checkLoadingOrWorkspacePhase("native.existing_rule");
BazelStarlarkContext.checkLoadingOrWorkspacePhase(thread, "native.existing_rule");
PackageContext context = getContext(thread);
Target target = context.pkgBuilder.getTarget(name);
if (target instanceof Rule /* `instanceof` also verifies that target != null */) {
Expand Down Expand Up @@ -508,7 +508,7 @@ public Object existingRules(StarlarkThread thread) throws EvalException {
if (thread.getThreadLocal(ExistingRulesShouldBeNoOp.class) != null) {
return Dict.empty();
}
BazelStarlarkContext.from(thread).checkLoadingOrWorkspacePhase("native.existing_rules");
BazelStarlarkContext.checkLoadingOrWorkspacePhase(thread, "native.existing_rules");
PackageContext context = getContext(thread);
if (thread
.getSemantics()
Expand All @@ -531,7 +531,7 @@ public Object existingRules(StarlarkThread thread) throws EvalException {
public NoneType packageGroup(
String name, Sequence<?> packagesO, Sequence<?> includesO, StarlarkThread thread)
throws EvalException {
BazelStarlarkContext.from(thread).checkLoadingPhase("native.package_group");
BazelStarlarkContext.checkLoadingPhase(thread, "native.package_group");
PackageContext context = getContext(thread);

List<String> packages =
Expand Down Expand Up @@ -566,7 +566,7 @@ public NoneType packageGroup(
public NoneType exportsFiles(
Sequence<?> srcs, Object visibilityO, Object licensesO, StarlarkThread thread)
throws EvalException {
BazelStarlarkContext.from(thread).checkLoadingPhase("native.exports_files");
BazelStarlarkContext.checkLoadingPhase(thread, "native.exports_files");
Package.Builder pkgBuilder = getContext(thread).pkgBuilder;
List<String> files = Type.STRING_LIST.convert(srcs, "'exports_files' operand");

Expand Down Expand Up @@ -607,23 +607,23 @@ public NoneType exportsFiles(

@Override
public String packageName(StarlarkThread thread) throws EvalException {
BazelStarlarkContext.from(thread).checkLoadingPhase("native.package_name");
BazelStarlarkContext.checkLoadingPhase(thread, "native.package_name");
PackageIdentifier packageId =
PackageFactory.getContext(thread).getBuilder().getPackageIdentifier();
return packageId.getPackageFragment().getPathString();
}

@Override
public String repositoryName(StarlarkThread thread) throws EvalException {
BazelStarlarkContext.from(thread).checkLoadingPhase("native.repository_name");
BazelStarlarkContext.checkLoadingPhase(thread, "native.repository_name");
PackageIdentifier packageId =
PackageFactory.getContext(thread).getBuilder().getPackageIdentifier();
return packageId.getRepository().getNameWithAt();
}

@Override
public Label packageRelativeLabel(Object input, StarlarkThread thread) throws EvalException {
BazelStarlarkContext.from(thread).checkLoadingPhase("native.package_relative_label");
BazelStarlarkContext.checkLoadingPhase(thread, "native.package_relative_label");
if (input instanceof Label) {
return (Label) input;
}
Expand All @@ -638,14 +638,14 @@ public Label packageRelativeLabel(Object input, StarlarkThread thread) throws Ev
@Override
@Nullable
public String moduleName(StarlarkThread thread) throws EvalException {
BazelStarlarkContext.from(thread).checkLoadingPhase("native.module_name");
BazelStarlarkContext.checkLoadingPhase(thread, "native.module_name");
return PackageFactory.getContext(thread).getBuilder().getAssociatedModuleName().orElse(null);
}

@Override
@Nullable
public String moduleVersion(StarlarkThread thread) throws EvalException {
BazelStarlarkContext.from(thread).checkLoadingPhase("native.module_version");
BazelStarlarkContext.checkLoadingPhase(thread, "native.module_version");
return PackageFactory.getContext(thread).getBuilder().getAssociatedModuleVersion().orElse(null);
}

Expand Down Expand Up @@ -826,7 +826,7 @@ private static Object starlarkifyValue(Mutability mu, Object val, Package pkg) {
public Sequence<?> subpackages(
Sequence<?> include, Sequence<?> exclude, boolean allowEmpty, StarlarkThread thread)
throws EvalException, InterruptedException {
BazelStarlarkContext.from(thread).checkLoadingPhase("native.subpackages");
BazelStarlarkContext.checkLoadingPhase(thread, "native.subpackages");
PackageContext context = getContext(thread);

List<String> includes = Type.STRING_LIST.convert(include, "'subpackages' argument");
Expand Down
Loading

0 comments on commit 8473928

Please sign in to comment.