Skip to content

Commit

Permalink
Fix yolov5 postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
cchudant committed Sep 14, 2022
1 parent 1c13f28 commit cc55da7
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 48 deletions.
42 changes: 26 additions & 16 deletions core/src/ops/array/dyn_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,29 @@ impl EvalOp for DynSlice {
}

fn eval(&self, inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
unsafe {
let start =
if self.start_input { inputs[1].cast_to_scalar::<i64>()? as usize } else { 0 };
let end = if self.end_input {
inputs[1 + self.start_input as usize].cast_to_scalar::<i64>()? as usize
} else {
inputs[0].shape()[self.axis]
};
if start >= end {
bail!("Invalid range {}-{}", start, end);
}
let mut shape: TVec<_> = inputs[0].shape().into();
shape[self.axis] = end - start;
let mut tensor = Tensor::uninitialized_dt(inputs[0].datum_type(), &shape)?;
tensor.assign_slice_unchecked(.., &inputs[0], start..end, self.axis);
Ok(tvec!(tensor.into_arc_tensor()))
let start = if self.start_input { inputs[1].cast_to_scalar::<i64>()? as usize } else { 0 };
let end = if self.end_input {
inputs[1 + self.start_input as usize].cast_to_scalar::<i64>()? as usize
} else {
inputs[0].shape()[self.axis]
};

let actual_axis_len = inputs[0].shape()[self.axis];
let (src_start, src_end) = (start.min(actual_axis_len), end.min(actual_axis_len));

if start > end {
bail!("Invalid range {}-{}", start, end);
}

let mut shape: TVec<_> = inputs[0].shape().into();
shape[self.axis] = src_end - src_start;

let tensor = unsafe {
let mut tensor = Tensor::uninitialized_dt(inputs[0].datum_type(), &shape)?;
tensor.assign_slice_unchecked(.., &inputs[0], src_start..src_end, self.axis);
tensor
};
Ok(tvec!(tensor.into_arc_tensor()))
}
}

