Skip to content

Commit

Permalink
Add test for normalizing multiple axes in LayerNormalization
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Dec 17, 2024
1 parent 2b61acf commit 7685b8b
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion src/ops/norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,27 @@ mod tests {
[1.0744, 0.0683],
]])),
},
// Normalize multiple axes
Case {
// Sample values generated using `torch.rand`.
input: Tensor::from([[
[0.9562, 0.0572],
[0.4366, 0.5655],
[0.2017, 0.0230],
[0.7941, 0.1554],
[0.3226, 0.120],
]]),
scale: Tensor::full(&[5, 2], 1.1),
bias: Some(Tensor::full(&[5, 2], 0.1)),
axis: -2,
expected: Ok(Tensor::from([[
[2.2467697, -1.0079411],
[0.36562642, 0.83229196],
[-0.48479798, -1.1317577],
[1.6599079, -0.65242106],
[-0.04709549, -0.7805821],
]])),
},
// Unsupported scale shape
Case {
input: Tensor::from([[1., 2., 3.], [4., 5., 6.]]),
Expand Down Expand Up @@ -745,7 +766,9 @@ mod tests {
);

match (result, expected) {
(Ok(result), Ok(expected)) => expect_eq_1e4(&result, &expected)?,
(Ok(result), Ok(expected)) => {
expect_eq_1e4(&result, &expected)?;
}
(result, expected) => assert_eq!(result, expected),
}
}
Expand Down

0 comments on commit 7685b8b

Please sign in to comment.