Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

Commit

Permalink
[msl-out] Implement index bounds check policies for non-texture acces…
Browse files Browse the repository at this point in the history
…ses.
  • Loading branch information
jimblandy committed Nov 15, 2021
1 parent 418b0e5 commit c166082
Show file tree
Hide file tree
Showing 13 changed files with 1,344 additions and 178 deletions.
1 change: 0 additions & 1 deletion src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,6 @@ impl<T> Default for HandleBitSet<T> {

impl<T> HandleBitSet<T> {
/// Return `true` if `handle` is in the set.
#[cfg(feature = "validate")]
pub fn contains(&self, handle: Handle<T>) -> bool {
self.set.contains(handle.index())
}
Expand Down
6 changes: 5 additions & 1 deletion src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ For the result type, if it's a structure, we re-compose it with a temporary valu
holding the result.
!*/

use crate::{arena::Handle, valid::ModuleInfo};
use crate::{arena::Handle, proc::index, valid::ModuleInfo};
use std::{
fmt::{Error as FmtError, Write},
ops,
Expand Down Expand Up @@ -177,6 +177,9 @@ pub struct Options {
pub spirv_cross_compatibility: bool,
/// Don't panic on missing bindings, instead generate invalid MSL.
pub fake_missing_bindings: bool,
/// Bounds checking policies.
#[cfg_attr(feature = "deserialize", serde(default))]
pub bounds_check_policies: index::BoundsCheckPolicies,
}

impl Default for Options {
Expand All @@ -187,6 +190,7 @@ impl Default for Options {
inline_samplers: Vec::new(),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
bounds_check_policies: index::BoundsCheckPolicies::default(),
}
}
}
Expand Down
199 changes: 190 additions & 9 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, Transl
use crate::{
arena::Handle,
back,
proc::index,
proc::{self, NameKey, TypeResolution},
valid, FastHashMap, FastHashSet,
};
Expand Down Expand Up @@ -448,6 +449,7 @@ struct ExpressionContext<'a> {
function: &'a crate::Function,
origin: FunctionOrigin,
info: &'a valid::FunctionInfo,
index_info: index::IndexInfo,
module: &'a crate::Module,
pipeline_options: &'a PipelineOptions,
}
Expand Down Expand Up @@ -743,6 +745,16 @@ impl<W: Write> Writer<W> {
Ok(())
}

/// Emit code for the expression `expr_handle`.
///
/// The `is_scoped` argument is true if the surrounding operators have the precedence
/// of the comma operator, or lower. So, for example:
///
/// - Pass `true` for `is_scoped` when writing function arguments, an expression
/// statement, an initializer expression, or anything already wrapped in parenthesis
///
/// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really almost
/// anything else.
fn put_expression(
&mut self,
expr_handle: Handle<crate::Expression>,
Expand Down Expand Up @@ -775,6 +787,7 @@ impl<W: Write> Writer<W> {
}
_ => base_res.handle(),
};
log::trace!(" resolved {:?} as {:?}", base, resolved);
match *resolved {
crate::TypeInner::Struct { .. } => {
let base_ty = base_ty_handle.unwrap();
Expand All @@ -795,6 +808,7 @@ impl<W: Write> Writer<W> {
}
_ => {
// unexpected indexing, should fail validation
unreachable!("unexpected type given to AccessIndex: {:?}", resolved);
}
}
}
Expand Down Expand Up @@ -846,6 +860,12 @@ impl<W: Write> Writer<W> {
crate::Expression::GlobalVariable(handle) => {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}", name)?;
log::trace!(
" type: {:?}",
context.info[expr_handle]
.ty
.inner_with(&context.module.types)
);
}
crate::Expression::LocalVariable(handle) => {
let name_key = match context.origin {
Expand Down Expand Up @@ -1203,14 +1223,83 @@ impl<W: Write> Writer<W> {
Ok(())
}

/// Emit a boolean expression that is true if `indices` are all within bounds.
///
/// The result is not parenthesized.
fn put_condition(
&mut self,
accesses: &[index::GuardedAccess],
pointer: Handle<crate::Expression>,
context: &ExpressionContext,
) -> BackendResult {
let mut first = true;
for access in accesses {
if first {
first = false;
} else {
self.out.write_str(" && ")?;
}
match access.index {
index::GuardedIndex::Expression(expr) => {
self.put_expression(expr, context, false)?
}
index::GuardedIndex::Known(value) => write!(self.out, "{}", value)?,
}
self.out.write_str(" < ")?;
match access.length {
index::IndexableLength::Known(value) => write!(self.out, "{}", value)?,
index::IndexableLength::Dynamic => {
let global = pointer
.originating_global(context.function)
.ok_or(Error::Validation)?;
self.put_dynamic_array_length(global, context)?
}
}
}

Ok(())
}

fn put_load(
&mut self,
pointer: Handle<crate::Expression>,
context: &ExpressionContext,
is_scoped: bool,
) -> BackendResult {
let comparisons = context.index_info.choose_checks(
pointer,
context.module,
context.function,
context.info,
);
if !comparisons.is_empty() {
if !is_scoped {
write!(self.out, "(")?;
}

self.put_condition(&comparisons, pointer, context)?;
write!(self.out, " ? ")?;
self.put_unchecked_load(pointer, context, false)?;
write!(self.out, " : 0")?;

if !is_scoped {
write!(self.out, ")")?;
}
} else {
self.put_unchecked_load(pointer, context, is_scoped)?;
}

Ok(())
}

fn put_unchecked_load(
&mut self,
pointer: Handle<crate::Expression>,
context: &ExpressionContext,
is_scoped: bool,
) -> BackendResult {
// Because packed vectors such as `packed_float3` cannot be directly multipied by
// matrices, we wrap them with `float3` on load.
// matrices, we convert them to unpacked vectors like `float3` on load.
let wrap_packed_vec_scalar_kind = match context.function.expressions[pointer] {
crate::Expression::AccessIndex { base, index } => {
let ty = match *context.resolve_type(base) {
Expand Down Expand Up @@ -1287,7 +1376,33 @@ impl<W: Write> Writer<W> {
write!(self.out, ".{}", WRAPPED_ARRAY_FIELD)?;
}
write!(self.out, "[")?;
self.put_expression(index, context, true)?;

// If this index needs to be clamped to fall within range, then do so.
if let Some(limit) = context.index_info.needs_restriction(
base,
index,
context.module,
context.function,
context.info,
) {
write!(self.out, "{}::min(unsigned(", NAMESPACE)?;
self.put_expression(index, context, true)?;
write!(self.out, "), ")?;
match limit {
index::IndexableLength::Known(limit) => {
write!(self.out, "{}u)", limit - 1)?;
}
index::IndexableLength::Dynamic => {
let global = base
.originating_global(context.function)
.ok_or(Error::Validation)?;
self.put_dynamic_array_length(global, context)?;
write!(self.out, " - 1)")?;
}
}
} else {
self.put_expression(index, context, true)?;
}
write!(self.out, "]")?;

Ok(())
Expand Down Expand Up @@ -1439,6 +1554,7 @@ impl<W: Write> Writer<W> {
match *statement {
crate::Statement::Emit(ref range) => {
for handle in range.clone() {
log::trace!(" deciding whether to bake {:?}", handle);
let info = &context.expression.info[handle];
let ptr_class = info
.ty
Expand All @@ -1449,21 +1565,45 @@ impl<W: Write> Writer<W> {
} else if let Some(name) =
context.expression.function.named_expressions.get(&handle)
{
// Front end provides names for all variables at the start of writing.
// But we write them to step by step. We need to recache them
// Otherwise, we could accidentally write variable name instead of full expression.
// Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
// The `crate::Function::named_expressions` table holds
// expressions that should be saved in temporaries once they
// are `Emit`ted. We only add them to `self.named_expressions`
// when we reach the `Emit` that covers them, so that we don't
// try to use their names before we've actually initialized
// the temporary that holds them.
//
// Don't assume the names in `named_expressions` are unique,
// or even valid. Use the `Namer`.
Some(self.namer.call(name))
} else {
let min_ref_count =
context.expression.function.expressions[handle].bake_ref_count();
if min_ref_count <= info.ref_count {
let bake;

// If this expression is an index that we're going to first compare
// against a limit, and then actually use as an index, then we may
// want to cache it in a temporary, to avoid evaluating it twice.
if context
.expression
.index_info
.guarded_indices
.contains(handle)
{
bake = true;
} else {
// Expressions whose reference count is above the
// threshold should always be stored in temporaries.
let min_ref_count = context.expression.function.expressions[handle]
.bake_ref_count();
bake = min_ref_count <= info.ref_count
};

if bake {
Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
} else {
None
}
};

log::trace!(" verdict: {:?}", expr_name);
if let Some(name) = expr_name {
write!(self.out, "{}", level)?;
self.start_baking_expression(handle, &context.expression, &name)?;
Expand Down Expand Up @@ -1739,6 +1879,29 @@ impl<W: Write> Writer<W> {
value: Handle<crate::Expression>,
level: back::Level,
context: &StatementContext,
) -> BackendResult {
let comparisons = context.expression.index_info.choose_checks(
pointer,
context.expression.module,
context.expression.function,
context.expression.info,
);
if !comparisons.is_empty() {
write!(self.out, "{}if (", level)?;
self.put_condition(&comparisons, pointer, &context.expression)?;
writeln!(self.out, ")")?;
self.put_unchecked_store(pointer, value, level.next(), context)
} else {
self.put_unchecked_store(pointer, value, level, context)
}
}

fn put_unchecked_store(
&mut self,
pointer: Handle<crate::Expression>,
value: Handle<crate::Expression>,
level: back::Level,
context: &StatementContext,
) -> BackendResult {
let pointer_info = &context.expression.info[pointer];
let (array_size, is_atomic) =
Expand Down Expand Up @@ -2138,6 +2301,7 @@ impl<W: Write> Writer<W> {

writeln!(self.out)?;
let fun_name = &self.names[&NameKey::Function(fun_handle)];
log::trace!("write_functions: function {:?}", fun_name);
match fun.result {
Some(ref result) => {
let ty_name = TypeContext {
Expand Down Expand Up @@ -2225,11 +2389,19 @@ impl<W: Write> Writer<W> {
writeln!(self.out, ";")?;
}

let index_info = index::IndexInfo::analyze_function(
module,
fun,
fun_info,
options.bounds_check_policies,
);

let context = StatementContext {
expression: ExpressionContext {
function: fun,
origin: FunctionOrigin::Handle(fun_handle),
info: fun_info,
index_info,
module,
pipeline_options,
},
Expand Down Expand Up @@ -2294,6 +2466,7 @@ impl<W: Write> Writer<W> {
continue;
}
let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
log::trace!("write_functions: entry point {:?}", fun_name);
info.entry_point_names.push(Ok(fun_name.clone()));

writeln!(self.out)?;
Expand Down Expand Up @@ -2647,11 +2820,19 @@ impl<W: Write> Writer<W> {
writeln!(self.out, ";")?;
}

let index_info = index::IndexInfo::analyze_function(
module,
fun,
fun_info,
options.bounds_check_policies,
);

let context = StatementContext {
expression: ExpressionContext {
function: fun,
origin: FunctionOrigin::EntryPoint(ep_index as _),
info: fun_info,
index_info,
module,
pipeline_options,
},
Expand Down
Loading

0 comments on commit c166082

Please sign in to comment.