From 5506574a2b66c44e96caa3417220c20df84c23ef Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 15 Nov 2024 22:18:25 +0000 Subject: [PATCH] Avoid redundant zeroing of output buffer in GatherND --- src/ops/gather.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/ops/gather.rs b/src/ops/gather.rs index ad3ed612..4ae28d33 100644 --- a/src/ops/gather.rs +++ b/src/ops/gather.rs @@ -288,7 +288,7 @@ pub fn gather_nd( let out_slice_len = out_shape[out_shape.len() - out_slice_ndim..] .iter() .product(); - let mut output = Tensor::::zeros_in(pool, &out_shape); + let mut output = Tensor::::uninit_in(pool, &out_shape); let output_non_batch_dims = output.ndim() - batch_dims; let input_non_batch_dims = input.ndim() - batch_dims; @@ -297,6 +297,7 @@ pub fn gather_nd( // This allows the loop below to rely on index tuples being contiguous. let indices = indices.to_contiguous_in(pool).auto_return(pool); + let mut n_init = 0; for (mut output, (input, indices)) in output.inner_iter_dyn_mut(output_non_batch_dims).zip( input .inner_iter_dyn(input_non_batch_dims) @@ -313,12 +314,15 @@ pub fn gather_nd( .map_err(|_| OpError::InvalidValue("Invalid index"))?; for (out, x) in out_slice.iter_mut().zip(in_slice.iter()) { - *out = x.clone(); + out.write(x.clone()); } + n_init += out_slice.len(); } } - Ok(output) + // Safety: All elements of `output` are initialized. + assert!(n_init == output.len()); + Ok(unsafe { output.assume_init() }) } #[derive(Debug)]