Skip to content

Commit

Permalink
ONNX debug improvements (#1712)
Browse files Browse the repository at this point in the history
* Minor debug improvements

* Change warn to panic

* Log improvements
  • Loading branch information
antimora authored Apr 30, 2024
1 parent 587b8f8 commit ff9e875
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 10 deletions.
15 changes: 10 additions & 5 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ fn cast_update_outputs(node: &mut Node) {
}
_ => panic!("Cast: only scalar and tensor inputs are valid"),
}

log::debug!(
"Cast: input type: {:?}, output type: {:?}",
input.ty,
output.ty
);
}

fn concat_update_outputs(node: &mut Node) {
Expand Down Expand Up @@ -300,11 +306,10 @@ fn same_as_input(node: &mut Node) {
}

/// Temporary pass-through stub for dimension inference so that we can export the IR model.
fn temporary_pass_through_stub(node: &Node) {
log::warn!(
"Must implement dimension inference for {:?}",
node.node_type
);
fn temporary_pass_through_stub(node: &mut Node) {
log::warn!("Must implement dimension inference for {:?}", node);
log::warn!("Temporarily setting the output type to the input type.");
node.outputs[0].ty = node.inputs[0].ty.clone();
}

fn equal_update_outputs(node: &mut Node) {
Expand Down
5 changes: 2 additions & 3 deletions crates/burn-import/src/onnx/from_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,13 +503,12 @@ fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) {
node_input.passed = false;
}
}
log::debug!("\n\nchecking outputs");
let mut out_count = 1;
if node.node_type == NodeType::Constant || node.node_type == NodeType::Identity {
log::debug!("it's a constant");
let new_name = format!("{}_out{}", node.name, out_count);
graph_io.insert(&node.outputs[0], &new_name);
node.outputs[0].name = new_name;
node.outputs[0].name = new_name.clone();
log::debug!("Found {} constant", new_name);
} else {
for output in node.outputs.iter_mut() {
log::debug!("output name: {}", &output.name);
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,9 @@ pub fn reshape_config(node: &Node) -> Vec<i64> {
panic!("Zero shape size is not supported");
}

// TODO: check "shape" attribute
if node.inputs.len() != 2 || node.inputs[1].value.is_none() {
panic!("Reshape: shape tensor must be present");
panic!("Reshape: shape tensor must be present for {:?}", node);
}

let input_value = &node.inputs[1].value;
Expand Down
8 changes: 7 additions & 1 deletion crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ impl OnnxGraph {
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
let mut graph = BurnGraph::<PS>::default();

let mut unsupported_ops = vec![];

for node in self.nodes {
match node.node_type {
NodeType::Add => graph.register(Self::add_conversion(node)),
Expand Down Expand Up @@ -279,10 +281,14 @@ impl OnnxGraph {
NodeType::Unsqueeze => graph.register(Self::unsqueeze_conversion(node)),
NodeType::Where => graph.register(Self::where_conversion(node)),
NodeType::Sign => graph.register(Self::sign_conversion(node)),
_ => panic!("Unsupported node conversion {}", node.node_type),
node_type => unsupported_ops.push(node_type),
}
}

if !unsupported_ops.is_empty() {
panic!("Unsupported ops: {:?}", unsupported_ops);
}

// Get input and output names
let input_names = self
.inputs
Expand Down

0 comments on commit ff9e875

Please sign in to comment.