Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/2361 segmentation mask #2426

Merged
merged 6 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 237 additions & 4 deletions crates/burn-dataset/src/vision/image_folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pub enum Annotation {
#[derive(Debug, Clone, PartialEq)]
pub struct SegmentationMask {
/// Segmentation mask.
pub mask: Vec<usize>,
pub mask: Vec<PixelDepth>,
anthonytorlucci marked this conversation as resolved.
Show resolved Hide resolved
}

/// Object detection bounding box annotation.
Expand Down Expand Up @@ -104,7 +104,8 @@ pub struct ImageDatasetItem {
enum AnnotationRaw {
Label(String),
MultiLabel(Vec<String>),
// TODO: bounding boxes and segmentation mask
SegmentationMask(PathBuf),
// TODO: bounding boxes
}

#[derive(Deserialize, Serialize, Debug, Clone)]
Expand All @@ -129,15 +130,77 @@ struct PathToImageDatasetItem {
classes: HashMap<String, usize>,
}

fn image_path_to_vec_pixel_depth(image_path: &PathBuf) -> Vec<PixelDepth> {
// Load image from disk
let image = image::open(image_path).unwrap();

// Image as Vec<PixelDepth>
let img_vec = match image.color() {
ColorType::L8 => image
.into_luma8()
.iter()
.map(|&x| PixelDepth::U8(x))
.collect(),
ColorType::La8 => image
.into_luma_alpha8()
.iter()
.map(|&x| PixelDepth::U8(x))
.collect(),
ColorType::L16 => image
.into_luma16()
.iter()
.map(|&x| PixelDepth::U16(x))
.collect(),
ColorType::La16 => image
.into_luma_alpha16()
.iter()
.map(|&x| PixelDepth::U16(x))
.collect(),
ColorType::Rgb8 => image
.into_rgb8()
.iter()
.map(|&x| PixelDepth::U8(x))
.collect(),
ColorType::Rgba8 => image
.into_rgba8()
.iter()
.map(|&x| PixelDepth::U8(x))
.collect(),
ColorType::Rgb16 => image
.into_rgb16()
.iter()
.map(|&x| PixelDepth::U16(x))
.collect(),
ColorType::Rgba16 => image
.into_rgba16()
.iter()
.map(|&x| PixelDepth::U16(x))
.collect(),
ColorType::Rgb32F => image
.into_rgb32f()
.iter()
.map(|&x| PixelDepth::F32(x))
.collect(),
ColorType::Rgba32F => image
.into_rgba32f()
.iter()
.map(|&x| PixelDepth::F32(x))
.collect(),
_ => panic!("Unrecognized image color type"),
laggui marked this conversation as resolved.
Show resolved Hide resolved
};

img_vec
}

/// Parse the image annotation to the corresponding type.
fn parse_image_annotation(
annotation: &AnnotationRaw,
classes: &HashMap<String, usize>,
) -> Annotation {
// TODO: add support for other annotations
// - [ ] Object bounding boxes
// - [ ] Segmentation mask
// For now, only image classification labels are supported.
// - [x] Segmentation mask
// For now, only image classification labels and segmentation are supported.

// Map class string to label id
match annotation {
Expand All @@ -148,6 +211,20 @@ fn parse_image_annotation(
.map(|name| *classes.get(name).unwrap())
.collect(),
),
AnnotationRaw::SegmentationMask(mask_path) => {
let mask_image = image_path_to_vec_pixel_depth(mask_path);
// assume that each channel in the mask image is the same and
// each pixel in the first channel corresponds to a class.
// multi-channel image segmentation is not supported at this time.
Annotation::SegmentationMask(SegmentationMask {
mask: mask_image
.into_iter()
.enumerate()
.filter(|(i, _)| i % 3 == 0)
.map(|(_, pixel)| pixel)
.collect(),
})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The filtering here will probably not be required given the suggested changes in the previous comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filtering was removed and put in the segmentation_mask_to_vec_usize. fixed in the next commit.

}
}
}

Expand All @@ -160,6 +237,7 @@ impl Mapper<ImageDatasetItemRaw, ImageDatasetItem> for PathToImageDatasetItem {
let image = image::open(&item.image_path).unwrap();

// Image as Vec<PixelDepth>
// NOTE: the following logic has been copied to a separate function to be used for Segmentation Masks as well
laggui marked this conversation as resolved.
Show resolved Hide resolved
let img_vec = match image.color() {
ColorType::L8 => image
.into_luma8()
Expand Down Expand Up @@ -401,6 +479,36 @@ impl ImageFolderDataset {
Self::with_items(items, classes)
}

/// Create an image segmentation dataset with the specified items.
///
/// # Arguments
///
/// * `items` - List of dataset items, each item represented by a tuple `(image path, labels)`.
laggui marked this conversation as resolved.
Show resolved Hide resolved
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
pub fn new_segmentation_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, P)>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
Comment on lines +446 to +449
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! That's exactly what I meant, so there should not be any memory issues and users are not forced to pre-load all of their segmentation masks into memory to create a dataset 👍

// Parse items and check valid image extension types
let items = items
.into_iter()
.map(|(image_path, mask_path)| {
// Map image path and segmentation mask path
let image_path = image_path.as_ref();
let annotation = AnnotationRaw::SegmentationMask(mask_path.as_ref().to_path_buf());

Self::check_extension(&image_path.extension().unwrap().to_str().unwrap())?;

Ok(ImageDatasetItemRaw::new(image_path, annotation))
})
.collect::<Result<Vec<_>, _>>()?;

Self::with_items(items, classes)
}

/// Create an image dataset with the specified items.
///
/// # Arguments
Expand Down Expand Up @@ -451,6 +559,7 @@ impl ImageFolderDataset {
mod tests {
use super::*;
const DATASET_ROOT: &str = "tests/data/image_folder";
const SEGMASK_ROOT: &str = "tests/data/segmask_folder";

#[test]
pub fn image_folder_dataset() {
Expand Down Expand Up @@ -611,4 +720,128 @@ mod tests {
Annotation::MultiLabel(vec![0, 2])
);
}

#[test]
pub fn segmask_image_path_to_vec_pixel_depth() {
let root = Path::new(SEGMASK_ROOT);
// test checkerboard mask
const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64 * 3] = [
1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1,
1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2,
2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1,
1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1,
1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1,
2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1,
];
assert_eq!(
TEST_CHECKERBOARD_MASK_PATTERN
.iter()
.map(|&x| PixelDepth::U8(x))
.collect::<Vec<PixelDepth>>(),
image_path_to_vec_pixel_depth(&root.join("annotations").join("mask_checkerboard.png")),
);

// checkerboard image
// TODO: investigate why the channels appear to be reversed, i.e (blue, green, red) rather than (red, green, blue)
anthonytorlucci marked this conversation as resolved.
Show resolved Hide resolved
const TEST_CHECKERBOARD_IMAGE_PATTERN: [u8; 64 * 3] = [
220, 20, 60, 0, 255, 255, 220, 20, 60, 0, 255, 255, 220, 20, 60, 0, 255, 255, 220, 20,
60, 0, 255, 255, 0, 255, 255, 220, 20, 60, 0, 255, 255, 220, 20, 60, 0, 255, 255, 220,
20, 60, 0, 255, 255, 220, 20, 60, 220, 20, 60, 0, 255, 255, 220, 20, 60, 0, 255, 255,
220, 20, 60, 0, 255, 255, 220, 20, 60, 0, 255, 255, 0, 255, 255, 220, 20, 60, 0, 255,
255, 220, 20, 60, 0, 255, 255, 220, 20, 60, 0, 255, 255, 220, 20, 60, 220, 20, 60, 0,
255, 255, 220, 20, 60, 0, 255, 255, 220, 20, 60, 0, 255, 255, 220, 20, 60, 0, 255, 255,
0, 255, 255, 220, 20, 60, 0, 255, 255, 220, 20, 60, 0, 255, 255, 220, 20, 60, 0, 255,
255, 220, 20, 60, 220, 20, 60, 0, 255, 255, 220, 20, 60, 0, 255, 255, 220, 20, 60, 0,
255, 255, 220, 20, 60, 0, 255, 255, 0, 255, 255, 220, 20, 60, 0, 255, 255, 220, 20, 60,
0, 255, 255, 220, 20, 60, 0, 255, 255, 220, 20, 60,
];
assert_eq!(
TEST_CHECKERBOARD_IMAGE_PATTERN
.iter()
.map(|&x| PixelDepth::U8(x))
.collect::<Vec<PixelDepth>>(),
image_path_to_vec_pixel_depth(&root.join("images").join("image_checkerboard.png")),
);
}

#[test]
pub fn segmask_folder_dataset() {
let root = Path::new(SEGMASK_ROOT);

let items = vec![
(
root.join("images").join("image_checkerboard.png"),
root.join("annotations").join("mask_checkerboard.png"),
),
(
root.join("images").join("image_random_2colors.png"),
root.join("annotations").join("mask_random_2colors.png"),
),
(
root.join("images").join("image_random_3colors.png"),
root.join("annotations").join("mask_random_3colors.png"),
),
];
let dataset = ImageFolderDataset::new_segmentation_with_items(
items,
&[
"foo", // 0
"bar", // 1
"baz", // 2
"qux", // 3
],
)
.unwrap();

// Dataset has 3 elements; each (image, annotation) is a single item
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get(3), None);

// checkerboard mask
const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64] = [
1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2,
1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1,
2, 1, 2, 1, 2, 1,
];
assert_eq!(
dataset.get(0).unwrap().annotation,
Annotation::SegmentationMask(SegmentationMask {
mask: TEST_CHECKERBOARD_MASK_PATTERN
.iter()
.map(|&x| PixelDepth::U8(x))
.collect()
})
);
// random 2 colors mask
const TEST_RANDOM2COLORS_MASK_PATTERN: [u8; 64] = [
1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2,
2, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1,
1, 1, 1, 1, 1, 1,
];
assert_eq!(
dataset.get(1).unwrap().annotation,
Annotation::SegmentationMask(SegmentationMask {
mask: TEST_RANDOM2COLORS_MASK_PATTERN
.iter()
.map(|&x| PixelDepth::U8(x))
.collect()
})
);
// random 3 colors mask
const TEST_RANDOM3COLORS_MASK_PATTERN: [u8; 64] = [
3, 1, 3, 3, 1, 1, 3, 2, 3, 3, 3, 3, 1, 3, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 3, 3,
3, 2, 3, 2, 2, 3, 2, 3, 3, 1, 3, 1, 3, 3, 1, 1, 3, 2, 1, 2, 2, 2, 1, 2, 1, 2, 3, 3, 1,
3, 3, 2, 1, 2, 2,
];
assert_eq!(
dataset.get(2).unwrap().annotation,
Annotation::SegmentationMask(SegmentationMask {
mask: TEST_RANDOM3COLORS_MASK_PATTERN
.iter()
.map(|&x| PixelDepth::U8(x))
.collect()
})
);
}
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
1 2 1 1 1 2 1 1
1 2 1 1 1 1 2 1
2 2 2 1 2 1 2 2
2 2 2 2 2 2 1 1
2 2 2 1 2 1 1 1
1 1 2 2 2 2 2 1
2 2 1 2 1 2 1 2
2 1 1 1 1 1 1 1
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
3 1 3 3 1 1 3 2
3 3 3 3 1 3 2 1
2 2 2 2 1 1 2 2
1 1 1 3 3 3 2 3
2 2 3 2 3 3 1 3
1 3 3 1 1 3 2 1
2 2 2 1 2 1 2 3
3 1 3 3 2 1 2 2
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.