Skip to content

Commit

Permalink
[ir] Use Function::IsEntryPoint() where possible
Browse files Browse the repository at this point in the history
Now that we have this helper function, we can use it many more places
instead of comparing the stage to `kUndefined`.

Change-Id: I4d8b3a78b6685606d70e99450cde571396b6d3ce
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/219774
Commit-Queue: Antonio Maiorano <[email protected]>
Auto-Submit: James Price <[email protected]>
Reviewed-by: Antonio Maiorano <[email protected]>
  • Loading branch information
jrprice authored and Dawn LUCI CQ committed Dec 17, 2024
1 parent 0c4e4be commit 9146b9f
Show file tree
Hide file tree
Showing 14 changed files with 26 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/tint/lang/core/ir/disassembler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ void Disassembler::EmitFunction(const Function* func) {
}
out_ << " =";

if (func->Stage() != Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
out_ << " " << StyleAttribute("@", func->Stage());
}
if (func->WorkgroupSize()) {
Expand Down
2 changes: 1 addition & 1 deletion src/tint/lang/core/ir/transform/add_empty_entry_point.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace {

void Run(ir::Module& ir) {
for (auto& func : ir.functions) {
if (func->Stage() != Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
return;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/tint/lang/core/ir/transform/shader_io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct State {
auto functions = ir.functions;
for (auto& func : functions) {
// Only process entry points.
if (func->Stage() == Function::PipelineStage::kUndefined) {
if (!func->IsEntryPoint()) {
continue;
}

Expand Down
2 changes: 1 addition & 1 deletion src/tint/lang/core/ir/transform/single_entry_point.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Result<SuccessType> Run(ir::Module& ir, std::string_view entry_point_name) {
// Find the entry point.
ir::Function* entry_point = nullptr;
for (auto& func : ir.functions) {
if (func->Stage() == Function::PipelineStage::kUndefined) {
if (!func->IsEntryPoint()) {
continue;
}
if (ir.NameOf(func).NameView() == entry_point_name) {
Expand Down
14 changes: 7 additions & 7 deletions src/tint/lang/core/ir/validator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,7 @@ class Validator {
}
visited.Add(calling_function);

if (calling_function->Stage() != Function::PipelineStage::kUndefined) {
if (calling_function->IsEntryPoint()) {
result.Add(calling_function);
}

Expand Down Expand Up @@ -1946,7 +1946,7 @@ void Validator::CheckFunction(const Function* func) {
TINT_DEFER(scope_stack_.Pop());

// Checking the name early, so its usage can be recorded, even if the function is malformed.
if (func->Stage() != Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
const auto name = mod_.NameOf(func).Name();
if (!entry_point_names_.Add(name)) {
AddError(func) << "entry point name " << style::Function(name) << " is not unique";
Expand Down Expand Up @@ -2028,7 +2028,7 @@ void Validator::CheckFunction(const Function* func) {
CheckFrontFacingIfBoolFunc<FunctionParam>(
"fragment entry point param memebers can only be a bool if "
"decorated with @builtin(front_facing)"));
} else if (func->Stage() != Function::PipelineStage::kUndefined) {
} else if (func->IsEntryPoint()) {
CheckFunctionParamAttributesAndType(
param, CheckNotBool<FunctionParam>(
"entry point params can only be a bool for fragment shaders"));
Expand All @@ -2045,7 +2045,7 @@ void Validator::CheckFunction(const Function* func) {
}
}

if (func->Stage() != Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
{
auto result = ValidateShaderIOAnnotations(param->Type(), param->BindingPoint(),
param->Attributes(), "input param");
Expand Down Expand Up @@ -2090,7 +2090,7 @@ void Validator::CheckFunction(const Function* func) {
AddError(func) << "function return type must be constructible";
}

if (func->Stage() != Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
if (DAWN_UNLIKELY(mod_.NameOf(func).Name().empty())) {
AddError(func) << "entry points must have names";
}
Expand All @@ -2104,7 +2104,7 @@ void Validator::CheckFunction(const Function* func) {
}
}

if (func->Stage() != Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
auto result = ValidateShaderIOAnnotations(func->ReturnType(), std::nullopt,
func->ReturnAttributes(), "return values");
if (result != Success) {
Expand Down Expand Up @@ -2775,7 +2775,7 @@ void Validator::CheckUserCall(const UserCall* call) {
return;
}

if (call->Target()->Stage() != Function::PipelineStage::kUndefined) {
if (call->Target()->IsEntryPoint()) {
AddError(call, UserCall::kFunctionOperandOffset)
<< "call target must not have a pipeline stage";
}
Expand Down
2 changes: 1 addition & 1 deletion src/tint/lang/glsl/writer/writer_bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void GenerateGLSL(benchmark::State& state, std::string input_name) {

// Get the list of entry point names.
for (auto func : ir->functions) {
if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
names.push_back(ir->NameOf(func).Name());
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/tint/lang/glsl/writer/writer_fuzz.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Options GenerateOptions(core::ir::Module& module) {

// Set offsets for push constants used for certain builtins.
for (auto& func : module.functions) {
if (func->Stage() == core::ir::Function::PipelineStage::kUndefined) {
if (!func->IsEntryPoint()) {
continue;
}

Expand Down Expand Up @@ -178,7 +178,7 @@ bool CanRun(const core::ir::Module& module, Options& options) {

// Check for unsupported shader IO builtins.
for (auto& func : module.functions) {
if (func->Stage() == core::ir::Function::PipelineStage::kUndefined) {
if (!func->IsEntryPoint()) {
continue;
}

Expand Down Expand Up @@ -257,7 +257,7 @@ Result<SuccessType> IRFuzzer(core::ir::Module& module, const fuzz::ir::Context&
// Strip the module down to a single entry point.
core::ir::Function* entry_point = nullptr;
for (auto& func : module.functions) {
if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
entry_point = func;
break;
}
Expand Down
2 changes: 1 addition & 1 deletion src/tint/lang/hlsl/writer/printer/printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class Printer : public tint::TextGenerator {

out << " " << func_name << "(";

bool is_ep = func->Stage() != core::ir::Function::PipelineStage::kUndefined;
bool is_ep = func->IsEntryPoint();
size_t i = 0;
for (auto* param : func->Params()) {
if (i > 0) {
Expand Down
10 changes: 4 additions & 6 deletions src/tint/lang/msl/writer/printer/printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class Printer : public tint::TextGenerator {
// We only look at function parameters of entry points, since this is how binding resources
// are handled in MSL.
for (auto func : ir_.functions) {
if (func->Stage() == core::ir::Function::PipelineStage::kUndefined) {
if (!func->IsEntryPoint()) {
continue;
}
for (auto* param : func->Params()) {
Expand Down Expand Up @@ -347,7 +347,7 @@ class Printer : public tint::TextGenerator {
case core::ir::Function::PipelineStage::kUndefined:
break;
}
if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
result_.workgroup_info.allocations.insert({func_name, {}});
}

Expand All @@ -365,8 +365,7 @@ class Printer : public tint::TextGenerator {
out << " ";

// Non-entrypoint pointers are set to `const` for the value
if (func->Stage() == core::ir::Function::PipelineStage::kUndefined &&
param->Type()->Is<core::type::Pointer>()) {
if (!func->IsEntryPoint() && param->Type()->Is<core::type::Pointer>()) {
out << "const ";
}

Expand All @@ -378,8 +377,7 @@ class Printer : public tint::TextGenerator {
out << " [[" << name << "]]";
}

if (param->Type()->Is<core::type::Struct>() &&
func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
if (param->Type()->Is<core::type::Struct>() && func->IsEntryPoint()) {
out << " [[stage_in]]";
}

Expand Down
2 changes: 1 addition & 1 deletion src/tint/lang/msl/writer/raise/module_scope_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ struct State {
// Add the structure the holds the module-scope variable pointers to the function and record
// it in the map. Entry points will create the structure, other functions will declare it as
// a parameter.
if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
function_to_struct_value.Add(func, AddModuleVarsToEntryPoint(func, refs));
} else {
function_to_struct_value.Add(func, AddModuleVarsToFunction(func));
Expand Down
2 changes: 1 addition & 1 deletion src/tint/lang/msl/writer/raise/simd_ballot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct State {
return var == subgroup_size_mask;
});
for (auto func : ir.functions) {
if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
if (refs.TransitiveReferences(func).Contains(subgroup_size_mask)) {
SetSubgroupSizeMaskForEntryPoint(func);
}
Expand Down
4 changes: 2 additions & 2 deletions src/tint/lang/spirv/reader/lower/shader_io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct State {
// Use a worklist as `ProcessEntryPointOutputs()` will add new functions.
Vector<core::ir::Function*, 4> entry_points;
for (auto& func : ir.functions) {
if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
entry_points.Push(func);
}
}
Expand Down Expand Up @@ -309,7 +309,7 @@ struct State {
/// @returns the function parameter
core::ir::Value* GetParameter(core::ir::Function* func, core::ir::Var* var) {
return function_parameter_map.GetOrAddZero(func).GetOrAdd(var, [&] {
const bool entry_point = func->Stage() != core::ir::Function::PipelineStage::kUndefined;
const bool entry_point = func->IsEntryPoint();
auto* var_type = var->Result(0)->Type()->UnwrapPtr();

// Use a scalar u32 for sample_mask builtins for entry point parameters.
Expand Down
2 changes: 1 addition & 1 deletion src/tint/lang/spirv/writer/printer/printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ class Printer {
PushName(id, func);

// Emit OpEntryPoint and OpExecutionMode declarations if needed.
if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
if (func->IsEntryPoint()) {
EmitEntryPoint(func, id);
}

Expand Down
2 changes: 1 addition & 1 deletion src/tint/lang/spirv/writer/raise/merge_return.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct State {
/// Process the function.
/// @param fn the function to process
void Process(core::ir::Function* fn) {
if (fn->Stage() != core::ir::Function::PipelineStage::kUndefined) {
if (fn->IsEntryPoint()) {
// Entry points are not called and do not require this transformation to ensure
// convergence.
return;
Expand Down

0 comments on commit 9146b9f

Please sign in to comment.