diff --git a/src/ops/norm.rs b/src/ops/norm.rs index 1bf204b9..8393b87a 100644 --- a/src/ops/norm.rs +++ b/src/ops/norm.rs @@ -705,6 +705,27 @@ mod tests { [1.0744, 0.0683], ]])), }, + // Normalize multiple axes + Case { + // Sample values generated using `torch.rand`. + input: Tensor::from([[ + [0.9562, 0.0572], + [0.4366, 0.5655], + [0.2017, 0.0230], + [0.7941, 0.1554], + [0.3226, 0.120], + ]]), + scale: Tensor::full(&[5, 2], 1.1), + bias: Some(Tensor::full(&[5, 2], 0.1)), + axis: -2, + expected: Ok(Tensor::from([[ + [2.2467697, -1.0079411], + [0.36562642, 0.83229196], + [-0.48479798, -1.1317577], + [1.6599079, -0.65242106], + [-0.04709549, -0.7805821], + ]])), + }, // Unsupported scale shape Case { input: Tensor::from([[1., 2., 3.], [4., 5., 6.]]), @@ -745,7 +766,9 @@ mod tests { ); match (result, expected) { - (Ok(result), Ok(expected)) => expect_eq_1e4(&result, &expected)?, + (Ok(result), Ok(expected)) => { + expect_eq_1e4(&result, &expected)?; + } (result, expected) => assert_eq!(result, expected), } }