diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index f850714f35..94b3d74eb3 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -104,7 +104,8 @@ pub struct ImageDatasetItem { enum AnnotationRaw { Label(String), MultiLabel(Vec), - // TODO: bounding boxes and segmentation mask + SegmentationMask(PathBuf), + // TODO: bounding boxes } #[derive(Deserialize, Serialize, Debug, Clone)] @@ -129,6 +130,33 @@ struct PathToImageDatasetItem { classes: HashMap, } +fn segmentation_mask_to_vec_usize(mask_path: &PathBuf) -> Vec { + // Load image from disk + let image = image::open(mask_path).unwrap(); + + // Image as Vec + // if rgb8 or rgb16, keep only the first channel assuming all channels are the same + let img_vec = match image.color() { + ColorType::L8 => image.into_luma8().iter().map(|&x| x as usize).collect(), + ColorType::L16 => image.into_luma16().iter().map(|&x| x as usize).collect(), + ColorType::Rgb8 => image + .into_rgb8() + .iter() + .step_by(3) + .map(|&x| x as usize) + .collect(), + ColorType::Rgb16 => image + .into_rgb16() + .iter() + .step_by(3) + .map(|&x| x as usize) + .collect(), + _ => panic!("Unrecognized image color type"), + }; + + img_vec +} + /// Parse the image annotation to the corresponding type. fn parse_image_annotation( annotation: &AnnotationRaw, @@ -136,8 +164,8 @@ fn parse_image_annotation( ) -> Annotation { // TODO: add support for other annotations // - [ ] Object bounding boxes - // - [ ] Segmentation mask - // For now, only image classification labels are supported. + // - [x] Segmentation mask + // For now, only image classification labels and segmentation are supported. // Map class string to label id match annotation { @@ -148,6 +176,11 @@ fn parse_image_annotation( .map(|name| *classes.get(name).unwrap()) .collect(), ), + AnnotationRaw::SegmentationMask(mask_path) => { + Annotation::SegmentationMask(SegmentationMask { + mask: segmentation_mask_to_vec_usize(mask_path), + }) + } } } @@ -401,6 +434,36 @@ impl ImageFolderDataset { Self::with_items(items, classes) } + /// Create an image segmentation dataset with the specified items. + /// + /// # Arguments + /// + /// * `items` - List of dataset items, each item represented by a tuple `(image path, annotation path)`. + /// * `classes` - Dataset class names. + /// + /// # Returns + /// A new dataset instance. + pub fn new_segmentation_with_items, S: AsRef>( + items: Vec<(P, P)>, + classes: &[S], + ) -> Result { + // Parse items and check valid image extension types + let items = items + .into_iter() + .map(|(image_path, mask_path)| { + // Map image path and segmentation mask path + let image_path = image_path.as_ref(); + let annotation = AnnotationRaw::SegmentationMask(mask_path.as_ref().to_path_buf()); + + Self::check_extension(&image_path.extension().unwrap().to_str().unwrap())?; + + Ok(ImageDatasetItemRaw::new(image_path, annotation)) + }) + .collect::, _>>()?; + + Self::with_items(items, classes) + } + /// Create an image dataset with the specified items. /// /// # Arguments @@ -451,6 +514,7 @@ impl ImageFolderDataset { mod tests { use super::*; const DATASET_ROOT: &str = "tests/data/image_folder"; + const SEGMASK_ROOT: &str = "tests/data/segmask_folder"; #[test] pub fn image_folder_dataset() { @@ -611,4 +675,134 @@ mod tests { Annotation::MultiLabel(vec![0, 2]) ); } + + #[test] + pub fn segmask_image_path_to_vec_usize() { + let root = Path::new(SEGMASK_ROOT); + + // checkerboard mask + const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64] = [ + 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, + 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, + 2, 1, 2, 1, 2, 1, + ]; + assert_eq!( + TEST_CHECKERBOARD_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect::>(), + segmentation_mask_to_vec_usize(&root.join("annotations").join("mask_checkerboard.png")), + ); + + // random 2 colors mask + const TEST_RANDOM2COLORS_MASK_PATTERN: [u8; 64] = [ + 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, + 2, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1, + 1, 1, 1, 1, 1, 1, + ]; + assert_eq!( + TEST_RANDOM2COLORS_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect::>(), + segmentation_mask_to_vec_usize( + &root.join("annotations").join("mask_random_2colors.png") + ), + ); + // random 3 colors mask + const TEST_RANDOM3COLORS_MASK_PATTERN: [u8; 64] = [ + 3, 1, 3, 3, 1, 1, 3, 2, 3, 3, 3, 3, 1, 3, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 3, 3, + 3, 2, 3, 2, 2, 3, 2, 3, 3, 1, 3, 1, 3, 3, 1, 1, 3, 2, 1, 2, 2, 2, 1, 2, 1, 2, 3, 3, 1, + 3, 3, 2, 1, 2, 2, + ]; + assert_eq!( + TEST_RANDOM3COLORS_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect::>(), + segmentation_mask_to_vec_usize( + &root.join("annotations").join("mask_random_3colors.png") + ), + ); + } + + #[test] + pub fn segmask_folder_dataset() { + let root = Path::new(SEGMASK_ROOT); + + let items = vec![ + ( + root.join("images").join("image_checkerboard.png"), + root.join("annotations").join("mask_checkerboard.png"), + ), + ( + root.join("images").join("image_random_2colors.png"), + root.join("annotations").join("mask_random_2colors.png"), + ), + ( + root.join("images").join("image_random_3colors.png"), + root.join("annotations").join("mask_random_3colors.png"), + ), + ]; + let dataset = ImageFolderDataset::new_segmentation_with_items( + items, + &[ + "foo", // 0 + "bar", // 1 + "baz", // 2 + "qux", // 3 + ], + ) + .unwrap(); + + // Dataset has 3 elements; each (image, annotation) is a single item + assert_eq!(dataset.len(), 3); + assert_eq!(dataset.get(3), None); + + // checkerboard mask + const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64] = [ + 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, + 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, + 2, 1, 2, 1, 2, 1, + ]; + assert_eq!( + dataset.get(0).unwrap().annotation, + Annotation::SegmentationMask(SegmentationMask { + mask: TEST_CHECKERBOARD_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect() + }) + ); + // random 2 colors mask + const TEST_RANDOM2COLORS_MASK_PATTERN: [u8; 64] = [ + 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, + 2, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1, + 1, 1, 1, 1, 1, 1, + ]; + assert_eq!( + dataset.get(1).unwrap().annotation, + Annotation::SegmentationMask(SegmentationMask { + mask: TEST_RANDOM2COLORS_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect() + }) + ); + // random 3 colors mask + const TEST_RANDOM3COLORS_MASK_PATTERN: [u8; 64] = [ + 3, 1, 3, 3, 1, 1, 3, 2, 3, 3, 3, 3, 1, 3, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 3, 3, + 3, 2, 3, 2, 2, 3, 2, 3, 3, 1, 3, 1, 3, 3, 1, 1, 3, 2, 1, 2, 2, 2, 1, 2, 1, 2, 3, 3, 1, + 3, 3, 2, 1, 2, 2, + ]; + assert_eq!( + dataset.get(2).unwrap().annotation, + Annotation::SegmentationMask(SegmentationMask { + mask: TEST_RANDOM3COLORS_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect() + }) + ); + } } diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.png b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.png new file mode 100644 index 0000000000..3c87252ebe Binary files /dev/null and b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.png differ diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.txt b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.txt new file mode 100644 index 0000000000..2e01635db9 --- /dev/null +++ b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.txt @@ -0,0 +1,8 @@ +1 2 1 2 1 2 1 2 +2 1 2 1 2 1 2 1 +1 2 1 2 1 2 1 2 +2 1 2 1 2 1 2 1 +1 2 1 2 1 2 1 2 +2 1 2 1 2 1 2 1 +1 2 1 2 1 2 1 2 +2 1 2 1 2 1 2 1 diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.png b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.png new file mode 100644 index 0000000000..f0a129ab26 Binary files /dev/null and b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.png differ diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.txt b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.txt new file mode 100644 index 0000000000..4fa2b7c26c --- /dev/null +++ b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.txt @@ -0,0 +1,8 @@ +1 2 1 1 1 2 1 1 +1 2 1 1 1 1 2 1 +2 2 2 1 2 1 2 2 +2 2 2 2 2 2 1 1 +2 2 2 1 2 1 1 1 +1 1 2 2 2 2 2 1 +2 2 1 2 1 2 1 2 +2 1 1 1 1 1 1 1 diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.png b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.png new file mode 100644 index 0000000000..38984e71ca Binary files /dev/null and b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.png differ diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.txt b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.txt new file mode 100644 index 0000000000..08f222b630 --- /dev/null +++ b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.txt @@ -0,0 +1,8 @@ +3 1 3 3 1 1 3 2 +3 3 3 3 1 3 2 1 +2 2 2 2 1 1 2 2 +1 1 1 3 3 3 2 3 +2 2 3 2 3 3 1 3 +1 3 3 1 1 3 2 1 +2 2 2 1 2 1 2 3 +3 1 3 3 2 1 2 2 diff --git a/crates/burn-dataset/tests/data/segmask_folder/images/image_checkerboard.png b/crates/burn-dataset/tests/data/segmask_folder/images/image_checkerboard.png new file mode 100644 index 0000000000..2087501e41 Binary files /dev/null and b/crates/burn-dataset/tests/data/segmask_folder/images/image_checkerboard.png differ diff --git a/crates/burn-dataset/tests/data/segmask_folder/images/image_random_2colors.png b/crates/burn-dataset/tests/data/segmask_folder/images/image_random_2colors.png new file mode 100644 index 0000000000..f433fd9321 Binary files /dev/null and b/crates/burn-dataset/tests/data/segmask_folder/images/image_random_2colors.png differ diff --git a/crates/burn-dataset/tests/data/segmask_folder/images/image_random_3colors.png b/crates/burn-dataset/tests/data/segmask_folder/images/image_random_3colors.png new file mode 100644 index 0000000000..880bac466f Binary files /dev/null and b/crates/burn-dataset/tests/data/segmask_folder/images/image_random_3colors.png differ