Skip to content

Commit

Permalink
Change ndarray mask_where implementation to correctly deal with NaNs (#…
Browse files Browse the repository at this point in the history
…2272)

* Change ndarray mask_where implementation to correctly deal with NaNs

* Add test
  • Loading branch information
laggui authored Sep 13, 2024
1 parent 2fbad48 commit 6f0e61a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
19 changes: 8 additions & 11 deletions crates/burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,17 +409,14 @@ where
mask: NdArrayTensor<bool, D>,
source: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
let mask_mul_4tensor = mask.array.mapv(|x| match x {
true => 0.elem(),
false => 1.elem(),
});
let mask_mul_4source = mask.array.mapv(|x| match x {
true => 1.elem(),
false => 0.elem(),
});
let array = (tensor.array * mask_mul_4tensor) + (source.array * mask_mul_4source);

NdArrayTensor::new(array)
let tensor = tensor.array.broadcast(mask.array.dim()).unwrap();
let source = source.array.broadcast(mask.array.dim()).unwrap();
let output = Zip::from(&tensor)
.and(&mask.array)
.and(&source)
.map_collect(|&x, &mask_val, &y| if mask_val { y } else { x })
.into_shared();
NdArrayTensor::new(output)
}

pub fn mask_fill<const D: usize>(
Expand Down
34 changes: 34 additions & 0 deletions crates/burn-tensor/src/tests/ops/mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,40 @@ mod tests {
output.into_data().assert_eq(&expected, false);
}

#[test]
fn should_handle_mask_where_nans() {
let device = Default::default();
let tensor = TestTensor::from_data(
[
[f32::NAN, f32::NAN, f32::NAN],
[f32::NAN, f32::NAN, f32::NAN],
[f32::NAN, f32::NAN, f32::NAN],
],
&device,
);
let mask = Tensor::<TestBackend, 2, Bool>::from_bool(
TensorData::from([
[true, true, true],
[true, true, false],
[false, false, false],
]),
&device,
);
let value = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]]),
&device,
);

let output = tensor.mask_where(mask, value);
let expected = TensorData::from([
[0.9, 0.8, 0.7],
[0.6, 0.5, f32::NAN],
[f32::NAN, f32::NAN, f32::NAN],
]);

output.into_data().assert_eq(&expected, false);
}

#[test]
fn should_support_mask_fill_ops() {
let device = Default::default();
Expand Down

0 comments on commit 6f0e61a

Please sign in to comment.