-
-
Notifications
You must be signed in to change notification settings - Fork 109
/
Copy pathmain.rs
91 lines (78 loc) · 3.02 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
use std::{ops::Mul, path::Path};
use cudarc::driver::{sys::CUdeviceptr, CudaDevice, DevicePtr, DevicePtrMut};
use image::{imageops::FilterType, GenericImageView, ImageBuffer, Rgba};
use ndarray::Array;
use ort::{AllocationDevice, AllocatorType, CUDAExecutionProvider, ExecutionProvider, MemoryInfo, MemoryType, Session, TensorRefMut};
use show_image::{event, AsImageView, WindowOptions};
#[show_image::main]
fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
ort::init()
.with_execution_providers([CUDAExecutionProvider::default().build().error_on_failure()])
.commit()?;
let model =
Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/modnet_photographic_portrait_matting.onnx")?;
let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("photo.jpg")).unwrap();
let (img_width, img_height) = (original_img.width(), original_img.height());
let img = original_img.resize_exact(512, 512, FilterType::Triangle);
let mut input = Array::zeros((1, 3, 512, 512));
for pixel in img.pixels() {
let x = pixel.0 as _;
let y = pixel.1 as _;
let [r, g, b, _] = pixel.2.0;
input[[0, 0, y, x]] = (r as f32 - 127.5) / 127.5;
input[[0, 1, y, x]] = (g as f32 - 127.5) / 127.5;
input[[0, 2, y, x]] = (b as f32 - 127.5) / 127.5;
}
let device = CudaDevice::new(0)?;
let device_data = device.htod_sync_copy(&input.into_raw_vec())?;
let tensor: TensorRefMut<'_, f32> = unsafe {
TensorRefMut::from_raw(
MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)?,
(*device_data.device_ptr() as usize as *mut ()).cast(),
vec![1, 3, 512, 512]
)
.unwrap()
};
let outputs = model.run([tensor.into()])?;
let output = outputs["output"].try_extract_tensor::<f32>()?;
// convert to 8-bit
let output = output.mul(255.0).map(|x| *x as u8);
let output = output.into_raw_vec();
// change rgb to rgba
let output_img = ImageBuffer::from_fn(512, 512, |x, y| {
let i = (x + y * 512) as usize;
Rgba([output[i], output[i], output[i], 255])
});
let mut output = image::imageops::resize(&output_img, img_width, img_height, FilterType::Triangle);
output.enumerate_pixels_mut().for_each(|(x, y, pixel)| {
let origin = original_img.get_pixel(x, y);
pixel.0[3] = pixel.0[0];
pixel.0[0] = origin.0[0];
pixel.0[1] = origin.0[1];
pixel.0[2] = origin.0[2];
});
let window = show_image::context()
.run_function_wait(move |context| -> Result<_, String> {
let mut window = context
.create_window(
"ort + modnet",
WindowOptions {
size: Some([img_width, img_height]),
..WindowOptions::default()
}
)
.map_err(|e| e.to_string())?;
window.set_image("photo", &output.as_image_view().map_err(|e| e.to_string())?);
Ok(window.proxy())
})
.unwrap();
for event in window.event_channel().unwrap() {
if let event::WindowEvent::KeyboardInput(event) = event {
if event.input.key_code == Some(event::VirtualKeyCode::Escape) && event.input.state.is_pressed() {
break;
}
}
}
Ok(())
}