Skip to content

Commit

Permalink
[naga]: Let TypeInner::Matrix hold a Scalar, not just a width.
Browse files Browse the repository at this point in the history
Let `naga::TypeInner::Matrix` hold a full `Scalar`, with a kind and
byte width, not merely a byte width, to make it possible to represent
matrices of AbstractFloats for WGSL.
  • Loading branch information
jimblandy committed Nov 21, 2023
1 parent 42058cf commit baab10c
Show file tree
Hide file tree
Showing 36 changed files with 222 additions and 245 deletions.
8 changes: 3 additions & 5 deletions naga/src/back/glsl/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,9 @@ impl<'a, W> Writer<'a, W> {

for (ty_handle, ty) in self.module.types.iter() {
match ty.inner {
TypeInner::Scalar(scalar) => self.scalar_required_features(scalar),
TypeInner::Vector { scalar, .. } => self.scalar_required_features(scalar),
TypeInner::Matrix { width, .. } => {
self.scalar_required_features(Scalar::float(width))
}
TypeInner::Scalar(scalar)
| TypeInner::Vector { scalar, .. }
| TypeInner::Matrix { scalar, .. } => self.scalar_required_features(scalar),
TypeInner::Array { base, size, .. } => {
if let TypeInner::Array { .. } = self.module.types[base].inner {
self.features.request(Features::ARRAY_OF_ARRAYS)
Expand Down
4 changes: 2 additions & 2 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -985,11 +985,11 @@ impl<'a, W: Write> Writer<'a, W> {
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => write!(
self.out,
"{}mat{}x{}",
glsl_scalar(crate::Scalar::float(width))?.prefix,
glsl_scalar(scalar)?.prefix,
columns as u8,
rows as u8
)?,
Expand Down
10 changes: 5 additions & 5 deletions naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ impl crate::TypeInner {
Self::Matrix {
columns,
rows,
width,
scalar,
} => {
let stride = Alignment::from(rows) * width as u32;
let last_row_size = rows as u32 * width as u32;
let stride = Alignment::from(rows) * scalar.width as u32;
let last_row_size = rows as u32 * scalar.width as u32;
((columns as u32 - 1) * stride) + last_row_size
}
Self::Array { base, size, stride } => {
Expand Down Expand Up @@ -82,10 +82,10 @@ impl crate::TypeInner {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => Cow::Owned(format!(
"{}{}x{}",
crate::Scalar::float(width).to_hlsl_str()?,
scalar.to_hlsl_str()?,
crate::back::vector_size_str(columns),
crate::back::vector_size_str(rows),
)),
Expand Down
11 changes: 4 additions & 7 deletions naga/src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -656,10 +656,9 @@ impl<'a, W: Write> super::Writer<'a, W> {
_ => unreachable!(),
};
let vec_ty = match module.types[member.ty].inner {
crate::TypeInner::Matrix { rows, width, .. } => crate::TypeInner::Vector {
size: rows,
scalar: crate::Scalar::float(width),
},
crate::TypeInner::Matrix { rows, scalar, .. } => {
crate::TypeInner::Vector { size: rows, scalar }
}
_ => unreachable!(),
};
self.write_value_type(module, &vec_ty)?;
Expand Down Expand Up @@ -736,9 +735,7 @@ impl<'a, W: Write> super::Writer<'a, W> {
_ => unreachable!(),
};
let scalar_ty = match module.types[member.ty].inner {
crate::TypeInner::Matrix { width, .. } => {
crate::TypeInner::Scalar(crate::Scalar::float(width))
}
crate::TypeInner::Matrix { scalar, .. } => crate::TypeInner::Scalar(scalar),
_ => unreachable!(),
};
self.write_value_type(module, &scalar_ty)?;
Expand Down
26 changes: 10 additions & 16 deletions naga/src/back/hlsl/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,23 +180,20 @@ impl<W: fmt::Write> super::Writer<'_, W> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
write!(
self.out,
"{}{}x{}(",
crate::Scalar::float(width).to_hlsl_str()?,
scalar.to_hlsl_str()?,
columns as u8,
rows as u8,
)?;

// Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
let row_stride = Alignment::from(rows) * width as u32;
let row_stride = Alignment::from(rows) * scalar.width as u32;
let iter = (0..columns as u32).map(|i| {
let ty_inner = crate::TypeInner::Vector {
size: rows,
scalar: crate::Scalar::float(width),
};
let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
(TypeResolution::Value(ty_inner), i * row_stride)
});
self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
Expand Down Expand Up @@ -316,7 +313,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
// first, assign the value to a temporary
writeln!(self.out, "{level}{{")?;
Expand All @@ -325,7 +322,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
self.out,
"{}{}{}x{} {}{} = ",
level.next(),
crate::Scalar::float(width).to_hlsl_str()?,
scalar.to_hlsl_str()?,
columns as u8,
rows as u8,
STORE_TEMP_NAME,
Expand All @@ -335,16 +332,13 @@ impl<W: fmt::Write> super::Writer<'_, W> {
writeln!(self.out, ";")?;

// Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
let row_stride = Alignment::from(rows) * width as u32;
let row_stride = Alignment::from(rows) * scalar.width as u32;

// then iterate the stores
for i in 0..columns as u32 {
self.temp_access_chain
.push(SubAccess::Offset(i * row_stride));
let ty_inner = crate::TypeInner::Vector {
size: rows,
scalar: crate::Scalar::float(width),
};
let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
let sv = StoreValue::TempIndex {
depth,
index: i,
Expand Down Expand Up @@ -467,10 +461,10 @@ impl<W: fmt::Write> super::Writer<'_, W> {
crate::TypeInner::Vector { scalar, .. } => Parent::Array {
stride: scalar.width as u32,
},
crate::TypeInner::Matrix { rows, width, .. } => Parent::Array {
crate::TypeInner::Matrix { rows, scalar, .. } => Parent::Array {
// The stride between matrices is the count of rows as this is how
// long each column is.
stride: Alignment::from(rows) * width as u32,
stride: Alignment::from(rows) * scalar.width as u32,
},
_ => unreachable!(),
},
Expand Down
23 changes: 10 additions & 13 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -908,12 +908,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
TypeInner::Matrix {
rows,
columns,
width,
scalar,
} if member.binding.is_none() && rows == crate::VectorSize::Bi => {
let vec_ty = crate::TypeInner::Vector {
size: rows,
scalar: crate::Scalar::float(width),
};
let vec_ty = crate::TypeInner::Vector { size: rows, scalar };
let field_name_key = NameKey::StructMember(handle, index as u32);

for i in 0..columns as u8 {
Expand Down Expand Up @@ -1037,7 +1034,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
// The IR supports only float matrix
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
Expand All @@ -1046,7 +1043,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(
self.out,
"{}{}x{}",
crate::Scalar::float(width).to_hlsl_str()?,
scalar.to_hlsl_str()?,
back::vector_size_str(columns),
back::vector_size_str(rows),
)?;
Expand Down Expand Up @@ -3241,11 +3238,11 @@ pub(super) fn get_inner_matrix_data(
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => Some(MatrixType {
columns,
rows,
width,
width: scalar.width,
}),
TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
_ => None,
Expand Down Expand Up @@ -3276,12 +3273,12 @@ pub(super) fn get_inner_matrix_of_struct_array_member(
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
mat_data = Some(MatrixType {
columns,
rows,
width,
width: scalar.width,
})
}
TypeInner::Array { base, .. } => {
Expand Down Expand Up @@ -3333,12 +3330,12 @@ fn get_inner_matrix_of_global_uniform(
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
mat_data = Some(MatrixType {
columns,
rows,
width,
width: scalar.width,
})
}
TypeInner::Array { base, .. } => {
Expand Down
9 changes: 4 additions & 5 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1942,11 +1942,11 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
let target_scalar = crate::Scalar {
kind,
width: convert.unwrap_or(width),
width: convert.unwrap_or(scalar.width),
};
put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
write!(self.out, "(")?;
Expand Down Expand Up @@ -2555,10 +2555,9 @@ impl<W: Write> Writer<W> {
TypeResolution::Value(crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
}) => {
let element = crate::Scalar::float(width);
put_numeric_type(&mut self.out, element, &[rows, columns])?;
put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
}
TypeResolution::Value(ref other) => {
log::warn!("Type {:?} isn't a known local", other); //TEMP!
Expand Down
12 changes: 5 additions & 7 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
self.write_matrix_matrix_column_op(
block,
Expand All @@ -504,7 +504,7 @@ impl<'w> BlockContext<'w> {
right_id,
columns,
rows,
width,
scalar.width,
spirv::Op::FAdd,
);

Expand All @@ -522,7 +522,7 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
self.write_matrix_matrix_column_op(
block,
Expand All @@ -532,7 +532,7 @@ impl<'w> BlockContext<'w> {
right_id,
columns,
rows,
width,
scalar.width,
spirv::Op::FSub,
);

Expand Down Expand Up @@ -1141,9 +1141,7 @@ impl<'w> BlockContext<'w> {
match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) {
crate::TypeInner::Scalar(scalar) => (scalar, None, false),
crate::TypeInner::Vector { scalar, size } => (scalar, Some(size), false),
crate::TypeInner::Matrix { width, .. } => {
(crate::Scalar::float(width), None, true)
}
crate::TypeInner::Matrix { scalar, .. } => (scalar, None, true),
ref other => {
log::error!("As source {:?}", other);
return Err(Error::Validation("Unexpected Expression::As source"));
Expand Down
4 changes: 2 additions & 2 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,11 @@ fn make_local(inner: &crate::TypeInner) -> Option<LocalType> {
crate::TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => LocalType::Matrix {
columns,
rows,
width,
width: scalar.width,
},
crate::TypeInner::Pointer { base, space } => LocalType::Pointer {
base,
Expand Down
4 changes: 2 additions & 2 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1766,10 +1766,10 @@ impl Writer {
if let crate::TypeInner::Matrix {
columns: _,
rows,
width,
scalar,
} = *member_array_subty_inner
{
let byte_stride = Alignment::from(rows) * width as u32;
let byte_stride = Alignment::from(rows) * scalar.width as u32;
self.annotations.push(Instruction::member_decorate(
struct_id,
index as u32,
Expand Down
9 changes: 4 additions & 5 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,14 +524,14 @@ impl<W: Write> Writer<W> {
TypeInner::Matrix {
columns,
rows,
width,
scalar,
} => {
write!(
self.out,
"mat{}x{}<{}>",
back::vector_size_str(columns),
back::vector_size_str(rows),
scalar_kind_str(crate::Scalar::float(width))
scalar_kind_str(scalar)
)?;
}
TypeInner::Pointer { base, space } => {
Expand Down Expand Up @@ -1412,12 +1412,11 @@ impl<W: Write> Writer<W> {
TypeInner::Matrix {
columns,
rows,
width,
..
scalar,
} => {
let scalar = crate::Scalar {
kind,
width: convert.unwrap_or(width),
width: convert.unwrap_or(scalar.width),
};
let scalar_kind_str = scalar_kind_str(scalar);
write!(
Expand Down
4 changes: 2 additions & 2 deletions naga/src/front/glsl/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,7 @@ fn inject_common_builtin(
vec![TypeInner::Matrix {
columns,
rows,
width: float_width,
scalar: float_scalar,
}],
MacroCall::MathFunction(MathFunction::Transpose),
))
Expand All @@ -1295,7 +1295,7 @@ fn inject_common_builtin(
let args = vec![TypeInner::Matrix {
columns,
rows,
width: float_width,
scalar: float_scalar,
}];

declaration.overloads.push(module.add_builtin(
Expand Down
Loading

0 comments on commit baab10c

Please sign in to comment.