Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid alloc of fusedspec vector #1454

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 88 additions & 53 deletions core/src/ops/matmul/lir_unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,19 @@ impl ProtoFusedSpec {
output: &Tensor,
) -> FusedSpec<'t> {
let fs = match self {
ProtoFusedSpec::AddMatMul { geo, a, b, packing } => {
let mut a = inputs[*a].view();
unsafe {
geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
}
let a =
a.as_slice::<Opaque>().unwrap()[0].downcast_ref::<Box<dyn MMMInput>>().unwrap();
let mut b = inputs[*b].view();
unsafe {
geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
}
let b =
b.as_slice::<Opaque>().unwrap()[0].downcast_ref::<Box<dyn MMMInput>>().unwrap();
ProtoFusedSpec::AddMatMul { geo, a, b, packing } => unsafe {
let mut a = inputs.get_unchecked(*a).view();
geo.c_to_a_axis_mapping.translate_view(output_coords, &mut a);
let a = a.as_slice_unchecked::<Opaque>()[0]
.downcast_ref::<Box<dyn MMMInput>>()
.unwrap_unchecked();
let mut b = inputs.get_unchecked(*b).view();
geo.c_to_b_axis_mapping.translate_view(output_coords, &mut b);
let b = b.as_slice_unchecked::<Opaque>()[0]
.downcast_ref::<Box<dyn MMMInput>>()
.unwrap_unchecked();
FusedSpec::AddMatMul { a: &**a, b: &**b, packing: *packing }
}
},
ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
ProtoFusedSpec::BinPerRow(v, op, map) => {
Expand Down Expand Up @@ -96,21 +94,26 @@ impl ProtoFusedSpec {
}
}

#[inline]
pub fn resolve_trivial<'t>(
&'t self,
inputs: &'t [TValue],
output: &mut Tensor,
) -> FusedSpec<'t> {
let fs = match self {
ProtoFusedSpec::AddMatMul { a, b, packing, .. } => {
let a = &inputs[*a];
let b = &inputs[*b];
let a =
a.to_scalar::<Opaque>().unwrap().downcast_ref::<Box<dyn MMMInput>>().unwrap();
let b =
b.to_scalar::<Opaque>().unwrap().downcast_ref::<Box<dyn MMMInput>>().unwrap();
ProtoFusedSpec::AddMatMul { a, b, packing, .. } => unsafe {
let a = &inputs.get_unchecked(*a);
let b = &inputs.get_unchecked(*b);
let a = a
.to_scalar_unchecked::<Opaque>()
.downcast_ref::<Box<dyn MMMInput>>()
.unwrap_unchecked();
let b = b
.to_scalar_unchecked::<Opaque>()
.downcast_ref::<Box<dyn MMMInput>>()
.unwrap_unchecked();
FusedSpec::AddMatMul { a: &**a, b: &**b, packing: *packing }
}
},
ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
ProtoFusedSpec::BinPerRow(v, op, _) => {
Expand Down Expand Up @@ -282,59 +285,91 @@ impl Op for LirMatMulUnary {

impl EvalOp for LirMatMulUnary {
fn is_stateless(&self) -> bool {
true
false
}

fn eval_with_session(
fn state(
&self,
session: &SessionState,
_session: &mut SessionState,
_node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(Some(Box::new(LirMatMulUnaryState::default())))
}
}

#[derive(Clone, Debug, Default)]
struct LirMatMulUnaryState(Vec<FusedSpec<'static>>);

impl OpState for LirMatMulUnaryState {
fn eval(
&mut self,
session: &mut SessionState,
op: &dyn Op,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let op = op.downcast_ref::<LirMatMulUnary>().unwrap();
unsafe {
let mut cell = session.cached_mmm_scratch_space.borrow_mut();
if !cell
.as_ref()
.map(|scratch| self.mmm.can_use_scratch_space(&**scratch))
.map(|scratch| op.mmm.can_use_scratch_space(&**scratch))
.unwrap_or(false)
{
*cell = None
}
let scratch = cell.get_or_insert_with(|| self.mmm.allocate_scratch_space());

if self.trivial_path {
let c_shape = self.c_fact.shape.as_concrete().unwrap_unchecked();
let geometry = self.geometry.as_concrete().unwrap_unchecked();
let mut c = Tensor::uninitialized_dt(self.c_fact.datum_type, c_shape)?;
let uops: Vec<FusedSpec> =
self.micro_ops.iter().map(|o| o.resolve_trivial(&inputs, &mut c)).collect();
self.mmm.run_with_scratch_space(geometry.m, geometry.n, scratch.as_mut(), &uops)?;
Ok(tvec!(c.into_tvalue()))
self.0.reserve(op.micro_ops.len().saturating_sub(self.0.capacity()));
#[allow(clippy::uninit_vec)]
self.0.set_len(op.micro_ops.len());
// kill static lifefime!
let fused_spec: &mut Vec<FusedSpec> = std::mem::transmute(&mut self.0);
let scratch = cell.get_or_insert_with(|| op.mmm.allocate_scratch_space());

let c = if op.trivial_path {
let c_shape = op.c_fact.shape.as_concrete().unwrap_unchecked();
let geometry = op.geometry.as_concrete().unwrap_unchecked();
let mut c = Tensor::uninitialized_dt(op.c_fact.datum_type, c_shape)?;
for i in 0..op.micro_ops.len() {
*fused_spec.get_unchecked_mut(i) =
op.micro_ops.get_unchecked(i).resolve_trivial(&inputs, &mut c);
}
op.mmm.run_with_scratch_space(
geometry.m,
geometry.n,
scratch.as_mut(),
fused_spec,
)?;
c
} else {
let geometry = self.geometry.to_concrete(&session.resolved_symbols)?;
let c_shape = self.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
let c = Tensor::uninitialized_dt(self.c_fact.datum_type, &c_shape)?;
let mut uops = vec![FusedSpec::ShiftLeft(0); self.micro_ops.len()];
let geometry = op.geometry.to_concrete(&session.resolved_symbols)?;
let c_shape = op.c_fact.shape.eval_to_usize(&session.resolved_symbols)?;
let c = Tensor::uninitialized_dt(op.c_fact.datum_type, &c_shape)?;
let mut looping_shape: TVec<usize> = c_shape.to_smallvec();
looping_shape[self.c_m_axis] = 1;
looping_shape[self.c_n_axis] = 1;
looping_shape[op.c_m_axis] = 1;
looping_shape[op.c_n_axis] = 1;
for c_coords in indices(&*looping_shape) {
for ix in 0..self.micro_ops.len() {
*uops.get_unchecked_mut(ix) =
self.micro_ops.get_unchecked(ix).resolve(&inputs, c_coords.slice(), &c);
for i in 0..op.micro_ops.len() {
*fused_spec.get_unchecked_mut(i) =
op.micro_ops.get_unchecked(i).resolve(&inputs, c_coords.slice(), &c)
}
self.mmm.run_with_scratch_space(
geometry.m,
geometry.n,
scratch.as_mut(),
&uops,
).context("In mmm.run_with_scratch_space")?;
op.mmm
.run_with_scratch_space(
geometry.m,
geometry.n,
scratch.as_mut(),
fused_spec,
)
.context("In mmm.run_with_scratch_space")?;
}
Ok(tvec!(c.into_tvalue()))
}
c
};
fused_spec.clear();
Ok(tvec!(c.into_tvalue()))
}
}
}

trivial_op_state_freeeze!(LirMatMulUnaryState);

impl TypedOp for LirMatMulUnary {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(self.c_m_axis < self.c_fact.rank());
Expand Down
Loading