diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index f0ca41a635..746dd9984f 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -686,18 +686,23 @@ mod tests { let model: transpose::Model = transpose::Model::new(&device); // Run the model - let input = Tensor::::from_floats( + let input = Tensor::::from_floats( [ - [0.33669037, 0.128_809_4, 0.23446237], - [0.23033303, -1.122_856_4, -0.18632829], + [[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]], + [ + [12., 13., 14., 15.], + [16., 17., 18., 19.], + [20., 21., 22., 23.], + ], ], &device, ); let output = model.forward(input); let expected = Data::from([ - [0.33669037, 0.23033303], - [0.128_809_4, -1.122_856_4], - [0.23446237, -0.18632829], + [[0., 4., 8.], [12., 16., 20.]], + [[1., 5., 9.], [13., 17., 21.]], + [[2., 6., 10.], [14., 18., 22.]], + [[3., 7., 11.], [15., 19., 23.]], ]); assert_eq!(output.to_data(), expected); diff --git a/crates/burn-import/onnx-tests/tests/transpose/transpose.onnx b/crates/burn-import/onnx-tests/tests/transpose/transpose.onnx index 910dc6cc71..4b765bac9a 100644 Binary files a/crates/burn-import/onnx-tests/tests/transpose/transpose.onnx and b/crates/burn-import/onnx-tests/tests/transpose/transpose.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/transpose/transpose.py b/crates/burn-import/onnx-tests/tests/transpose/transpose.py index 17d9803f01..50e4a5a0d9 100755 --- a/crates/burn-import/onnx-tests/tests/transpose/transpose.py +++ b/crates/burn-import/onnx-tests/tests/transpose/transpose.py @@ -11,8 +11,7 @@ def __init__(self): super(Model, self).__init__() def forward(self, x): - x = x.transpose(0, 1) - return x + return x.permute(2, 0, 1) def main(): @@ -28,18 +27,18 @@ def main(): device = torch.device("cpu") file_name = "transpose.onnx" - test_input = torch.randn(2, 3, device=device) + test_input = torch.arange(24, dtype=torch.float, device=device).reshape(2, 3, 4) torch.onnx.export(model, test_input, file_name, verbose=False, opset_version=16) - print("Finished exporting model to {}".format(file_name)) + print(f"Finished exporting model to {file_name}") # Output some test data for use in the test - print("Test input data: {}".format(test_input)) - print("Test input data shape: {}".format(test_input.shape)) + print(f"Test input data: {test_input}") + print(f"Test input data shape: {test_input.shape}") output = model.forward(test_input) - print("Test output data shape: {}".format(output.shape)) - print("Test output data: {}".format(output)) + print(f"Test output data shape: {output.shape}") + print(f"Test output data: {output}") if __name__ == '__main__': diff --git a/crates/burn-import/src/burn/node/unary.rs b/crates/burn-import/src/burn/node/unary.rs index e09a559b1c..7309a504d7 100644 --- a/crates/burn-import/src/burn/node/unary.rs +++ b/crates/burn-import/src/burn/node/unary.rs @@ -198,8 +198,9 @@ impl UnaryNode { Self::new(input, output, UnaryNodeKind::Tanh, Rc::new(function)) } - pub(crate) fn transpose(input: Type, output: Type) -> Self { - let function = move |input| quote! { #input.transpose() }; + pub(crate) fn transpose(input: Type, output: Type, perm: Vec) -> Self { + let perm = perm.to_tokens(); + let function = move |input| quote! { #input.permute(#perm) }; Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function)) } @@ -538,10 +539,11 @@ mod tests { UnaryNode::transpose( Type::Tensor(TensorType::new_float("tensor1", 4)), Type::Tensor(TensorType::new_float("tensor2", 4)), + vec![0, 3, 1, 2], ), quote! { pub fn forward(&self, tensor1: Tensor) -> Tensor { - let tensor2 = tensor1.transpose(); + let tensor2 = tensor1.permute([0, 3, 1, 2]); tensor2 } diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 5c4ac43cec..784186747e 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -786,3 +786,27 @@ pub fn shape_config(curr: &Node) -> (usize, usize) { (start_dim as usize, end_dim as usize) } + +pub fn transpose_config(curr: &Node) -> Vec { + if curr.inputs.len() != 1 { + panic!( + "Transpose: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // Extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Default: reverse the dimensions + let mut perm = (0..tensor.dim as i64).rev().collect::>(); + + if let Some(axes) = curr.attrs.get("perm") { + perm = axes.clone().into_i64s(); + } + + perm +} diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 3bd15a4576..fec995e97f 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -451,8 +451,9 @@ impl OnnxGraph { fn transpose_conversion(node: Node) -> UnaryNode { let input = node.inputs.first().unwrap().to_type(); let output = node.outputs.first().unwrap().to_type(); + let perm = transpose_config(&node); - UnaryNode::transpose(input, output) + UnaryNode::transpose(input, output, perm) } fn cast_conversion(node: Node) -> UnaryNode {