Skip to content

Commit

Permalink
Add segmentation mask to ImageFolderDataset (#2426)
Browse files Browse the repository at this point in the history
* 2361-SegmentationMask implementation and initial test

* 2361-SegmentationMask validated tests for test data

* 2361-SegmentationMask removed unnecessary serialize/deserialize

* 2361-SegmentationMask raw mask as path rather than Vec<_>

* 2361-SegmentationMask updated synthetic images and fixed test

* 2361-SegmentationMask rever back to Vec<usize>
  • Loading branch information
anthonytorlucci authored Nov 11, 2024
1 parent b4fa1fc commit 6e71aaf
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 3 deletions.
200 changes: 197 additions & 3 deletions crates/burn-dataset/src/vision/image_folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ pub struct ImageDatasetItem {
enum AnnotationRaw {
Label(String),
MultiLabel(Vec<String>),
// TODO: bounding boxes and segmentation mask
SegmentationMask(PathBuf),
// TODO: bounding boxes
}

#[derive(Deserialize, Serialize, Debug, Clone)]
Expand All @@ -129,15 +130,42 @@ struct PathToImageDatasetItem {
classes: HashMap<String, usize>,
}

fn segmentation_mask_to_vec_usize(mask_path: &PathBuf) -> Vec<usize> {
// Load image from disk
let image = image::open(mask_path).unwrap();

// Image as Vec<PixelDepth>
// if rgb8 or rgb16, keep only the first channel assuming all channels are the same
let img_vec = match image.color() {
ColorType::L8 => image.into_luma8().iter().map(|&x| x as usize).collect(),
ColorType::L16 => image.into_luma16().iter().map(|&x| x as usize).collect(),
ColorType::Rgb8 => image
.into_rgb8()
.iter()
.step_by(3)
.map(|&x| x as usize)
.collect(),
ColorType::Rgb16 => image
.into_rgb16()
.iter()
.step_by(3)
.map(|&x| x as usize)
.collect(),
_ => panic!("Unrecognized image color type"),
};

img_vec
}

/// Parse the image annotation to the corresponding type.
fn parse_image_annotation(
annotation: &AnnotationRaw,
classes: &HashMap<String, usize>,
) -> Annotation {
// TODO: add support for other annotations
// - [ ] Object bounding boxes
// - [ ] Segmentation mask
// For now, only image classification labels are supported.
// - [x] Segmentation mask
// For now, only image classification labels and segmentation are supported.

// Map class string to label id
match annotation {
Expand All @@ -148,6 +176,11 @@ fn parse_image_annotation(
.map(|name| *classes.get(name).unwrap())
.collect(),
),
AnnotationRaw::SegmentationMask(mask_path) => {
Annotation::SegmentationMask(SegmentationMask {
mask: segmentation_mask_to_vec_usize(mask_path),
})
}
}
}

Expand Down Expand Up @@ -401,6 +434,36 @@ impl ImageFolderDataset {
Self::with_items(items, classes)
}

/// Create an image segmentation dataset with the specified items.
///
/// # Arguments
///
/// * `items` - List of dataset items, each item represented by a tuple `(image path, annotation path)`.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
pub fn new_segmentation_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, P)>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// Parse items and check valid image extension types
let items = items
.into_iter()
.map(|(image_path, mask_path)| {
// Map image path and segmentation mask path
let image_path = image_path.as_ref();
let annotation = AnnotationRaw::SegmentationMask(mask_path.as_ref().to_path_buf());

Self::check_extension(&image_path.extension().unwrap().to_str().unwrap())?;

Ok(ImageDatasetItemRaw::new(image_path, annotation))
})
.collect::<Result<Vec<_>, _>>()?;

Self::with_items(items, classes)
}

