diff --git a/src/tint/lang/core/ir/transform/demote_to_helper.cc b/src/tint/lang/core/ir/transform/demote_to_helper.cc index bc69ec28e6..24e2529ce4 100644 --- a/src/tint/lang/core/ir/transform/demote_to_helper.cc +++ b/src/tint/lang/core/ir/transform/demote_to_helper.cc @@ -189,7 +189,7 @@ struct State { [&](Return* ret) { // Insert a conditional terminate invocation instruction before each return // instruction in the entry point function. - if (ret->Func()->Stage() == Function::PipelineStage::kFragment) { + if (ret->Func()->IsFragment()) { b.InsertBefore(ret, [&] { auto* cond = b.Load(continue_execution); auto* ifelse = b.If(b.Not(cond)); diff --git a/src/tint/lang/core/ir/transform/shader_io.cc b/src/tint/lang/core/ir/transform/shader_io.cc index 5c9be95aab..594fae3be7 100644 --- a/src/tint/lang/core/ir/transform/shader_io.cc +++ b/src/tint/lang/core/ir/transform/shader_io.cc @@ -106,7 +106,7 @@ struct State { // Add an output for the vertex point size if needed. std::optional vertex_point_size_index; - if (ep->Stage() == Function::PipelineStage::kVertex && backend->NeedsVertexPointSize()) { + if (ep->IsVertex() && backend->NeedsVertexPointSize()) { vertex_point_size_index = backend->AddOutput(ir.symbols.New("vertex_point_size"), ty.f32(), core::IOAttributes{ @@ -164,8 +164,7 @@ struct State { for (auto* member : str->Members()) { auto name = str->Name().Name() + "_" + member->Name().Name(); auto attributes = member->Attributes(); - if (attributes.interpolation && - ep->Stage() != Function::PipelineStage::kFragment) { + if (attributes.interpolation && !ep->IsFragment()) { // Strip interpolation on non-fragment inputs attributes.interpolation = {}; } @@ -175,7 +174,7 @@ struct State { } else { // Pull out the IO attributes and remove them from the parameter. auto attributes = param->Attributes(); - if (attributes.interpolation && ep->Stage() != Function::PipelineStage::kFragment) { + if (attributes.interpolation && !ep->IsFragment()) { // Strip interpolation on non-fragment inputs attributes.interpolation = {}; } @@ -197,7 +196,7 @@ struct State { for (auto* member : str->Members()) { auto name = str->Name().Name() + "_" + member->Name().Name(); auto attributes = member->Attributes(); - if (attributes.interpolation && ep->Stage() != Function::PipelineStage::kVertex) { + if (attributes.interpolation && !ep->IsVertex()) { // Strip interpolation on non-vertex outputs attributes.interpolation = {}; } @@ -207,7 +206,7 @@ struct State { } else { // Pull out the IO attributes and remove them from the original function. auto attributes = ep->ReturnAttributes(); - if (attributes.interpolation && ep->Stage() != Function::PipelineStage::kVertex) { + if (attributes.interpolation && !ep->IsVertex()) { // Strip interpolation on non-vertex outputs attributes.interpolation = {}; } diff --git a/src/tint/lang/core/ir/transform/substitute_overrides.cc b/src/tint/lang/core/ir/transform/substitute_overrides.cc index c4b17c9fe8..6276903bd2 100644 --- a/src/tint/lang/core/ir/transform/substitute_overrides.cc +++ b/src/tint/lang/core/ir/transform/substitute_overrides.cc @@ -124,7 +124,7 @@ struct State { // Find any workgroup_sizes to replace for (auto func : ir.functions) { - if (func->Stage() != core::ir::Function::PipelineStage::kCompute) { + if (!func->IsCompute()) { continue; } diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc index 020b644663..cd98505e5b 100644 --- a/src/tint/lang/core/ir/validator.cc +++ b/src/tint/lang/core/ir/validator.cc @@ -1386,7 +1386,7 @@ void Validator::CheckForNonFragmentDiscards() { for (const auto& d : discards_) { const auto* f = ContainingFunction(d); for (const Function* ep : ContainingEndPoints(f)) { - if (ep->Stage() != Function::PipelineStage::kFragment) { + if (!ep->IsFragment()) { AddError(d) << "cannot be called in non-fragment end point"; } } @@ -2019,7 +2019,7 @@ void Validator::CheckFunction(const Function* func) { "invariant can only decorate a param member iff it is also " "decorated with position")); - if (func->Stage() == Function::PipelineStage::kFragment) { + if (func->IsFragment()) { CheckFunctionParamAttributesAndType( param, CheckFrontFacingIfBoolFunc( @@ -2128,8 +2128,7 @@ void Validator::CheckFunction(const Function* func) { continue; } - if (func->Stage() == Function::PipelineStage::kFragment && - mv->AddressSpace() == AddressSpace::kIn) { + if (func->IsFragment() && mv->AddressSpace() == AddressSpace::kIn) { CheckIOAttributesAndType( func, attr, ty, CheckFrontFacingIfBoolFunc("input address space values referenced by " @@ -2148,7 +2147,7 @@ void Validator::CheckFunction(const Function* func) { } } - if (func->Stage() == Function::PipelineStage::kVertex) { + if (func->IsVertex()) { CheckVertexEntryPoint(func); } @@ -2157,7 +2156,7 @@ void Validator::CheckFunction(const Function* func) { } void Validator::CheckWorkgroupSize(const Function* func) { - if (func->Stage() != Function::PipelineStage::kCompute) { + if (!func->IsCompute()) { if (func->WorkgroupSize().has_value()) { AddError(func) << "@workgroup_size only valid on compute entry point"; } diff --git a/src/tint/lang/glsl/writer/printer/printer.cc b/src/tint/lang/glsl/writer/printer/printer.cc index 3cda7930fe..e8ca2176b3 100644 --- a/src/tint/lang/glsl/writer/printer/printer.cc +++ b/src/tint/lang/glsl/writer/printer/printer.cc @@ -325,7 +325,7 @@ class Printer : public tint::TextGenerator { { auto out = Line(); - if (func->Stage() == core::ir::Function::PipelineStage::kCompute) { + if (func->IsCompute()) { auto wg_opt = func->WorkgroupSizeAsConst(); TINT_ASSERT(wg_opt.has_value()); @@ -342,7 +342,7 @@ class Printer : public tint::TextGenerator { out << " "; // Fragment shaders need a precision statement - if (func->Stage() == core::ir::Function::PipelineStage::kFragment) { + if (func->IsFragment()) { auto pre = Line(&header_buffer_); pre << "precision highp float;\n"; pre << "precision highp int;"; diff --git a/src/tint/lang/hlsl/writer/raise/pixel_local.cc b/src/tint/lang/hlsl/writer/raise/pixel_local.cc index 12a0bf47a8..47ab447aa7 100644 --- a/src/tint/lang/hlsl/writer/raise/pixel_local.cc +++ b/src/tint/lang/hlsl/writer/raise/pixel_local.cc @@ -232,7 +232,7 @@ struct State { auto rovs = CreateROVs(pixel_local_struct); for (auto f : ir.functions) { - if (f->Stage() == core::ir::Function::PipelineStage::kFragment) { + if (f->IsFragment()) { ProcessFragmentEntryPoint(f, pixel_local_var, pixel_local_struct, rovs); } } diff --git a/src/tint/lang/msl/writer/raise/shader_io.cc b/src/tint/lang/msl/writer/raise/shader_io.cc index caa6f08345..80d6a83471 100644 --- a/src/tint/lang/msl/writer/raise/shader_io.cc +++ b/src/tint/lang/msl/writer/raise/shader_io.cc @@ -136,8 +136,7 @@ struct StateImpl : core::ir::transform::ShaderIOBackendState { /// @copydoc ShaderIO::BackendState::FinalizeOutputs const core::type::Type* FinalizeOutputs() override { // Add a fixed sample mask builtin for fragment shaders if needed. - if (config.fixed_sample_mask != UINT32_MAX && - func->Stage() == core::ir::Function::PipelineStage::kFragment) { + if (config.fixed_sample_mask != UINT32_MAX && func->IsFragment()) { AddFixedSampleMaskOutput(); } diff --git a/src/tint/lang/spirv/writer/raise/shader_io.cc b/src/tint/lang/spirv/writer/raise/shader_io.cc index 14c9e53ce4..1477478615 100644 --- a/src/tint/lang/spirv/writer/raise/shader_io.cc +++ b/src/tint/lang/spirv/writer/raise/shader_io.cc @@ -98,8 +98,8 @@ struct StateImpl : core::ir::transform::ShaderIOBackendState { name << "_" << io.attributes.builtin.value(); // Vulkan requires that fragment integer builtin inputs be Flat decorated. - if (func->Stage() == core::ir::Function::PipelineStage::kFragment && - addrspace == core::AddressSpace::kIn && io.type->IsIntegerScalarOrVector()) { + if (func->IsFragment() && addrspace == core::AddressSpace::kIn && + io.type->IsIntegerScalarOrVector()) { io.attributes.interpolation = {core::InterpolationType::kFlat}; } }