Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pad-input-fix: adding support for pads as attributes #2195

Merged
merged 4 commits into from
Aug 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -768,9 +768,20 @@ pub fn tile_config(node: &Node) -> TileConfig {

/// Create a PadConfig from the attributes of the node
pub fn pad_config(node: &Node) -> PadConfig {
fn get_pads_input(node: &Node) -> Vec<i64> {
// If the input is not provided, return an empty vector
if node.inputs.get(1).is_none() {
return Vec::new();
}

match &node.inputs[1].value {
Some(Data::Int64s(shape)) => shape.clone(),
_ => panic!("Tensor data type must be int64"),
}
}
fn get_pads(node: &Node) -> Vec<usize> {
if node.inputs.len() < 2 {
panic!("Pad: must provide at least two inputs")
if node.inputs.is_empty() {
panic!("Pad: must provide data as input")
}
if node.inputs.len() >= 4 {
panic!("Pad: axes input is not supported")
Expand All @@ -781,19 +792,41 @@ pub fn pad_config(node: &Node) -> PadConfig {
_ => panic!("Pad: Only tensor input is valid"),
};

let pads: Vec<usize> = match &node.inputs[1].value {
Some(Data::Int64s(shape)) => shape
.iter()
.map(|&x| {
if x < 0 {
// TODO: support negative pads
panic!("Pad: Negative pad is not supported");
//TODO : handle more possible attributes
let mut pads: Vec<usize> = get_pads_input(node)
.into_iter()
.map(|x| x as usize)
.collect();

for (key, value) in node.attrs.iter() {
match key.as_str() {
"pads" => {
pads = value
.clone()
.into_i64s()
.iter()
.map(|&x| {
if x < 0 {
panic!("Pad: Negative pad is not supported");
}
x as usize
})
.collect()
Comment on lines +803 to +814
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are supporting the pads as attribute, which corresponds to version 2 of the operator, then we should also parse the value attribute to make sure we are correctly capturing all parameters of the node.

Make sure it doesn't come in conflict with the constant_value input parsing for later versions of the operator 🙂

}
"mode" => {
let mode = value.clone().into_string();
if mode != "constant" {
panic!("only constant mode is supported, given mode is {}", mode);
}
x as usize
})
.collect(),
_ => panic!("Pad: pads data type must be int64"),
};
}

_ => {}
}
}

if pads.is_empty() {
panic!("Pad: pads should be given as attribute or as input");
}

if pads.len() != input_dim * 2 {
panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]");
Expand Down Expand Up @@ -823,7 +856,7 @@ pub fn pad_config(node: &Node) -> PadConfig {
}
fn get_constant_value(node: &Node) -> f32 {
// TODO: support int, boolean
node.inputs
let mut constant_value = node.inputs
.get(2)
.and_then(|input| match &input.value {
Some(Data::Float16s(constant_value)) => {
Expand All @@ -840,7 +873,15 @@ pub fn pad_config(node: &Node) -> PadConfig {
Some(Data::Float64(constant_value)) => Some(*constant_value as f32),
_ => panic!("Pad: only float values are currently supported for constant value, submit an issue on github"),
})
.unwrap_or(0.0)
.unwrap_or(0.0);

if node.attrs.contains_key("value") {
constant_value = node.attrs.get("value").map(|value| match value {
AttributeValue::Float32(value) => *value,
_ => panic!("Pad: only float32 values are currently supported for constant value as attribute, submit an issue on github"),
}).expect("constant_value should have had a value now");
}
constant_value
}

let pads = get_pads(node);
Expand Down
Loading