diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 638ca153b5..c62e8da864 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -768,9 +768,20 @@ pub fn tile_config(node: &Node) -> TileConfig { /// Create a PadConfig from the attributes of the node pub fn pad_config(node: &Node) -> PadConfig { + fn get_pads_input(node: &Node) -> Vec { + // If the input is not provided, return an empty vector + if node.inputs.get(1).is_none() { + return Vec::new(); + } + + match &node.inputs[1].value { + Some(Data::Int64s(shape)) => shape.clone(), + _ => panic!("Tensor data type must be int64"), + } + } fn get_pads(node: &Node) -> Vec { - if node.inputs.len() < 2 { - panic!("Pad: must provide at least two inputs") + if node.inputs.is_empty() { + panic!("Pad: must provide data as input") } if node.inputs.len() >= 4 { panic!("Pad: axes input is not supported") @@ -781,19 +792,41 @@ pub fn pad_config(node: &Node) -> PadConfig { _ => panic!("Pad: Only tensor input is valid"), }; - let pads: Vec = match &node.inputs[1].value { - Some(Data::Int64s(shape)) => shape - .iter() - .map(|&x| { - if x < 0 { - // TODO: support negative pads - panic!("Pad: Negative pad is not supported"); + //TODO : handle more possible attributes + let mut pads: Vec = get_pads_input(node) + .into_iter() + .map(|x| x as usize) + .collect(); + + for (key, value) in node.attrs.iter() { + match key.as_str() { + "pads" => { + pads = value + .clone() + .into_i64s() + .iter() + .map(|&x| { + if x < 0 { + panic!("Pad: Negative pad is not supported"); + } + x as usize + }) + .collect() + } + "mode" => { + let mode = value.clone().into_string(); + if mode != "constant" { + panic!("only constant mode is supported, given mode is {}", mode); } - x as usize - }) - .collect(), - _ => panic!("Pad: pads data type must be int64"), - }; + } + + _ => {} + } + } + + if pads.is_empty() { + panic!("Pad: pads should be given as attribute or as input"); + } if pads.len() != input_dim * 2 { panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]"); @@ -823,7 +856,7 @@ pub fn pad_config(node: &Node) -> PadConfig { } fn get_constant_value(node: &Node) -> f32 { // TODO: support int, boolean - node.inputs + let mut constant_value = node.inputs .get(2) .and_then(|input| match &input.value { Some(Data::Float16s(constant_value)) => { @@ -840,7 +873,15 @@ pub fn pad_config(node: &Node) -> PadConfig { Some(Data::Float64(constant_value)) => Some(*constant_value as f32), _ => panic!("Pad: only float values are currently supported for constant value, submit an issue on github"), }) - .unwrap_or(0.0) + .unwrap_or(0.0); + + if node.attrs.contains_key("value") { + constant_value = node.attrs.get("value").map(|value| match value { + AttributeValue::Float32(value) => *value, + _ => panic!("Pad: only float32 values are currently supported for constant value as attribute, submit an issue on github"), + }).expect("constant_value should have had a value now"); + } + constant_value } let pads = get_pads(node);