Skip to content

Commit

Permalink
Merge pull request #13 from tracel-ai/resnet
Browse files Browse the repository at this point in the history
Add ResNet implementation with pre-trained weights
  • Loading branch information
nathanielsimard authored Feb 10, 2024
2 parents acaed68 + ec5fe2b commit 14eecd2
Show file tree
Hide file tree
Showing 13 changed files with 1,671 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ examples constructed using the [Burn](https://github.com/burn-rs/burn) deep lear
| Model | Description | Repository Link |
| ---------------------------------------------- | ------------------------------------------------- | -------------------------------------------- |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |

## Community Contributions

Expand Down
22 changes: 22 additions & 0 deletions resnet-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
authors = ["guillaumelagrange <[email protected]>"]
license = "MIT OR Apache-2.0"
name = "resnet-burn"
version = "0.1.0"
edition = "2021"

[features]
default = ["burn/default"]


[dependencies]
# Note: default-features = false is needed to disable std
burn = { git = "https://github.com/tracel-ai/burn.git", rev = "75cb5b6d5633c1c6092cf5046419da75e7f74b11", default-features = false }
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

[dev-dependencies]
burn-import = { git = "https://github.com/tracel-ai/burn.git", rev = "75cb5b6d5633c1c6092cf5046419da75e7f74b11"}
image = { version = "0.24.7", features = ["png", "jpeg"] }
1 change: 1 addition & 0 deletions resnet-burn/LICENSE-APACHE
1 change: 1 addition & 0 deletions resnet-burn/LICENSE-MIT
16 changes: 16 additions & 0 deletions resnet-burn/NOTICES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# NOTICES AND INFORMATION

This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided.

## Sample Image

Image Title: Standing yellow Labrador Retriever dog.
Author: Djmirko
Source: https://commons.wikimedia.org/wiki/File:YellowLabradorLooking_new.jpg
License: https://creativecommons.org/licenses/by-sa/3.0/

## Pre-trained Model

The ImageNet pre-trained model was ported from [`torchvision.models.ResNet18_Weights.IMAGENET1K_V1`](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18.html#torchvision.models.ResNet18_Weights).

As opposed to [other pre-trained models](https://pytorch.org/vision/stable/models/generated/torchvision.models.regnet_y_128gf.html#torchvision.models.RegNet_Y_128GF_Weights) in `torchvision`, no specific license was linked to the weights, which are assumed to be under the library's [BSD-3-Clause license](https://github.com/pytorch/vision/blob/main/LICENSE) ([ref](https://github.com/pytorch/vision/issues/160)).
33 changes: 33 additions & 0 deletions resnet-burn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# ResNet Burn

To this day, [ResNet](https://arxiv.org/abs/1512.03385)s are still a strong baseline for your image
classification tasks. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for
the ResNet variants in [src/model/resnet.rs](src/model/resnet.rs).

The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html).

## Usage

### `Cargo.toml`

Add this to your `Cargo.toml`:

```toml
[dependencies]
resnet-burn = { git = "https://github.com/burn-rs/models", package = "resnet-burn", default-features = false }
```

### Example Usage

The [inference example](examples/inference.rs) initializes a ResNet-18 with the `NdArray` backend,
imports the ImageNet pre-trained weights from
[`torchvision`](https://download.pytorch.org/models/resnet18-f37072fd.pth) and performs inference on
the provided input image.

After downloading the
[pre-trained weights](https://download.pytorch.org/models/resnet18-f37072fd.pth) to the current
directory, you can run the example with the following command:

```sh
cargo run --release --example inference samples/dog.jpg
```
100 changes: 100 additions & 0 deletions resnet-burn/examples/inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use resnet_burn::model::{imagenet, resnet::ResNet};

use burn::{
backend::NdArray,
module::Module,
record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder},
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor},
};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};

const TORCH_WEIGHTS: &str = "resnet18-f37072fd.pth";
const MODEL_PATH: &str = "resnet18-ImageNet1k";
const NUM_CLASSES: usize = 1000;
const HEIGHT: usize = 224;
const WIDTH: usize = 224;

fn to_tensor<B: Backend, T: Element>(
data: Vec<T>,
shape: [usize; 3],
device: &Device<B>,
) -> Tensor<B, 3> {
Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device)
// permute(2, 0, 1)
.swap_dims(2, 1) // [H, C, W]
.swap_dims(1, 0) // [C, H, W]
/ 255 // normalize between [0, 1]
}

pub fn main() {
// Parse arguments
let img_path = std::env::args().nth(1).expect("No image path provided");

// Create ResNet-18
let device = Default::default();
let model: ResNet<NdArray, _> = ResNet::resnet18(NUM_CLASSES, &device);

// Load weights from torch state_dict
let load_args = LoadArgs::new(TORCH_WEIGHTS.into())
// Map *.downsample.0.* -> *.downsample.conv.*
.with_key_remap("(.+)\\.downsample\\.0\\.(.+)", "$1.downsample.conv.$2")
// Map *.downsample.1.* -> *.downsample.bn.*
.with_key_remap("(.+)\\.downsample\\.1\\.(.+)", "$1.downsample.bn.$2")
// Map layer[i].[j].* -> layer[i].blocks.[j].*
.with_key_remap("(layer[1-4])\\.([0-9])\\.(.+)", "$1.blocks.$2.$3");
let record = PyTorchFileRecorder::<FullPrecisionSettings>::new()
.load(load_args, &device)
.map_err(|err| format!("Failed to load weights.\nError: {err}"))
.unwrap();

let model = model.load_record(record);

// Save the model to a supported format and load it back
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
.clone() // `save_file` takes ownership but we want to load the file after
.save_file(MODEL_PATH, &recorder)
.map_err(|err| format!("Failed to save weights to file {MODEL_PATH}.\nError: {err}"))
.unwrap();
let model = model
.load_file(MODEL_PATH, &recorder, &device)
.map_err(|err| format!("Failed to load weights from file {MODEL_PATH}.\nError: {err}"))
.unwrap();

// Load image
let img = image::open(&img_path)
.map_err(|err| format!("Failed to load image {img_path}.\nError: {err}"))
.unwrap();

// Resize to 224x224
let resized_img = img.resize_exact(
WIDTH as u32,
HEIGHT as u32,
image::imageops::FilterType::Triangle, // also known as bilinear in 2D
);

// Create tensor from image data
let img_tensor = to_tensor(
resized_img.into_rgb8().into_raw(),
[HEIGHT, WIDTH, 3],
&device,
)
.unsqueeze::<4>(); // [B, C, H, W]

// Normalize the image
let x = imagenet::Normalizer::new(&device).normalize(img_tensor);

// Forward pass
let out = model.forward(x);

// Output class index w/ score (raw)
let (score, idx) = out.max_dim_with_indices(1);
let idx = idx.into_scalar() as usize;

println!(
"Predicted: {}\nCategory Id: {}\nScore: {:.4}",
imagenet::CLASSES[idx],
idx,
score.into_scalar()
);
}
Binary file added resnet-burn/samples/dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions resnet-burn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#![no_std]
pub mod model;
extern crate alloc;
Loading

0 comments on commit 14eecd2

Please sign in to comment.