Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

[msl-out] wrap arrays in structs so that they can be returned by functions #764

Merged
merged 9 commits into from
Apr 28, 2021
110 changes: 63 additions & 47 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::{
const NAMESPACE: &str = "metal";
const INDENT: &str = " ";
const BAKE_PREFIX: &str = "_e";
const WRAPPED_ARRAY_FIELD: &str = "inner";

#[derive(Clone)]
struct Level(usize);
Expand Down Expand Up @@ -81,12 +82,9 @@ impl<'a> Display for TypeContext<'a> {
}
crate::TypeInner::Pointer { base, class } => {
let sub = Self {
arena: self.arena,
names: self.names,
handle: base,
usage: self.usage,
access: self.access,
first_time: false,
..*self
};
let class_name = match class.get_name(self.usage) {
Some(name) => name,
Expand Down Expand Up @@ -125,7 +123,17 @@ impl<'a> Display for TypeContext<'a> {
vector_size_str(size),
)
}
crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => unreachable!(),
crate::TypeInner::Array { base, .. } => {
let sub = Self {
handle: base,
expenses marked this conversation as resolved.
Show resolved Hide resolved
first_time: false,
..*self
};
// Array lengths go at the end of the type definition,
// so just print the element type here.
write!(out, "{}", sub)
}
crate::TypeInner::Struct { .. } => unreachable!(),
crate::TypeInner::Image {
dim,
arrayed,
Expand Down Expand Up @@ -533,39 +541,6 @@ impl<W: Write> Writer<W> {
Ok(())
}

fn put_initialization_component(
&mut self,
component: Handle<crate::Expression>,
context: &ExpressionContext,
) -> Result<(), Error> {
// we can't initialize the array members just like other members,
// we have to unwrap them one level deeper...
let component_res = &context.info[component].ty;
if let crate::TypeInner::Array {
size: crate::ArraySize::Constant(const_handle),
..
} = *component_res.inner_with(&context.module.types)
{
//HACK: we are forcefully duplicating the expression here,
// it would be nice to find a more C++ idiomatic solution for initializing array members
let size = context.module.constants[const_handle]
.to_array_length()
.unwrap();
write!(self.out, "{{")?;
for j in 0..size {
if j != 0 {
write!(self.out, ",")?;
}
self.put_expression(component, context, false)?;
write!(self.out, "[{}]", j)?;
}
write!(self.out, "}}")?;
} else {
self.put_expression(component, context, true)?;
}
Ok(())
}

fn put_expression(
&mut self,
expr_handle: Handle<crate::Expression>,
Expand All @@ -587,7 +562,25 @@ impl<W: Write> Writer<W> {
log::trace!("expression {:?} = {:?}", expr_handle, expression);
match *expression {
crate::Expression::Access { base, index } => {
let accessing_wrapped_array =
match *context.info[base].ty.inner_with(&context.module.types) {
crate::TypeInner::Array { .. } => true,
crate::TypeInner::Pointer {
base: pointer_base, ..
} => match context.module.types[pointer_base].inner {
crate::TypeInner::Array {
size: crate::ArraySize::Constant(_),
..
} => true,
_ => false,
},
_ => false,
};

self.put_expression(base, context, false)?;
if accessing_wrapped_array {
write!(self.out, ".{}", WRAPPED_ARRAY_FIELD)?;
}
write!(self.out, "[")?;
self.put_expression(index, context, true)?;
write!(self.out, "]")?;
Expand Down Expand Up @@ -690,7 +683,7 @@ impl<W: Write> Writer<W> {
if i != 0 {
write!(self.out, ", ")?;
}
self.put_initialization_component(component, context)?;
self.put_expression(component, context, true)?;
}
write!(self.out, "}}")?;
}
Expand Down Expand Up @@ -1120,7 +1113,9 @@ impl<W: Write> Writer<W> {
let comma = if is_first { "" } else { "," };
is_first = false;
let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
// logic similar to `put_initialization_component`
// HACK: we are forcefully deduplicating the expression here
// to convert from a wrapped struct to a raw array, e.g.
// `float gl_ClipDistance1 [[clip_distance]] [1];`.
if let crate::TypeInner::Array {
size: crate::ArraySize::Constant(const_handle),
..
Expand All @@ -1134,7 +1129,11 @@ impl<W: Write> Writer<W> {
if j != 0 {
write!(self.out, ",")?;
}
write!(self.out, "{}.{}[{}]", tmp, name, j)?;
write!(
kvark marked this conversation as resolved.
Show resolved Hide resolved
self.out,
"{}.{}.{}[{}]",
tmp, name, WRAPPED_ARRAY_FIELD, j
)?;
}
write!(self.out, "}}")?;
} else {
Expand Down Expand Up @@ -1344,9 +1343,9 @@ impl<W: Write> Writer<W> {
.unwrap();
write!(self.out, "{}for(int _i=0; _i<{}; ++_i) ", level, size)?;
self.put_expression(pointer, &context.expression, true)?;
write!(self.out, "[_i] = ")?;
write!(self.out, ".{}[_i] = ", WRAPPED_ARRAY_FIELD)?;
self.put_expression(value, &context.expression, true)?;
writeln!(self.out, "[_i];")?;
writeln!(self.out, ".{}[_i];", WRAPPED_ARRAY_FIELD)?;
}
None => {
write!(self.out, "{}", level)?;
Expand Down Expand Up @@ -1467,7 +1466,7 @@ impl<W: Write> Writer<W> {
access: crate::StorageAccess::empty(),
first_time: false,
};
write!(self.out, "typedef {} {}", base_name, name)?;

match size {
crate::ArraySize::Constant(const_handle) => {
let coco = ConstantContext {
Expand All @@ -1476,10 +1475,17 @@ impl<W: Write> Writer<W> {
names: &self.names,
first_time: false,
};
writeln!(self.out, "[{}];", coco)?;

writeln!(self.out, "struct {} {{", name)?;
writeln!(
self.out,
"{}{} {}[{}];",
INDENT, base_name, WRAPPED_ARRAY_FIELD, coco
)?;
writeln!(self.out, "}};")?;
}
crate::ArraySize::Dynamic => {
writeln!(self.out, "[1];")?;
writeln!(self.out, "typedef {} {}[1];", base_name, name)?;
}
}
}
Expand Down Expand Up @@ -1941,17 +1947,27 @@ impl<W: Write> Writer<W> {
names: &self.names,
usage: GlobalUse::empty(),
access: crate::StorageAccess::empty(),
first_time: false,
first_time: true,
kvark marked this conversation as resolved.
Show resolved Hide resolved
};
let binding = binding.ok_or(Error::Validation)?;
if !pipeline_options.allow_point_size
&& *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize)
{
continue;
}
let array_len = match module.types[ty].inner {
crate::TypeInner::Array {
size: crate::ArraySize::Constant(handle),
..
} => module.constants[handle].to_array_length(),
_ => None,
};
let resolved = options.resolve_local_binding(binding, out_mode)?;
write!(self.out, "{}{} {}", INDENT, ty_name, name)?;
resolved.try_fmt_decorated(&mut self.out, "")?;
if let Some(array_len) = array_len {
write!(self.out, " [{}]", array_len)?;
}
writeln!(self.out, ";")?;
}
writeln!(self.out, "}};")?;
Expand Down
6 changes: 4 additions & 2 deletions tests/out/access.msl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <metal_stdlib>
#include <simd/simd.h>

typedef int type3[5];
struct type3 {
int inner[5];
};

struct fooInput {
};
Expand All @@ -11,5 +13,5 @@ struct fooOutput {
vertex fooOutput foo(
metal::uint vi [[vertex_id]]
) {
return fooOutput { static_cast<float4>(int4(type3 {1, 2, 3, 4, 5}[vi])) };
return fooOutput { static_cast<float4>(int4(type3 {1, 2, 3, 4, 5}.inner[vi])) };
}
10 changes: 6 additions & 4 deletions tests/out/quad-vert.msl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <metal_stdlib>
#include <simd/simd.h>

typedef float type6[1u];
struct type6 {
float inner[1u];
};
struct gl_PerVertex {
metal::float4 gl_Position;
float gl_PointSize;
Expand Down Expand Up @@ -35,7 +37,7 @@ struct main2Output {
metal::float2 member [[user(loc0), center_perspective]];
metal::float4 gl_Position1 [[position]];
float gl_PointSize1 [[point_size]];
type6 gl_ClipDistance1 [[clip_distance]];
float gl_ClipDistance1 [[clip_distance]] [1];
kvark marked this conversation as resolved.
Show resolved Hide resolved
};
vertex main2Output main2(
main2Input varyings [[stage_in]]
Expand All @@ -49,6 +51,6 @@ vertex main2Output main2(
a_uv = a_uv1;
a_pos = a_pos1;
main1(v_uv, a_uv, _, a_pos);
const auto _tmp = type10 {v_uv, _.gl_Position, _.gl_PointSize, {_.gl_ClipDistance[0]}};
return main2Output { _tmp.member, _tmp.gl_Position1, _tmp.gl_PointSize1, {_tmp.gl_ClipDistance1[0]} };
const auto _tmp = type10 {v_uv, _.gl_Position, _.gl_PointSize, _.gl_ClipDistance};
return main2Output { _tmp.member, _tmp.gl_Position1, _tmp.gl_PointSize1, {_tmp.gl_ClipDistance1.inner[0]} };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, so here we can't just do _tmp.gl_ClipDistance1 for the last argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, unfortunately not. Not gl_ClipDistance1.inner either.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I'm concerned now, a little bit. So when we are initializing the I/O struct, we can't just pass _tmp.gl_ClipDistance1 because it's an array. But how would this work in other places? I.e. if you have a user function returning a user struct, then passing { xxx.inner[0], xxx.inner[1] } would just fail horribly, since the only field of the target struct is inner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaik, only having a single pair of those { brackets when constructing a struct is okay if it only has one field. I think more testing would be nice, but I haven't found a (trivial) example that this breaks. Something like the following, for example, is fine:

#version 450

vec3[8] hello_world() {
    vec3[8] xyz;
    return xyz;
}

struct X {
  vec3[8] yyy;
};

vec3[8] hmm2() {
  X x;
  x.yyy = hello_world();
  x.yyy[1] = vec3(3.0);
  return x.yyy;
}

layout(location = 0) out vec4 colour;

void main() {
    vec3[8] xyz = hmm2();
    colour = vec4(xyz[7], 1.0);
}

->

#include <metal_stdlib>
#include <simd/simd.h>

struct type3 {
    metal::float3 inner[8u];
};
struct X {
    type3 yyy;
};
constant metal::float3 const_type1_ = {3.0, 3.0, 3.0};

type3 hello_world(
) {
    type3 xyz;
    return xyz;
}

type3 hmm2_(
) {
    X x;
    type3 _e13 = hello_world();
    for(int _i=0; _i<8; ++_i) x.yyy.inner[_i] = _e13.inner[_i];
    x.yyy.inner[1] = const_type1_;
    return x.yyy;
}

void main1(
    thread metal::float4& colour
) {
    type3 xyz1;
    type3 _e13 = hmm2_();
    for(int _i=0; _i<8; ++_i) xyz1.inner[_i] = _e13.inner[_i];
    metal::float3 _e15 = xyz1.inner[7];
    colour = metal::float4(_e15.x, _e15.y, _e15.z, 1.0);
    return;
}

struct main2Output {
    metal::float4 member [[color(0)]];
};
fragment main2Output main2(
) {
    metal::float4 colour = {};
    main1(colour);
    return main2Output { colour };
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, specifically for arrays.

struct Foo {
  xxx: array<f32; 3>;
};
fn bar() -> Foo {
  return Foo(array<f32;3>(1.0, 2.0, 3.0));
}

My understanding is that this PR right now is not going to handle this. Ideally, we'd want it to generate something like

const auto _tmp = Foo { .. };
return { { _tmp.xxx[0], _tmp.xxx[1], _tmp.xxx[2] } };

This will fail because the produced Foo would be declared with the inner in it, but the return code doesn't use it.

Copy link
Member

@kvark kvark Apr 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I don't understand how type1 {1.0, 2.0, 3.0} part works :(
type1 is a structure with 1 element. So shouldn't it yell at us for providing 3 elements instead of 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm not 100% sure either, but what seems to happen is that (as I wrote in #764 (comment)):

Afaik, only having a single pair of those { brackets when constructing a struct is okay if it only has one field

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing is allowed in C:

struct Arr {
    float inner[4];
};

int main() {
    struct Arr arr = { 1.0, 1.5, 2.0, 2.5 };
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh oh, in both MSL and C it lets you skip fields in the constructor though so

Foo bar1(
) {
    return Foo {type1 {1.0}, 2};
}

also works?? With no warnings or errors 😟 ??

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going to happen if the array only has one element? How is the compiler going to figure out if the initializer is for this field or for the inner itself? And initializers don't even have to be provided for all the fields. Ugh, C is weird :(

}