Skip to content

Commit

Permalink
Merge #3445
Browse files Browse the repository at this point in the history
3445: [mtl] experimental Naga support r=grovesNL a=kvark

First humble steps towards #71 ...

Naga is an optional github dependency (obviously will need to be switched to crates upon release).
If it fails, we fall back to spirv-cross. Extra steps are taking to ensure that shader modules compiled by Naga and SPIRV-Cross can still link together.

PR checklist:
- [x] `make` succeeds (on *nix)
- [ ] `make reftests` succeeds
- [x] tested examples with the following backends: Metal


Co-authored-by: Dzmitry Malyshau <[email protected]>
  • Loading branch information
bors[bot] and kvark authored Oct 29, 2020
2 parents 4e6fb6a + 151f5af commit 17249af
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 140 deletions.
7 changes: 7 additions & 0 deletions src/backend/metal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ storage-map = "0.3"
lazy_static = "1"
raw-window-handle = "0.3"

[dependencies.naga]
#path = "../../../../naga"
git = "https://github.com/gfx-rs/naga"
rev = "587dc01a2cc43fdf4637f1bf6d2b75b703e9f681"
features = ["spv-in", "msl-out"]
optional = true

# This forces docs.rs to build the crate on mac, otherwise the build fails
# and we get no docs at all.
[package.metadata.docs.rs]
Expand Down
314 changes: 188 additions & 126 deletions src/backend/metal/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ use std::{
collections::BTreeMap,
iter, mem,
ops::Range,
path::Path,
ptr,
sync::{
atomic::{AtomicBool, Ordering},
Expand Down Expand Up @@ -581,60 +580,7 @@ impl Device {
}
}

pub fn create_shader_library_from_file<P>(
&self,
_path: P,
) -> Result<n::ShaderModule, ShaderError>
where
P: AsRef<Path>,
{
unimplemented!()
}

pub fn create_shader_library_from_source<S>(
&self,
source: S,
version: LanguageVersion,
rasterization_enabled: bool,
) -> Result<n::ShaderModule, ShaderError>
where
S: AsRef<str>,
{
let options = metal::CompileOptions::new();
let msl_version = match version {
LanguageVersion { major: 1, minor: 0 } => MTLLanguageVersion::V1_0,
LanguageVersion { major: 1, minor: 1 } => MTLLanguageVersion::V1_1,
LanguageVersion { major: 1, minor: 2 } => MTLLanguageVersion::V1_2,
LanguageVersion { major: 2, minor: 0 } => MTLLanguageVersion::V2_0,
LanguageVersion { major: 2, minor: 1 } => MTLLanguageVersion::V2_1,
_ => {
return Err(ShaderError::CompilationFailed(
"shader model not supported".into(),
));
}
};
if msl_version > self.shared.private_caps.msl_version {
return Err(ShaderError::CompilationFailed(
"shader model too high".into(),
));
}
options.set_language_version(msl_version);

self.shared
.device
.lock()
.new_library_with_source(source.as_ref(), &options)
.map(|library| {
n::ShaderModule::Compiled(n::ModuleInfo {
library,
entry_point_map: n::EntryPointMap::default(),
rasterization_enabled,
})
})
.map_err(|e| ShaderError::CompilationFailed(e.into()))
}

fn compile_shader_library(
fn compile_shader_library_cross(
device: &Mutex<metal::Device>,
raw_data: &[u32],
compiler_options: &msl::CompilerOptions,
Expand Down Expand Up @@ -712,6 +658,62 @@ impl Device {
})
}

#[cfg(feature = "naga")]
fn compile_shader_library_naga(
device: &Mutex<metal::Device>,
module: &naga::Module,
naga_options: &naga::back::msl::Options,
) -> Result<n::ModuleInfo, ShaderError> {
let source = naga::back::msl::write_string(module, naga_options)
.map_err(|e| ShaderError::CompilationFailed(format!("{:?}", e)))?;

let mut entry_point_map = n::EntryPointMap::default();
for (&(stage, ref name), ep) in module.entry_points.iter() {
entry_point_map.insert(
name.clone(),
spirv::EntryPoint {
//TODO: fill that information by Naga
name: format!("{}{:?}", name, stage),
execution_model: match stage {
naga::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
naga::ShaderStage::Fragment => spirv::ExecutionModel::Fragment,
naga::ShaderStage::Compute => spirv::ExecutionModel::GlCompute,
},
work_group_size: spirv::WorkGroupSize {
x: ep.workgroup_size[0],
y: ep.workgroup_size[1],
z: ep.workgroup_size[2],
},
},
);
}

debug!("Naga generated shader:\n{}", source);

let options = metal::CompileOptions::new();
let msl_version = match naga_options.lang_version {
(1, 0) => MTLLanguageVersion::V1_0,
(1, 1) => MTLLanguageVersion::V1_1,
(1, 2) => MTLLanguageVersion::V1_2,
(2, 0) => MTLLanguageVersion::V2_0,
(2, 1) => MTLLanguageVersion::V2_1,
(2, 2) => MTLLanguageVersion::V2_2,
other => panic!("Unexpected language version {:?}", other),
};
options.set_language_version(msl_version);

let library = device
.lock()
.new_library_with_source(source.as_ref(), &options)
.map_err(|err| ShaderError::CompilationFailed(err.into()))?;

Ok(n::ModuleInfo {
library,
entry_point_map,
rasterization_enabled: true, //TODO
})
}

fn load_shader(
&self,
ep: &pso::EntryPoint<Backend>,
Expand All @@ -725,51 +727,62 @@ impl Device {
let module_map;
let (info_owned, info_guard);

let info = match *ep.module {
n::ShaderModule::Compiled(ref info) => info,
n::ShaderModule::Raw(ref data) => {
let compiler_options = &mut match primitive_class {
MTLPrimitiveTopologyClass::Point => layout.shader_compiler_options_point.clone(),
_ => layout.shader_compiler_options.clone(),
};
compiler_options.entry_point = Some((ep.entry.to_string(), match stage {
ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
ShaderStage::Fragment => spirv::ExecutionModel::Fragment,
ShaderStage::Compute => spirv::ExecutionModel::GlCompute,
_ => return Err(pso::CreationError::UnsupportedPipeline),
}));
match pipeline_cache {
Some(cache) => {
module_map = cache
.modules
.get_or_create_with(compiler_options, || FastStorageMap::default());
info_guard = module_map.get_or_create_with(data, || {
Self::compile_shader_library(
device,
data,
compiler_options,
msl_version,
&ep.specialization,
)
.unwrap()
});
&*info_guard
}
None => {
info_owned = Self::compile_shader_library(
device,
data,
compiler_options,
msl_version,
&ep.specialization,
)
.map_err(|e| {
error!("Error compiling the shader {:?}", e);
pso::CreationError::Other
})?;
&info_owned
let compiler_options = &mut match primitive_class {
MTLPrimitiveTopologyClass::Point => layout.shader_compiler_options_point.clone(),
_ => layout.shader_compiler_options.clone(),
};
compiler_options.entry_point = Some((
ep.entry.to_string(),
match stage {
ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
ShaderStage::Fragment => spirv::ExecutionModel::Fragment,
ShaderStage::Compute => spirv::ExecutionModel::GlCompute,
_ => return Err(pso::CreationError::UnsupportedPipeline),
},
));

let data = &ep.module.spv;
let info = match pipeline_cache {
Some(cache) => {
module_map = cache
.modules
.get_or_create_with(compiler_options, FastStorageMap::default);
info_guard = module_map.get_or_create_with(data, || {
Self::compile_shader_library_cross(
device,
data,
compiler_options,
msl_version,
&ep.specialization,
)
.unwrap()
});
&*info_guard
}
None => {
let mut result = Err(ShaderError::CompilationFailed(String::new()));
#[cfg(feature = "naga")]
if let Some(ref module) = ep.module.naga {
result =
Self::compile_shader_library_naga(device, module, &layout.naga_options);
if let Err(ShaderError::CompilationFailed(ref msg)) = result {
warn!("Naga: {:?}", msg);
}
}
if result.is_err() {
result = Self::compile_shader_library_cross(
device,
data,
compiler_options,
msl_version,
&ep.specialization,
);
}
info_owned = result.map_err(|e| {
error!("Error compiling the shader {:?}", e);
pso::CreationError::Other
})?;
&info_owned
}
};

Expand Down Expand Up @@ -1286,6 +1299,51 @@ impl hal::device::Device<Backend> for Device {
assert!(counters.samplers <= self.shared.private_caps.max_samplers_per_stage);
}

#[cfg(feature = "naga")]
let naga_options = {
use naga::back::msl;
fn res_index(id: u32) -> Option<u8> {
if id == !0 {
None
} else {
Some(id as _)
}
}
msl::Options {
lang_version: match self.shared.private_caps.msl_version {
MTLLanguageVersion::V1_0 => (1, 0),
MTLLanguageVersion::V1_1 => (1, 1),
MTLLanguageVersion::V1_2 => (1, 2),
MTLLanguageVersion::V2_0 => (2, 0),
MTLLanguageVersion::V2_1 => (2, 1),
MTLLanguageVersion::V2_2 => (2, 2),
},
spirv_cross_compatibility: true,
binding_map: res_overrides
.iter()
.map(|(loc, binding)| {
let source = msl::BindSource {
stage: match loc.stage {
spirv::ExecutionModel::Vertex => naga::ShaderStage::Vertex,
spirv::ExecutionModel::Fragment => naga::ShaderStage::Fragment,
spirv::ExecutionModel::GlCompute => naga::ShaderStage::Compute,
other => panic!("Unexpected stage: {:?}", other),
},
group: loc.desc_set,
binding: loc.binding,
};
let target = msl::BindTarget {
buffer: res_index(binding.buffer_id),
texture: res_index(binding.texture_id),
sampler: res_index(binding.sampler_id),
mutable: false, //TODO
};
(source, target)
})
.collect(),
}
};

let mut shader_compiler_options = msl::CompilerOptions::default();
shader_compiler_options.version = match self.shared.private_caps.msl_version {
MTLLanguageVersion::V1_0 => msl::Version::V1_0,
Expand All @@ -1308,6 +1366,8 @@ impl hal::device::Device<Backend> for Device {
Ok(n::PipelineLayout {
shader_compiler_options,
shader_compiler_options_point,
#[cfg(feature = "naga")]
naga_options,
infos,
total: n::MultiStageResourceCounters {
vs: stage_infos[0].2.clone(),
Expand Down Expand Up @@ -1465,16 +1525,26 @@ impl hal::device::Device<Backend> for Device {
}

// Vertex shader
let (vs_lib, vs_function, _, enable_rasterization) =
self.load_shader(vs, pipeline_layout, primitive_class, cache, ShaderStage::Vertex)?;
let (vs_lib, vs_function, _, enable_rasterization) = self.load_shader(
vs,
pipeline_layout,
primitive_class,
cache,
ShaderStage::Vertex,
)?;
pipeline.set_vertex_function(Some(&vs_function));

// Fragment shader
let fs_function;
let fs_lib = match pipeline_desc.fragment {
Some(ref ep) => {
let (lib, fun, _, _) =
self.load_shader(ep, pipeline_layout, primitive_class, cache, ShaderStage::Fragment)?;
let (lib, fun, _, _) = self.load_shader(
ep,
pipeline_layout,
primitive_class,
cache,
ShaderStage::Fragment,
)?;
fs_function = fun;
pipeline.set_fragment_function(Some(&fs_function));
Some(lib)
Expand Down Expand Up @@ -1773,30 +1843,22 @@ impl hal::device::Device<Backend> for Device {
raw_data: &[u32],
) -> Result<n::ShaderModule, ShaderError> {
//TODO: we can probably at least parse here and save the `Ast`
let depends_on_pipeline_layout = true; //TODO: !self.private_caps.argument_buffers

// TODO: also depends on pipeline layout if there are specialization constants that
// SPIRV-Cross generates macros for, which occurs when MSL version is older than 1.2 or the
// constant is used as an array size (see
// `CompilerMSL::emit_specialization_constants_and_structs` in SPIRV-Cross)
Ok(if depends_on_pipeline_layout {
n::ShaderModule::Raw(raw_data.to_vec())
} else {
let mut options = msl::CompilerOptions::default();
options.enable_point_size_builtin = false;
options.vertex.invert_y = !self.features.contains(hal::Features::NDC_Y_UP);
options.force_zero_initialized_variables = true;
options.force_native_arrays = true;
let info = Self::compile_shader_library(
&self.shared.device,
raw_data,
&options,
self.shared.private_caps.msl_version,
&pso::Specialization::default(), // we should only pass empty specialization constants
// here if we know they won't be used by
// SPIRV-Cross, see above
)?;
n::ShaderModule::Compiled(info)
Ok(n::ShaderModule {
spv: raw_data.to_vec(),
#[cfg(feature = "naga")]
naga: match naga::front::spv::Parser::new(raw_data.iter().cloned()).parse() {
Ok(module) => match naga::proc::Validator::new().validate(&module) {
Ok(()) => Some(module),
Err(e) => {
warn!("Naga validation failed: {:?}", e);
None
}
}
Err(e) => {
warn!("Naga parsing failed: {:?}", e);
None
}
},
})
}

Expand Down
Loading

0 comments on commit 17249af

Please sign in to comment.