From cc55da727f6a46066ed0bd6fefd8fd036590b644 Mon Sep 17 00:00:00 2001 From: Charles Chudant Date: Tue, 13 Sep 2022 16:06:31 +0200 Subject: [PATCH] Fix yolov5 postprocessing --- core/src/ops/array/dyn_slice.rs | 42 ++++++++++------ core/src/ops/array/gather.rs | 15 ++++-- hir/src/ops/array/strided_slice.rs | 14 +++++- onnx/src/ops/array/mod.rs | 4 +- onnx/src/ops/logic.rs | 81 +++++++++++++++++++++--------- 5 files changed, 108 insertions(+), 48 deletions(-) diff --git a/core/src/ops/array/dyn_slice.rs b/core/src/ops/array/dyn_slice.rs index a474c6a10f..3a32e43a5d 100644 --- a/core/src/ops/array/dyn_slice.rs +++ b/core/src/ops/array/dyn_slice.rs @@ -51,23 +51,29 @@ impl EvalOp for DynSlice { } fn eval(&self, inputs: TVec>) -> TractResult>> { - unsafe { - let start = - if self.start_input { inputs[1].cast_to_scalar::()? as usize } else { 0 }; - let end = if self.end_input { - inputs[1 + self.start_input as usize].cast_to_scalar::()? as usize - } else { - inputs[0].shape()[self.axis] - }; - if start >= end { - bail!("Invalid range {}-{}", start, end); - } - let mut shape: TVec<_> = inputs[0].shape().into(); - shape[self.axis] = end - start; - let mut tensor = Tensor::uninitialized_dt(inputs[0].datum_type(), &shape)?; - tensor.assign_slice_unchecked(.., &inputs[0], start..end, self.axis); - Ok(tvec!(tensor.into_arc_tensor())) + let start = if self.start_input { inputs[1].cast_to_scalar::()? as usize } else { 0 }; + let end = if self.end_input { + inputs[1 + self.start_input as usize].cast_to_scalar::()? as usize + } else { + inputs[0].shape()[self.axis] + }; + + let actual_axis_len = inputs[0].shape()[self.axis]; + let (src_start, src_end) = (start.min(actual_axis_len), end.min(actual_axis_len)); + + if start > end { + bail!("Invalid range {}-{}", start, end); } + + let mut shape: TVec<_> = inputs[0].shape().into(); + shape[self.axis] = src_end - src_start; + + let tensor = unsafe { + let mut tensor = Tensor::uninitialized_dt(inputs[0].datum_type(), &shape)?; + tensor.assign_slice_unchecked(.., &inputs[0], src_start..src_end, self.axis); + tensor + }; + Ok(tvec!(tensor.into_arc_tensor())) } } @@ -117,6 +123,10 @@ impl TypedOp for DynSlice { node: &TypedNode, ) -> TractResult> { let inputs = model.node_input_facts(node.id)?; + if inputs[0].shape[self.axis].to_usize().is_err() { + return Ok(None) + } + let start = if self.start_input { inputs[1].konst.clone() } else { Some(rctensor0(TDim::zero())) }; let end = if self.end_input { diff --git a/core/src/ops/array/gather.rs b/core/src/ops/array/gather.rs index 384bcdea9d..a6ebd6a709 100644 --- a/core/src/ops/array/gather.rs +++ b/core/src/ops/array/gather.rs @@ -110,11 +110,16 @@ impl TypedOp for Gather { }, &[wire], )?[0]; - wire = patch.wire_node( - format!("{}.rm_axis", node.name), - crate::ops::change_axes::AxisOp::Rm(self.axis), - &[wire], - )?[0]; + let original_rank = model.outlet_fact(node.id.into())?.shape.rank(); + let new_rank = patch.model.outlet_fact(wire)?.shape.rank(); + + if new_rank == original_rank + 1 { + wire = patch.wire_node( + format!("{}.rm_axis", node.name), + crate::ops::change_axes::AxisOp::Rm(self.axis), + &[wire], + )?[0]; + } patch.shunt_outside(model, node.id.into(), wire)?; return Ok(Some(patch)); } diff --git a/hir/src/ops/array/strided_slice.rs b/hir/src/ops/array/strided_slice.rs index 2bf2f37184..dcab946f16 100644 --- a/hir/src/ops/array/strided_slice.rs +++ b/hir/src/ops/array/strided_slice.rs @@ -258,8 +258,16 @@ impl Expansion for StridedSlice { let begin = params[0].as_ref(); let end = params[1].as_ref(); for (ix, &axis) in axes.iter().enumerate() { - if let (Some(begin), Some(end)) = (begin, end) { - let d = &input_shape[axis]; + + // note: if the input axis has symbols, we really cannot know how to slice statically + // example: slice( 'h', (0..10) ) + // means that if h < 10 at runtime, the resulting axis is < 10 + // and if h > 10, resulting axis is always 10 + + let d = &input_shape[axis]; + if let (Some(begin), Some(end), Ok(_)) = (begin, end, d.to_usize()) { + // this is the case where you can know the resulting axis statically + let preped = self.prepare_one_dim(ix, d, begin, end, &strides)?; let (left, right) = if preped.stride > 0 { (preped.begin, preped.end) @@ -279,6 +287,8 @@ impl Expansion for StridedSlice { )?[0]; } } else if strides[ix] == 1 { + // this is the case where we can't know the resulting axis statically + let left = target.wire_node( format!("{}.slice-axis-{}-start", prefix, axis), crate::ops::array::Slice::new(0, ix, ix + 1), diff --git a/onnx/src/ops/array/mod.rs b/onnx/src/ops/array/mod.rs index b912095392..7fad3a2236 100644 --- a/onnx/src/ops/array/mod.rs +++ b/onnx/src/ops/array/mod.rs @@ -33,8 +33,8 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) { reg.insert("Scatter", scatter_elements); reg.insert("ScatterElements", scatter_elements); reg.insert("ScatterND", |_, _| Ok((Box::new(array::ScatterNd), vec![]))); - reg.insert("Shape", |_, _| Ok((expand(array::Shape::new(DatumType::TDim)), vec![]))); - reg.insert("Size", |_, _| Ok((expand(array::Size::new(DatumType::TDim)), vec![]))); + reg.insert("Shape", |_, _| Ok((expand(array::Shape::new(DatumType::I64)), vec![]))); + reg.insert("Size", |_, _| Ok((expand(array::Size::new(DatumType::I64)), vec![]))); reg.insert("Slice", slice::slice); reg.insert("Split", split::split); reg.insert("Squeeze", squeeze::squeeze); diff --git a/onnx/src/ops/logic.rs b/onnx/src/ops/logic.rs index fba5b990d5..614c672f6f 100644 --- a/onnx/src/ops/logic.rs +++ b/onnx/src/ops/logic.rs @@ -57,9 +57,9 @@ pub fn _if( #[derive(Debug, Clone, new, Hash)] struct If { then_body: InferenceModel, - then_input_mapping: Vec, + then_input_mapping: TVec, else_body: InferenceModel, - else_input_mapping: Vec, + else_input_mapping: TVec, } impl_dyn_hash!(If); @@ -167,7 +167,6 @@ impl InferenceOp for If { Ok(body.output_outlets()?.iter().map(|o| inner_mapping[o]).collect()) } else { - target.wire_node( &node.name, IfMir { @@ -184,12 +183,54 @@ impl InferenceOp for If { as_op!(); } +/// Returns the output fact that is the result of the If control flow. +/// This could be thought of as the output fact of the Phi node of the Then and Else subgraphs, +/// (but it's arguably not as fancy as that.) +pub fn phi_result(then: &TypedFact, elze: &TypedFact) -> TractResult { + if then.konst.is_some() && elze.konst.is_some() && then.konst == elze.konst { + return Ok(then.clone()); + } + + if then.datum_type != elze.datum_type { + bail!( + "If operator branches has incompatible datum types (then: {:?}; else: {:?})", + then.datum_type, + elze.datum_type + ) + } + + if then.shape.rank() != elze.shape.rank() { + bail!( + "If operator branches has incompatible ranks (then: {:?}; else: {:?})", + then.shape.rank(), + elze.shape.rank() + ) + } + + // [4, 'n', 18] . [4, 'k', 3] => [4, '?', '?'] + let shape: TVec<_> = then + .shape + .iter() + .zip(elze.shape.iter()) + .map(|(then_dim, else_dim)| { + let then_dim = then_dim.eval(&SymbolValues::default()); + if then_dim == else_dim.eval(&SymbolValues::default()) { + then_dim + } else { + Symbol::new('h').to_dim() + } + }) + .collect(); + + Ok(TypedFact::dt_shape(then.datum_type, shape)) +} + #[derive(Debug, Clone, new, Hash)] struct IfMir { then_body: TypedModel, - then_input_mapping: Vec, + then_input_mapping: TVec, else_body: TypedModel, - else_input_mapping: Vec, + else_input_mapping: TVec, } impl_dyn_hash!(IfMir); @@ -227,23 +268,17 @@ impl TypedOp for IfMir { fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult> { let then_outputs = self.then_body.outputs.iter().copied().map(|outlet| self.then_body.outlet_fact(outlet)); - // let else_outputs = - // self.else_body.outputs.iter().copied().map(|outlet| self.else_body.outlet_fact(outlet)); - - // then_outputs - // .zip(else_outputs) - // .map(|(tfact, efact)| { - // let (tfact, _efact) = (tfact?.without_value(), efact?.without_value()); - // ensure!( - // tfact.same_as(&efact), - // "Then and Else body have different output types {:?} and {:?}", - // tfact, - // efact - // ); - // Ok(tfact) - // }) - // .collect() - - then_outputs.map(|e| Ok(e?.without_value())).collect() + let else_outputs = + self.else_body.outputs.iter().copied().map(|outlet| self.else_body.outlet_fact(outlet)); + + let facts = then_outputs + .zip(else_outputs) + .map(|(tfact, efact)| { + let (tfact, efact) = (tfact?.without_value(), efact?.without_value()); + phi_result(&tfact, &efact) + }) + .collect(); + + facts } }