Skip to content

Commit

Permalink
[msl] Move symbol renaming into the backend
Browse files Browse the repository at this point in the history
Instead of running the AST renamer transform in Dawn's Metal backend,
we just modify the names when emitting them in the printer.

This is a little cleaner than using a transform, as transforms cannot
rename structures and struct members without requiring
const_cast. This also keeps the list of MSL keywords local to the
printer, and means we can retain the meaningful names throughout the
`raise` process.

Bug: 380043958
Change-Id: Ib5ae9b36947826928fd7babaed9892d06a8c0c29
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/219295
Reviewed-by: Antonio Maiorano <[email protected]>
Commit-Queue: James Price <[email protected]>
  • Loading branch information
jrprice authored and Dawn LUCI CQ committed Dec 13, 2024
1 parent f0ed19b commit 65727b7
Show file tree
Hide file tree
Showing 624 changed files with 2,399 additions and 2,012 deletions.
28 changes: 18 additions & 10 deletions src/dawn/native/metal/ShaderModuleMTL.mm
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
namespace dawn::native::metal {
namespace {

// The name to use when remapping entry points.
constexpr char kRemappedEntryPointName[] = "dawn_entry_point";

using OptionalVertexPullingTransformConfig =
std::optional<tint::ast::transform::VertexPulling::Config>;

Expand Down Expand Up @@ -274,6 +277,11 @@
req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
req.platform = UnsafeUnkeyedValue(device->GetPlatform());

req.use_tint_ir = device->IsToggleEnabled(Toggle::UseTintIR);
if (req.use_tint_ir) {
req.tintOptions.strip_all_names = !req.disableSymbolRenaming;
req.tintOptions.remapped_entry_point_name = kRemappedEntryPointName;
}
req.tintOptions.disable_robustness = !device->IsRobustnessEnabled();
req.tintOptions.buffer_size_ubo_index = kBufferLengthBufferSlot;
req.tintOptions.fixed_sample_mask = sampleMask;
Expand All @@ -286,7 +294,6 @@
req.tintOptions.array_length_from_uniform = std::move(arrayLengthFromUniform);
req.tintOptions.pixel_local_attachments = std::move(pixelLocalAttachments);
req.tintOptions.bindings = std::move(bindings);
req.use_tint_ir = device->IsToggleEnabled(Toggle::UseTintIR);
req.tintOptions.disable_polyfill_integer_div_mod =
device->IsToggleEnabled(Toggle::DisablePolyfillsOnIntegerDivisonAndModulo);

Expand All @@ -297,7 +304,6 @@
DAWN_TRY_LOAD_OR_RUN(
mslCompilation, device, std::move(req), MslCompilation::FromBlob,
[](MslCompilationRequest r) -> ResultOrError<MslCompilation> {
constexpr char kRemappedEntryPointName[] = "dawn_entry_point";
tint::ast::transform::Manager transformManager;
tint::ast::transform::DataMap transformInputs;

Expand All @@ -307,14 +313,16 @@
transformManager.Add<tint::ast::transform::SingleEntryPoint>();
transformInputs.Add<tint::ast::transform::SingleEntryPoint::Config>(r.entryPointName);

// Needs to run before all other transforms so that they can use builtin names safely.
tint::ast::transform::Renamer::Remappings requestedNames = {
{r.entryPointName, kRemappedEntryPointName}};
transformManager.Add<tint::ast::transform::Renamer>();
transformInputs.Add<tint::ast::transform::Renamer::Config>(
r.disableSymbolRenaming ? tint::ast::transform::Renamer::Target::kMslKeywords
: tint::ast::transform::Renamer::Target::kAll,
std::move(requestedNames));
if (!r.use_tint_ir) {
// Needs to run before other transforms so that they can use builtin names safely.
tint::ast::transform::Renamer::Remappings requestedNames = {
{r.entryPointName, kRemappedEntryPointName}};
transformManager.Add<tint::ast::transform::Renamer>();
transformInputs.Add<tint::ast::transform::Renamer::Config>(
r.disableSymbolRenaming ? tint::ast::transform::Renamer::Target::kMslKeywords
: tint::ast::transform::Renamer::Target::kAll,
std::move(requestedNames));
}

if (r.vertexPullingTransformConfig) {
transformManager.Add<tint::ast::transform::VertexPulling>();
Expand Down
10 changes: 9 additions & 1 deletion src/tint/cmd/tint/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,10 @@ When specified, automatically enables HLSL validation)",
tint::ast::transform::DataMap& transform_inputs) {
switch (options.format) {
case Format::kMsl: {
if (options.use_ir) {
// Renaming is handled in the backend.
break;
}
if (!options.rename_all) {
transform_inputs.Add<tint::ast::transform::Renamer::Config>(
tint::ast::transform::Renamer::Target::kMslKeywords);
Expand Down Expand Up @@ -888,8 +892,12 @@ bool GenerateMsl([[maybe_unused]] Options& options,
input_program = std::move(flattened.value());
}

// TODO(jrprice): Provide a way for the user to set non-default options.
// Set up the backend options.
tint::msl::writer::Options gen_options;
if (options.rename_all) {
gen_options.remapped_entry_point_name = "tint_entry_point";
gen_options.strip_all_names = true;
}
gen_options.disable_robustness = !options.enable_robustness;
gen_options.disable_workgroup_init = options.disable_workgroup_init;
gen_options.pixel_local_attachments = options.pixel_local_attachments;
Expand Down
10 changes: 10 additions & 0 deletions src/tint/lang/msl/writer/common/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#ifndef SRC_TINT_LANG_MSL_WRITER_COMMON_OPTIONS_H_
#define SRC_TINT_LANG_MSL_WRITER_COMMON_OPTIONS_H_

#include <optional>
#include <string>
#include <unordered_map>

#include "src/tint/api/common/binding_point.h"
Expand Down Expand Up @@ -135,6 +137,12 @@ struct Options {
/// @returns this Options
Options& operator=(const Options&);

/// An optional remapped name to use when emitting the entry point.
std::optional<std::string> remapped_entry_point_name = {};

/// Set to `true` to strip all user-declared identifiers from the module.
bool strip_all_names = false;

/// Set to `true` to disable software robustness that prevents out-of-bounds accesses.
bool disable_robustness = false;

Expand Down Expand Up @@ -171,6 +179,8 @@ struct Options {

/// Reflect the fields of this class so that it can be used by tint::ForeachField()
TINT_REFLECT(Options,
remapped_entry_point_name,
strip_all_names,
disable_robustness,
disable_workgroup_init,
disable_demote_to_helper,
Expand Down
12 changes: 6 additions & 6 deletions src/tint/lang/msl/writer/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ fragment void foo(device int* storage_var [[buffer(1)]], const constant int* uni

TEST_F(MslWriterTest, EntryPointParameterHandleBindingPoint) {
auto* t = ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k2d, ty.f32());
auto* texture = b.Var("texture", ty.ptr<handle>(t));
auto* sampler = b.Var("sampler", ty.ptr<handle>(ty.sampler()));
auto* texture = b.Var("t", ty.ptr<handle>(t));
auto* sampler = b.Var("s", ty.ptr<handle>(ty.sampler()));
texture->SetBindingPoint(0, 1);
sampler->SetBindingPoint(0, 2);
mod.root_block->Append(texture);
Expand All @@ -119,12 +119,12 @@ TEST_F(MslWriterTest, EntryPointParameterHandleBindingPoint) {
using namespace metal;
struct tint_module_vars_struct {
texture2d<float, access::sample> texture;
sampler sampler;
texture2d<float, access::sample> t;
sampler s;
};
fragment void foo(texture2d<float, access::sample> texture [[texture(1)]], sampler sampler [[sampler(2)]]) {
tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.texture=texture, .sampler=sampler};
fragment void foo(texture2d<float, access::sample> t [[texture(1)]], sampler s [[sampler(2)]]) {
tint_module_vars_struct const tint_module_vars = tint_module_vars_struct{.t=t, .s=s};
}
)");
}
Expand Down
Loading

0 comments on commit 65727b7

Please sign in to comment.