Skip to content

Commit

Permalink
[rust] Load model on given device
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Aug 16, 2024
1 parent a03a324 commit 2f3371b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
34 changes: 26 additions & 8 deletions extensions/tokenizers/rust/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@ fn load_model<'local>(
env: &mut JNIEnv,
model_path: JString,
dtype: jint,
device: JString,
) -> Result<Box<dyn Model>> {
let model_path: String = env
.get_string(&model_path)
.expect("Couldn't get java string!")
.into();
let device: String = env
.get_string(&device)
.expect("Couldn't get java string!")
.into();

let model_path = PathBuf::from(model_path);

Expand All @@ -71,13 +76,7 @@ fn load_model<'local>(
let config: Config = serde_json::from_str(&config).map_err(Error::msg)?;

// Get candle device
let device = if candle::utils::cuda_is_available() {
Device::new_cuda(0)
} else if candle::utils::metal_is_available() {
Device::new_metal(0)
} else {
Ok(Device::Cpu)
}?;
let device = as_device(&device).expect("Couldn't get device!");

// Get candle dtype
let dtype = as_data_type(dtype).unwrap();
Expand Down Expand Up @@ -171,8 +170,9 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_loadModel<'local>(
_: JObject,
model_path: JString,
dtype: jint,
device: JString,
) -> jlong {
let model = load_model(&mut env, model_path, dtype);
let model = load_model(&mut env, model_path, dtype, device);

match model {
Ok(output) => to_handle(output),
Expand Down Expand Up @@ -234,3 +234,21 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_runInference<'local>(
}
}
}

pub fn as_device(device: &String) -> Result<Device> {
if device.starts_with("gpu") {
if let Some(id_str) = device
.strip_prefix("gpu(")
.and_then(|s| s.strip_suffix(")"))
{
if let Ok(id) = id_str.parse::<usize>() {
return Device::new_cuda(id);
}
}
panic!("Invalid GPU format!");
} else if device == "cpu()" {
return Ok(Device::Cpu);
} else {
panic!("Unsupported device string!");
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
/** {@code RsModel} is the Rust implementation of {@link Model}. */
public class RsModel extends BaseModel {

private Device device;
private final AtomicReference<Long> handle;

/**
Expand All @@ -38,6 +39,7 @@ public class RsModel extends BaseModel {
*/
RsModel(String name, Device device) {
super(name);
this.device = device;
manager = RsNDManager.getSystemManager().newSubManager(device);
manager.setName("RsModel");
dataType = DataType.FLOAT16;
Expand All @@ -54,7 +56,9 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
}
setModelDir(modelPath);
if (block == null) {
handle.set(RustLibrary.loadModel(modelDir.toString(), dataType.ordinal()));
handle.set(
RustLibrary.loadModel(
modelDir.toString(), dataType.ordinal(), device.toString()));
block = new RsSymbolBlock((RsNDManager) manager, handle.get());
} else {
loadBlock(prefix, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ private RustLibrary() {}

public static native boolean isCudaAvailable();

public static native long loadModel(String modelPath, int dtype);
public static native long loadModel(String modelPath, int dtype, String device);

public static native long deleteModel(long handle);

Expand Down

0 comments on commit 2f3371b

Please sign in to comment.