From 2f3371b61f76dd70174dd389b3b149d63a083ed9 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Thu, 15 Aug 2024 19:01:56 -0700 Subject: [PATCH] [rust] Load model on given device --- extensions/tokenizers/rust/src/models/mod.rs | 34 ++++++++++++++----- .../main/java/ai/djl/engine/rust/RsModel.java | 6 +++- .../java/ai/djl/engine/rust/RustLibrary.java | 2 +- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/extensions/tokenizers/rust/src/models/mod.rs b/extensions/tokenizers/rust/src/models/mod.rs index cfae04cec36..bbe57e152aa 100644 --- a/extensions/tokenizers/rust/src/models/mod.rs +++ b/extensions/tokenizers/rust/src/models/mod.rs @@ -58,11 +58,16 @@ fn load_model<'local>( env: &mut JNIEnv, model_path: JString, dtype: jint, + device: JString, ) -> Result> { let model_path: String = env .get_string(&model_path) .expect("Couldn't get java string!") .into(); + let device: String = env + .get_string(&device) + .expect("Couldn't get java string!") + .into(); let model_path = PathBuf::from(model_path); @@ -71,13 +76,7 @@ fn load_model<'local>( let config: Config = serde_json::from_str(&config).map_err(Error::msg)?; // Get candle device - let device = if candle::utils::cuda_is_available() { - Device::new_cuda(0) - } else if candle::utils::metal_is_available() { - Device::new_metal(0) - } else { - Ok(Device::Cpu) - }?; + let device = as_device(&device).expect("Couldn't get device!"); // Get candle dtype let dtype = as_data_type(dtype).unwrap(); @@ -171,8 +170,9 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_loadModel<'local>( _: JObject, model_path: JString, dtype: jint, + device: JString, ) -> jlong { - let model = load_model(&mut env, model_path, dtype); + let model = load_model(&mut env, model_path, dtype, device); match model { Ok(output) => to_handle(output), @@ -234,3 +234,21 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_runInference<'local>( } } } + +pub fn as_device(device: &String) -> Result { + if device.starts_with("gpu") { + if let Some(id_str) = device + .strip_prefix("gpu(") + .and_then(|s| s.strip_suffix(")")) + { + if let Ok(id) = id_str.parse::() { + return Device::new_cuda(id); + } + } + panic!("Invalid GPU format!"); + } else if device == "cpu()" { + return Ok(Device::Cpu); + } else { + panic!("Unsupported device string!"); + }; +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java index be9430577fc..63f6b0e3333 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java @@ -28,6 +28,7 @@ /** {@code RsModel} is the Rust implementation of {@link Model}. */ public class RsModel extends BaseModel { + private Device device; private final AtomicReference handle; /** @@ -38,6 +39,7 @@ public class RsModel extends BaseModel { */ RsModel(String name, Device device) { super(name); + this.device = device; manager = RsNDManager.getSystemManager().newSubManager(device); manager.setName("RsModel"); dataType = DataType.FLOAT16; @@ -54,7 +56,9 @@ public void load(Path modelPath, String prefix, Map options) } setModelDir(modelPath); if (block == null) { - handle.set(RustLibrary.loadModel(modelDir.toString(), dataType.ordinal())); + handle.set( + RustLibrary.loadModel( + modelDir.toString(), dataType.ordinal(), device.toString())); block = new RsSymbolBlock((RsNDManager) manager, handle.get()); } else { loadBlock(prefix, options); diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java index 852c392f13c..440fe6bde03 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java @@ -22,7 +22,7 @@ private RustLibrary() {} public static native boolean isCudaAvailable(); - public static native long loadModel(String modelPath, int dtype); + public static native long loadModel(String modelPath, int dtype, String device); public static native long deleteModel(long handle);