Skip to content

Commit

Permalink
[rust] Allows -2 as dims for sum() (#3221)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored May 28, 2024
1 parent 5d951ba commit 45e41bc
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion extensions/tokenizers/rust/src/ndarray/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,17 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_sumWithAxis<'local>(
keep_dims: jboolean,
) -> jlong {
let tensor = cast_handle::<Tensor>(handle);
let rank = tensor.shape().rank() as i32;
let axes = unsafe { env.get_array_elements(&axes, ReleaseMode::NoCopyBack) }.unwrap();
let dims = axes
.into_iter()
.map(|i| *i as usize)
.map(|i| {
let mut dim = *i as i32;
if dim < 0 {
dim = rank + dim;
}
return dim as usize;
})
.collect::<Vec<usize>>();

let ret = if keep_dims == JNI_TRUE {
Expand Down

0 comments on commit 45e41bc

Please sign in to comment.