Skip to content

Commit

Permalink
Fix transpose onnx op (permute) (#1657)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Apr 19, 2024
1 parent ee12aee commit b65a487
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 18 deletions.
17 changes: 11 additions & 6 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,18 +686,23 @@ mod tests {
let model: transpose::Model<Backend> = transpose::Model::new(&device);

// Run the model
let input = Tensor::<Backend, 2>::from_floats(
let input = Tensor::<Backend, 3>::from_floats(
[
[0.33669037, 0.128_809_4, 0.23446237],
[0.23033303, -1.122_856_4, -0.18632829],
[[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]],
[
[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.],
],
],
&device,
);
let output = model.forward(input);
let expected = Data::from([
[0.33669037, 0.23033303],
[0.128_809_4, -1.122_856_4],
[0.23446237, -0.18632829],
[[0., 4., 8.], [12., 16., 20.]],
[[1., 5., 9.], [13., 17., 21.]],
[[2., 6., 10.], [14., 18., 22.]],
[[3., 7., 11.], [15., 19., 23.]],
]);

assert_eq!(output.to_data(), expected);
Expand Down
Binary file modified crates/burn-import/onnx-tests/tests/transpose/transpose.onnx
Binary file not shown.
15 changes: 7 additions & 8 deletions crates/burn-import/onnx-tests/tests/transpose/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ def __init__(self):
super(Model, self).__init__()

def forward(self, x):
x = x.transpose(0, 1)
return x
return x.permute(2, 0, 1)


def main():
Expand All @@ -28,18 +27,18 @@ def main():
device = torch.device("cpu")

file_name = "transpose.onnx"
test_input = torch.randn(2, 3, device=device)
test_input = torch.arange(24, dtype=torch.float, device=device).reshape(2, 3, 4)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))
print(f"Finished exporting model to {file_name}")

# Output some test data for use in the test
print("Test input data: {}".format(test_input))
print("Test input data shape: {}".format(test_input.shape))
print(f"Test input data: {test_input}")
print(f"Test input data shape: {test_input.shape}")
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test output data: {}".format(output))
print(f"Test output data shape: {output.shape}")
print(f"Test output data: {output}")


if __name__ == '__main__':
Expand Down
8 changes: 5 additions & 3 deletions crates/burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ impl UnaryNode {
Self::new(input, output, UnaryNodeKind::Tanh, Rc::new(function))
}

pub(crate) fn transpose(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.transpose() };
pub(crate) fn transpose(input: Type, output: Type, perm: Vec<i64>) -> Self {
let perm = perm.to_tokens();
let function = move |input| quote! { #input.permute(#perm) };
Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function))
}

Expand Down Expand Up @@ -538,10 +539,11 @@ mod tests {
UnaryNode::transpose(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
vec![0, 3, 1, 2],
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.transpose();
let tensor2 = tensor1.permute([0, 3, 1, 2]);

tensor2
}
Expand Down
24 changes: 24 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,3 +786,27 @@ pub fn shape_config(curr: &Node) -> (usize, usize) {

(start_dim as usize, end_dim as usize)
}

pub fn transpose_config(curr: &Node) -> Vec<i64> {
if curr.inputs.len() != 1 {
panic!(
"Transpose: multiple inputs are not supported (got {:?})",
curr.inputs.len()
);
}

// Extract the shape of the input tensor
let tensor = match curr.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// Default: reverse the dimensions
let mut perm = (0..tensor.dim as i64).rev().collect::<Vec<i64>>();

if let Some(axes) = curr.attrs.get("perm") {
perm = axes.clone().into_i64s();
}

perm
}
3 changes: 2 additions & 1 deletion crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,9 @@ impl OnnxGraph {
fn transpose_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let perm = transpose_config(&node);

UnaryNode::transpose(input, output)
UnaryNode::transpose(input, output, perm)
}

fn cast_conversion(node: Node) -> UnaryNode {
Expand Down

0 comments on commit b65a487

Please sign in to comment.