-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from tracel-ai/resnet
Add ResNet implementation with pre-trained weights
- Loading branch information
Showing
13 changed files
with
1,671 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
[package] | ||
authors = ["guillaumelagrange <[email protected]>"] | ||
license = "MIT OR Apache-2.0" | ||
name = "resnet-burn" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[features] | ||
default = ["burn/default"] | ||
|
||
|
||
[dependencies] | ||
# Note: default-features = false is needed to disable std | ||
burn = { git = "https://github.com/tracel-ai/burn.git", rev = "75cb5b6d5633c1c6092cf5046419da75e7f74b11", default-features = false } | ||
serde = { version = "1.0.192", default-features = false, features = [ | ||
"derive", | ||
"alloc", | ||
] } # alloc is for no_std, derive is needed | ||
|
||
[dev-dependencies] | ||
burn-import = { git = "https://github.com/tracel-ai/burn.git", rev = "75cb5b6d5633c1c6092cf5046419da75e7f74b11"} | ||
image = { version = "0.24.7", features = ["png", "jpeg"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../LICENSE-APACHE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../LICENSE-MIT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# NOTICES AND INFORMATION | ||
|
||
This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided. | ||
|
||
## Sample Image | ||
|
||
Image Title: Standing yellow Labrador Retriever dog. | ||
Author: Djmirko | ||
Source: https://commons.wikimedia.org/wiki/File:YellowLabradorLooking_new.jpg | ||
License: https://creativecommons.org/licenses/by-sa/3.0/ | ||
|
||
## Pre-trained Model | ||
|
||
The ImageNet pre-trained model was ported from [`torchvision.models.ResNet18_Weights.IMAGENET1K_V1`](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18.html#torchvision.models.ResNet18_Weights). | ||
|
||
As opposed to [other pre-trained models](https://pytorch.org/vision/stable/models/generated/torchvision.models.regnet_y_128gf.html#torchvision.models.RegNet_Y_128GF_Weights) in `torchvision`, no specific license was linked to the weights, which are assumed to be under the library's [BSD-3-Clause license](https://github.com/pytorch/vision/blob/main/LICENSE) ([ref](https://github.com/pytorch/vision/issues/160)). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# ResNet Burn | ||
|
||
To this day, [ResNet](https://arxiv.org/abs/1512.03385)s are still a strong baseline for your image | ||
classification tasks. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for | ||
the ResNet variants in [src/model/resnet.rs](src/model/resnet.rs). | ||
|
||
The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html). | ||
|
||
## Usage | ||
|
||
### `Cargo.toml` | ||
|
||
Add this to your `Cargo.toml`: | ||
|
||
```toml | ||
[dependencies] | ||
resnet-burn = { git = "https://github.com/burn-rs/models", package = "resnet-burn", default-features = false } | ||
``` | ||
|
||
### Example Usage | ||
|
||
The [inference example](examples/inference.rs) initializes a ResNet-18 with the `NdArray` backend, | ||
imports the ImageNet pre-trained weights from | ||
[`torchvision`](https://download.pytorch.org/models/resnet18-f37072fd.pth) and performs inference on | ||
the provided input image. | ||
|
||
After downloading the | ||
[pre-trained weights](https://download.pytorch.org/models/resnet18-f37072fd.pth) to the current | ||
directory, you can run the example with the following command: | ||
|
||
```sh | ||
cargo run --release --example inference samples/dog.jpg | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
use resnet_burn::model::{imagenet, resnet::ResNet}; | ||
|
||
use burn::{ | ||
backend::NdArray, | ||
module::Module, | ||
record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder}, | ||
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor}, | ||
}; | ||
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}; | ||
|
||
const TORCH_WEIGHTS: &str = "resnet18-f37072fd.pth"; | ||
const MODEL_PATH: &str = "resnet18-ImageNet1k"; | ||
const NUM_CLASSES: usize = 1000; | ||
const HEIGHT: usize = 224; | ||
const WIDTH: usize = 224; | ||
|
||
fn to_tensor<B: Backend, T: Element>( | ||
data: Vec<T>, | ||
shape: [usize; 3], | ||
device: &Device<B>, | ||
) -> Tensor<B, 3> { | ||
Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device) | ||
// permute(2, 0, 1) | ||
.swap_dims(2, 1) // [H, C, W] | ||
.swap_dims(1, 0) // [C, H, W] | ||
/ 255 // normalize between [0, 1] | ||
} | ||
|
||
pub fn main() { | ||
// Parse arguments | ||
let img_path = std::env::args().nth(1).expect("No image path provided"); | ||
|
||
// Create ResNet-18 | ||
let device = Default::default(); | ||
let model: ResNet<NdArray, _> = ResNet::resnet18(NUM_CLASSES, &device); | ||
|
||
// Load weights from torch state_dict | ||
let load_args = LoadArgs::new(TORCH_WEIGHTS.into()) | ||
// Map *.downsample.0.* -> *.downsample.conv.* | ||
.with_key_remap("(.+)\\.downsample\\.0\\.(.+)", "$1.downsample.conv.$2") | ||
// Map *.downsample.1.* -> *.downsample.bn.* | ||
.with_key_remap("(.+)\\.downsample\\.1\\.(.+)", "$1.downsample.bn.$2") | ||
// Map layer[i].[j].* -> layer[i].blocks.[j].* | ||
.with_key_remap("(layer[1-4])\\.([0-9])\\.(.+)", "$1.blocks.$2.$3"); | ||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::new() | ||
.load(load_args, &device) | ||
.map_err(|err| format!("Failed to load weights.\nError: {err}")) | ||
.unwrap(); | ||
|
||
let model = model.load_record(record); | ||
|
||
// Save the model to a supported format and load it back | ||
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new(); | ||
model | ||
.clone() // `save_file` takes ownership but we want to load the file after | ||
.save_file(MODEL_PATH, &recorder) | ||
.map_err(|err| format!("Failed to save weights to file {MODEL_PATH}.\nError: {err}")) | ||
.unwrap(); | ||
let model = model | ||
.load_file(MODEL_PATH, &recorder, &device) | ||
.map_err(|err| format!("Failed to load weights from file {MODEL_PATH}.\nError: {err}")) | ||
.unwrap(); | ||
|
||
// Load image | ||
let img = image::open(&img_path) | ||
.map_err(|err| format!("Failed to load image {img_path}.\nError: {err}")) | ||
.unwrap(); | ||
|
||
// Resize to 224x224 | ||
let resized_img = img.resize_exact( | ||
WIDTH as u32, | ||
HEIGHT as u32, | ||
image::imageops::FilterType::Triangle, // also known as bilinear in 2D | ||
); | ||
|
||
// Create tensor from image data | ||
let img_tensor = to_tensor( | ||
resized_img.into_rgb8().into_raw(), | ||
[HEIGHT, WIDTH, 3], | ||
&device, | ||
) | ||
.unsqueeze::<4>(); // [B, C, H, W] | ||
|
||
// Normalize the image | ||
let x = imagenet::Normalizer::new(&device).normalize(img_tensor); | ||
|
||
// Forward pass | ||
let out = model.forward(x); | ||
|
||
// Output class index w/ score (raw) | ||
let (score, idx) = out.max_dim_with_indices(1); | ||
let idx = idx.into_scalar() as usize; | ||
|
||
println!( | ||
"Predicted: {}\nCategory Id: {}\nScore: {:.4}", | ||
imagenet::CLASSES[idx], | ||
idx, | ||
score.into_scalar() | ||
); | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#![no_std] | ||
pub mod model; | ||
extern crate alloc; |
Oops, something went wrong.