Expand Down Expand Up @@ -117,6 +123,10 @@ impl TypedOp for DynSlice {
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
if inputs[0].shape[self.axis].to_usize().is_err() {
return Ok(None)
}

let start =
if self.start_input { inputs[1].konst.clone() } else { Some(rctensor0(TDim::zero())) };
let end = if self.end_input {
Expand Down
15 changes: 10 additions & 5 deletions core/src/ops/array/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,16 @@ impl TypedOp for Gather {
},
&[wire],
)?[0];
wire = patch.wire_node(
format!("{}.rm_axis", node.name),
crate::ops::change_axes::AxisOp::Rm(self.axis),
&[wire],
)?[0];
let original_rank = model.outlet_fact(node.id.into())?.shape.rank();
let new_rank = patch.model.outlet_fact(wire)?.shape.rank();

if new_rank == original_rank + 1 {
wire = patch.wire_node(
format!("{}.rm_axis", node.name),
crate::ops::change_axes::AxisOp::Rm(self.axis),
&[wire],
)?[0];
}
patch.shunt_outside(model, node.id.into(), wire)?;
return Ok(Some(patch));
}
Expand Down
14 changes: 12 additions & 2 deletions hir/src/ops/array/strided_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,16 @@ impl Expansion for StridedSlice {
let begin = params[0].as_ref();
let end = params[1].as_ref();
for (ix, &axis) in axes.iter().enumerate() {
if let (Some(begin), Some(end)) = (begin, end) {
let d = &input_shape[axis];

// note: if the input axis has symbols, we really cannot know how to slice statically
// example: slice( 'h', (0..10) )
// means that if h < 10 at runtime, the resulting axis is < 10
// and if h > 10, resulting axis is always 10

let d = &input_shape[axis];
if let (Some(begin), Some(end), Ok(_)) = (begin, end, d.to_usize()) {
// this is the case where you can know the resulting axis statically

let preped = self.prepare_one_dim(ix, d, begin, end, &strides)?;
let (left, right) = if preped.stride > 0 {
(preped.begin, preped.end)
Expand All @@ -279,6 +287,8 @@ impl Expansion for StridedSlice {
)?[0];
}
} else if strides[ix] == 1 {
// this is the case where we can't know the resulting axis statically

let left = target.wire_node(
format!("{}.slice-axis-{}-start", prefix, axis),
crate::ops::array::Slice::new(0, ix, ix + 1),
Expand Down
4 changes: 2 additions & 2 deletions onnx/src/ops/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) {
reg.insert("Scatter", scatter_elements);
reg.insert("ScatterElements", scatter_elements);
reg.insert("ScatterND", |_, _| Ok((Box::new(array::ScatterNd), vec![])));
reg.insert("Shape", |_, _| Ok((expand(array::Shape::new(DatumType::TDim)), vec![])));
reg.insert("Size", |_, _| Ok((expand(array::Size::new(DatumType::TDim)), vec![])));
reg.insert("Shape", |_, _| Ok((expand(array::Shape::new(DatumType::I64)), vec![])));
reg.insert("Size", |_, _| Ok((expand(array::Size::new(DatumType::I64)), vec![])));
reg.insert("Slice", slice::slice);
reg.insert("Split", split::split);
reg.insert("Squeeze", squeeze::squeeze);
Expand Down
81 changes: 58 additions & 23 deletions onnx/src/ops/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ pub fn _if(
#[derive(Debug, Clone, new, Hash)]
struct If {
then_body: InferenceModel,
then_input_mapping: Vec<usize>,
then_input_mapping: TVec<usize>,
else_body: InferenceModel,
else_input_mapping: Vec<usize>,
else_input_mapping: TVec<usize>,
}

impl_dyn_hash!(If);
Expand Down Expand Up @@ -167,7 +167,6 @@ impl InferenceOp for If {

Ok(body.output_outlets()?.iter().map(|o| inner_mapping[o]).collect())
} else {

target.wire_node(
&node.name,
IfMir {
Expand All @@ -184,12 +183,54 @@ impl InferenceOp for If {
as_op!();
}

/// Returns the output fact that is the result of the If control flow.
/// This could be thought of as the output fact of the Phi node of the Then and Else subgraphs,
/// (but it's arguably not as fancy as that.)
pub fn phi_result(then: &TypedFact, elze: &TypedFact) -> TractResult<TypedFact> {
if then.konst.is_some() && elze.konst.is_some() && then.konst == elze.konst {
return Ok(then.clone());
}

if then.datum_type != elze.datum_type {
bail!(
"If operator branches has incompatible datum types (then: {:?}; else: {:?})",
then.datum_type,
elze.datum_type
)
}

if then.shape.rank() != elze.shape.rank() {
bail!(
"If operator branches has incompatible ranks (then: {:?}; else: {:?})",
then.shape.rank(),
elze.shape.rank()
)
}

// [4, 'n', 18] . [4, 'k', 3] => [4, '?', '?']
let shape: TVec<_> = then
.shape
.iter()
.zip(elze.shape.iter())
.map(|(then_dim, else_dim)| {
let then_dim = then_dim.eval(&SymbolValues::default());
if then_dim == else_dim.eval(&SymbolValues::default()) {
then_dim
} else {
Symbol::new('h').to_dim()
}
})
.collect();

Ok(TypedFact::dt_shape(then.datum_type, shape))
}

#[derive(Debug, Clone, new, Hash)]
struct IfMir {
then_body: TypedModel,
then_input_mapping: Vec<usize>,
then_input_mapping: TVec<usize>,
else_body: TypedModel,
else_input_mapping: Vec<usize>,
else_input_mapping: TVec<usize>,
}

impl_dyn_hash!(IfMir);
Expand Down Expand Up @@ -227,23 +268,17 @@ impl TypedOp for IfMir {
fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let then_outputs =
self.then_body.outputs.iter().copied().map(|outlet| self.then_body.outlet_fact(outlet));
// let else_outputs =
// self.else_body.outputs.iter().copied().map(|outlet| self.else_body.outlet_fact(outlet));

// then_outputs
// .zip(else_outputs)
// .map(|(tfact, efact)| {
// let (tfact, _efact) = (tfact?.without_value(), efact?.without_value());
// ensure!(
// tfact.same_as(&efact),
// "Then and Else body have different output types {:?} and {:?}",
// tfact,
// efact
// );
// Ok(tfact)
// })
// .collect()

then_outputs.map(|e| Ok(e?.without_value())).collect()
let else_outputs =
self.else_body.outputs.iter().copied().map(|outlet| self.else_body.outlet_fact(outlet));

let facts = then_outputs
.zip(else_outputs)
.map(|(tfact, efact)| {
let (tfact, efact) = (tfact?.without_value(), efact?.without_value());
phi_result(&tfact, &efact)
})
.collect();

facts
}
}

0 comments on commit cc55da7

Please sign in to comment.