/// Create an image dataset with the specified items.
///
/// # Arguments
Expand Down Expand Up @@ -451,6 +514,7 @@ impl ImageFolderDataset {
mod tests {
use super::*;
const DATASET_ROOT: &str = "tests/data/image_folder";
const SEGMASK_ROOT: &str = "tests/data/segmask_folder";

#[test]
pub fn image_folder_dataset() {
Expand Down Expand Up @@ -611,4 +675,134 @@ mod tests {
Annotation::MultiLabel(vec![0, 2])
);
}

#[test]
pub fn segmask_image_path_to_vec_usize() {
let root = Path::new(SEGMASK_ROOT);

// checkerboard mask
const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64] = [
1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2,
1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1,
2, 1, 2, 1, 2, 1,
];
assert_eq!(
TEST_CHECKERBOARD_MASK_PATTERN
.iter()
.map(|&x| x as usize)
.collect::<Vec<usize>>(),
segmentation_mask_to_vec_usize(&root.join("annotations").join("mask_checkerboard.png")),
);

// random 2 colors mask
const TEST_RANDOM2COLORS_MASK_PATTERN: [u8; 64] = [
1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2,
2, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1,
1, 1, 1, 1, 1, 1,
];
assert_eq!(
TEST_RANDOM2COLORS_MASK_PATTERN
.iter()
.map(|&x| x as usize)
.collect::<Vec<usize>>(),
segmentation_mask_to_vec_usize(
&root.join("annotations").join("mask_random_2colors.png")
),
);
// random 3 colors mask
const TEST_RANDOM3COLORS_MASK_PATTERN: [u8; 64] = [
3, 1, 3, 3, 1, 1, 3, 2, 3, 3, 3, 3, 1, 3, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 3, 3,
3, 2, 3, 2, 2, 3, 2, 3, 3, 1, 3, 1, 3, 3, 1, 1, 3, 2, 1, 2, 2, 2, 1, 2, 1, 2, 3, 3, 1,
3, 3, 2, 1, 2, 2,
];
assert_eq!(
TEST_RANDOM3COLORS_MASK_PATTERN
.iter()
.map(|&x| x as usize)
.collect::<Vec<usize>>(),
segmentation_mask_to_vec_usize(
&root.join("annotations").join("mask_random_3colors.png")
),
);
}

#[test]
pub fn segmask_folder_dataset() {
let root = Path::new(SEGMASK_ROOT);

let items = vec![
(
root.join("images").join("image_checkerboard.png"),
root.join("annotations").join("mask_checkerboard.png"),
),
(
root.join("images").join("image_random_2colors.png"),
root.join("annotations").join("mask_random_2colors.png"),
),
(
root.join("images").join("image_random_3colors.png"),
root.join("annotations").join("mask_random_3colors.png"),
),
];
let dataset = ImageFolderDataset::new_segmentation_with_items(
items,
&[
"foo", // 0
"bar", // 1
"baz", // 2
"qux", // 3
],
)
.unwrap();

// Dataset has 3 elements; each (image, annotation) is a single item
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get(3), None);

// checkerboard mask
const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64] = [
1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2,
1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1,
2, 1, 2, 1, 2, 1,
];
assert_eq!(
dataset.get(0).unwrap().annotation,
Annotation::SegmentationMask(SegmentationMask {
mask: TEST_CHECKERBOARD_MASK_PATTERN
.iter()
.map(|&x| x as usize)
.collect()
})
);
// random 2 colors mask
const TEST_RANDOM2COLORS_MASK_PATTERN: [u8; 64] = [
1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2,
2, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1,
1, 1, 1, 1, 1, 1,
];
assert_eq!(
dataset.get(1).unwrap().annotation,
Annotation::SegmentationMask(SegmentationMask {
mask: TEST_RANDOM2COLORS_MASK_PATTERN
.iter()
.map(|&x| x as usize)
.collect()
})
);
// random 3 colors mask
const TEST_RANDOM3COLORS_MASK_PATTERN: [u8; 64] = [
3, 1, 3, 3, 1, 1, 3, 2, 3, 3, 3, 3, 1, 3, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 3, 3,
3, 2, 3, 2, 2, 3, 2, 3, 3, 1, 3, 1, 3, 3, 1, 1, 3, 2, 1, 2, 2, 2, 1, 2, 1, 2, 3, 3, 1,
3, 3, 2, 1, 2, 2,
];
assert_eq!(
dataset.get(2).unwrap().annotation,
Annotation::SegmentationMask(SegmentationMask {
mask: TEST_RANDOM3COLORS_MASK_PATTERN
.iter()
.map(|&x| x as usize)
.collect()
})
);
}
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1
1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
1 2 1 1 1 2 1 1
1 2 1 1 1 1 2 1
2 2 2 1 2 1 2 2
2 2 2 2 2 2 1 1
2 2 2 1 2 1 1 1
1 1 2 2 2 2 2 1
2 2 1 2 1 2 1 2
2 1 1 1 1 1 1 1
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
3 1 3 3 1 1 3 2
3 3 3 3 1 3 2 1
2 2 2 2 1 1 2 2
1 1 1 3 3 3 2 3
2 2 3 2 3 3 1 3
1 3 3 1 1 3 2 1
2 2 2 1 2 1 2 3
3 1 3 3 2 1 2 2
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 6e71aaf

Please sign in to comment.