Skip to content

Commit

Permalink
Avoid redundant zeroing of output buffer in GatherND
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Nov 15, 2024
1 parent bf73c6f commit 5506574
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/ops/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ pub fn gather_nd<T: Clone + Default>(
let out_slice_len = out_shape[out_shape.len() - out_slice_ndim..]
.iter()
.product();
let mut output = Tensor::<T>::zeros_in(pool, &out_shape);
let mut output = Tensor::<T>::uninit_in(pool, &out_shape);

let output_non_batch_dims = output.ndim() - batch_dims;
let input_non_batch_dims = input.ndim() - batch_dims;
Expand All @@ -297,6 +297,7 @@ pub fn gather_nd<T: Clone + Default>(
// This allows the loop below to rely on index tuples being contiguous.
let indices = indices.to_contiguous_in(pool).auto_return(pool);

let mut n_init = 0;
for (mut output, (input, indices)) in output.inner_iter_dyn_mut(output_non_batch_dims).zip(
input
.inner_iter_dyn(input_non_batch_dims)
Expand All @@ -313,12 +314,15 @@ pub fn gather_nd<T: Clone + Default>(
.map_err(|_| OpError::InvalidValue("Invalid index"))?;

for (out, x) in out_slice.iter_mut().zip(in_slice.iter()) {
*out = x.clone();
out.write(x.clone());
}
n_init += out_slice.len();
}
}

Ok(output)
// Safety: All elements of `output` are initialized.
assert!(n_init == output.len());
Ok(unsafe { output.assume_init() })
}

#[derive(Debug)]
Expand Down

0 comments on commit 5506574

Please sign in to comment.