From 6e71aafbd57161fa82ca9cfc3c456e1dbb0ff0c6 Mon Sep 17 00:00:00 2001 From: Anthony Torlucci Date: Mon, 11 Nov 2024 11:23:31 -0600 Subject: [PATCH] Add segmentation mask to ImageFolderDataset (#2426) * 2361-SegmentationMask implementation and initial test * 2361-SegmentationMask validated tests for test data * 2361-SegmentationMask removed unnecessary serialize/deserialize * 2361-SegmentationMask raw mask as path rather than Vec<_> * 2361-SegmentationMask updated synthetic images and fixed test * 2361-SegmentationMask rever back to Vec --- .../burn-dataset/src/vision/image_folder.rs | 200 +++++++++++++++++- .../annotations/mask_checkerboard.png | Bin 0 -> 117 bytes .../annotations/mask_checkerboard.txt | 8 + .../annotations/mask_random_2colors.png | Bin 0 -> 123 bytes .../annotations/mask_random_2colors.txt | 8 + .../annotations/mask_random_3colors.png | Bin 0 -> 137 bytes .../annotations/mask_random_3colors.txt | 8 + .../images/image_checkerboard.png | Bin 0 -> 165 bytes .../images/image_random_2colors.png | Bin 0 -> 133 bytes .../images/image_random_3colors.png | Bin 0 -> 204 bytes 10 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.png create mode 100644 crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.txt create mode 100644 crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.png create mode 100644 crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.txt create mode 100644 crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.png create mode 100644 crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.txt create mode 100644 crates/burn-dataset/tests/data/segmask_folder/images/image_checkerboard.png create mode 100644 crates/burn-dataset/tests/data/segmask_folder/images/image_random_2colors.png create mode 100644 crates/burn-dataset/tests/data/segmask_folder/images/image_random_3colors.png diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index f850714f35..94b3d74eb3 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -104,7 +104,8 @@ pub struct ImageDatasetItem { enum AnnotationRaw { Label(String), MultiLabel(Vec), - // TODO: bounding boxes and segmentation mask + SegmentationMask(PathBuf), + // TODO: bounding boxes } #[derive(Deserialize, Serialize, Debug, Clone)] @@ -129,6 +130,33 @@ struct PathToImageDatasetItem { classes: HashMap, } +fn segmentation_mask_to_vec_usize(mask_path: &PathBuf) -> Vec { + // Load image from disk + let image = image::open(mask_path).unwrap(); + + // Image as Vec + // if rgb8 or rgb16, keep only the first channel assuming all channels are the same + let img_vec = match image.color() { + ColorType::L8 => image.into_luma8().iter().map(|&x| x as usize).collect(), + ColorType::L16 => image.into_luma16().iter().map(|&x| x as usize).collect(), + ColorType::Rgb8 => image + .into_rgb8() + .iter() + .step_by(3) + .map(|&x| x as usize) + .collect(), + ColorType::Rgb16 => image + .into_rgb16() + .iter() + .step_by(3) + .map(|&x| x as usize) + .collect(), + _ => panic!("Unrecognized image color type"), + }; + + img_vec +} + /// Parse the image annotation to the corresponding type. fn parse_image_annotation( annotation: &AnnotationRaw, @@ -136,8 +164,8 @@ fn parse_image_annotation( ) -> Annotation { // TODO: add support for other annotations // - [ ] Object bounding boxes - // - [ ] Segmentation mask - // For now, only image classification labels are supported. + // - [x] Segmentation mask + // For now, only image classification labels and segmentation are supported. // Map class string to label id match annotation { @@ -148,6 +176,11 @@ fn parse_image_annotation( .map(|name| *classes.get(name).unwrap()) .collect(), ), + AnnotationRaw::SegmentationMask(mask_path) => { + Annotation::SegmentationMask(SegmentationMask { + mask: segmentation_mask_to_vec_usize(mask_path), + }) + } } } @@ -401,6 +434,36 @@ impl ImageFolderDataset { Self::with_items(items, classes) } + /// Create an image segmentation dataset with the specified items. + /// + /// # Arguments + /// + /// * `items` - List of dataset items, each item represented by a tuple `(image path, annotation path)`. + /// * `classes` - Dataset class names. + /// + /// # Returns + /// A new dataset instance. + pub fn new_segmentation_with_items, S: AsRef>( + items: Vec<(P, P)>, + classes: &[S], + ) -> Result { + // Parse items and check valid image extension types + let items = items + .into_iter() + .map(|(image_path, mask_path)| { + // Map image path and segmentation mask path + let image_path = image_path.as_ref(); + let annotation = AnnotationRaw::SegmentationMask(mask_path.as_ref().to_path_buf()); + + Self::check_extension(&image_path.extension().unwrap().to_str().unwrap())?; + + Ok(ImageDatasetItemRaw::new(image_path, annotation)) + }) + .collect::, _>>()?; + + Self::with_items(items, classes) + } + /// Create an image dataset with the specified items. /// /// # Arguments @@ -451,6 +514,7 @@ impl ImageFolderDataset { mod tests { use super::*; const DATASET_ROOT: &str = "tests/data/image_folder"; + const SEGMASK_ROOT: &str = "tests/data/segmask_folder"; #[test] pub fn image_folder_dataset() { @@ -611,4 +675,134 @@ mod tests { Annotation::MultiLabel(vec![0, 2]) ); } + + #[test] + pub fn segmask_image_path_to_vec_usize() { + let root = Path::new(SEGMASK_ROOT); + + // checkerboard mask + const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64] = [ + 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, + 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, + 2, 1, 2, 1, 2, 1, + ]; + assert_eq!( + TEST_CHECKERBOARD_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect::>(), + segmentation_mask_to_vec_usize(&root.join("annotations").join("mask_checkerboard.png")), + ); + + // random 2 colors mask + const TEST_RANDOM2COLORS_MASK_PATTERN: [u8; 64] = [ + 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, + 2, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1, + 1, 1, 1, 1, 1, 1, + ]; + assert_eq!( + TEST_RANDOM2COLORS_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect::>(), + segmentation_mask_to_vec_usize( + &root.join("annotations").join("mask_random_2colors.png") + ), + ); + // random 3 colors mask + const TEST_RANDOM3COLORS_MASK_PATTERN: [u8; 64] = [ + 3, 1, 3, 3, 1, 1, 3, 2, 3, 3, 3, 3, 1, 3, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 3, 3, + 3, 2, 3, 2, 2, 3, 2, 3, 3, 1, 3, 1, 3, 3, 1, 1, 3, 2, 1, 2, 2, 2, 1, 2, 1, 2, 3, 3, 1, + 3, 3, 2, 1, 2, 2, + ]; + assert_eq!( + TEST_RANDOM3COLORS_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect::>(), + segmentation_mask_to_vec_usize( + &root.join("annotations").join("mask_random_3colors.png") + ), + ); + } + + #[test] + pub fn segmask_folder_dataset() { + let root = Path::new(SEGMASK_ROOT); + + let items = vec![ + ( + root.join("images").join("image_checkerboard.png"), + root.join("annotations").join("mask_checkerboard.png"), + ), + ( + root.join("images").join("image_random_2colors.png"), + root.join("annotations").join("mask_random_2colors.png"), + ), + ( + root.join("images").join("image_random_3colors.png"), + root.join("annotations").join("mask_random_3colors.png"), + ), + ]; + let dataset = ImageFolderDataset::new_segmentation_with_items( + items, + &[ + "foo", // 0 + "bar", // 1 + "baz", // 2 + "qux", // 3 + ], + ) + .unwrap(); + + // Dataset has 3 elements; each (image, annotation) is a single item + assert_eq!(dataset.len(), 3); + assert_eq!(dataset.get(3), None); + + // checkerboard mask + const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64] = [ + 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, + 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, + 2, 1, 2, 1, 2, 1, + ]; + assert_eq!( + dataset.get(0).unwrap().annotation, + Annotation::SegmentationMask(SegmentationMask { + mask: TEST_CHECKERBOARD_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect() + }) + ); + // random 2 colors mask + const TEST_RANDOM2COLORS_MASK_PATTERN: [u8; 64] = [ + 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, + 2, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1, + 1, 1, 1, 1, 1, 1, + ]; + assert_eq!( + dataset.get(1).unwrap().annotation, + Annotation::SegmentationMask(SegmentationMask { + mask: TEST_RANDOM2COLORS_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect() + }) + ); + // random 3 colors mask + const TEST_RANDOM3COLORS_MASK_PATTERN: [u8; 64] = [ + 3, 1, 3, 3, 1, 1, 3, 2, 3, 3, 3, 3, 1, 3, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 3, 3, + 3, 2, 3, 2, 2, 3, 2, 3, 3, 1, 3, 1, 3, 3, 1, 1, 3, 2, 1, 2, 2, 2, 1, 2, 1, 2, 3, 3, 1, + 3, 3, 2, 1, 2, 2, + ]; + assert_eq!( + dataset.get(2).unwrap().annotation, + Annotation::SegmentationMask(SegmentationMask { + mask: TEST_RANDOM3COLORS_MASK_PATTERN + .iter() + .map(|&x| x as usize) + .collect() + }) + ); + } } diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.png b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.png new file mode 100644 index 0000000000000000000000000000000000000000..3c87252ebe032f7f28544c96150a30651e7f55f4 GIT binary patch literal 117 zcmeAS@N?(olHy`uVBq!ia0vp^93afW1SGw4HSYi^8&4O<5Dr<@gN>XF1_I2AdcWW2 z$giqh5wLJGuc%hY3+Gm@xw$}|lRS`<1LTSR0&>!&wL(^WV7O!(qN_G@GcV8_22WQ% Jmvv4FO#oZUApZaW literal 0 HcmV?d00001 diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.txt b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.txt new file mode 100644 index 0000000000..2e01635db9 --- /dev/null +++ b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.txt @@ -0,0 +1,8 @@ +1 2 1 2 1 2 1 2 +2 1 2 1 2 1 2 1 +1 2 1 2 1 2 1 2 +2 1 2 1 2 1 2 1 +1 2 1 2 1 2 1 2 +2 1 2 1 2 1 2 1 +1 2 1 2 1 2 1 2 +2 1 2 1 2 1 2 1 diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.png b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.png new file mode 100644 index 0000000000000000000000000000000000000000..f0a129ab263b2438c63a5a860f1752d12e999e67 GIT binary patch literal 123 zcmeAS@N?(olHy`uVBq!ia0vp^93afW1SGw4HSYi^Cr=m05DwYYgN>Xm4h%;Qr2Kww zsjuGj$@rcY-&UqQp*B}e<;qNrTG?>>$0~;Hg&|o-AF>8iuDrt*lC?A}Zd>S|pN!kG WwbVCupA!L^#^CAd=d#Wzp$PzXXeti? literal 0 HcmV?d00001 diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.txt b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.txt new file mode 100644 index 0000000000..4fa2b7c26c --- /dev/null +++ b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.txt @@ -0,0 +1,8 @@ +1 2 1 1 1 2 1 1 +1 2 1 1 1 1 2 1 +2 2 2 1 2 1 2 2 +2 2 2 2 2 2 1 1 +2 2 2 1 2 1 1 1 +1 1 2 2 2 2 2 1 +2 2 1 2 1 2 1 2 +2 1 1 1 1 1 1 1 diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.png b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.png new file mode 100644 index 0000000000000000000000000000000000000000..38984e71cad9b1508bf8659605a18091146c8fa9 GIT binary patch literal 137 zcmeAS@N?(olHy`uVBq!ia0vp^93afW1SGw4HSYka08bak5Dr<>gN?Z?jsk}bA}`mU zF5{E7nYlZAHB&BQY4p3k`P*L2Tk$7{c@o3V4=ZeLe+q0=xjl&?_i9bZ0^J4gI!|rZ lJ-O}A?%mt8HmGxcV7G}fv#~WQQU;pM;OXk;vd$@?2>>74Fth*w literal 0 HcmV?d00001 diff --git a/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.txt b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.txt new file mode 100644 index 0000000000..08f222b630 --- /dev/null +++ b/crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.txt @@ -0,0 +1,8 @@ +3 1 3 3 1 1 3 2 +3 3 3 3 1 3 2 1 +2 2 2 2 1 1 2 2 +1 1 1 3 3 3 2 3 +2 2 3 2 3 3 1 3 +1 3 3 1 1 3 2 1 +2 2 2 1 2 1 2 3 +3 1 3 3 2 1 2 2 diff --git a/crates/burn-dataset/tests/data/segmask_folder/images/image_checkerboard.png b/crates/burn-dataset/tests/data/segmask_folder/images/image_checkerboard.png new file mode 100644 index 0000000000000000000000000000000000000000..2087501e41b6f2663152e5a335d1161f2e4c1aab GIT binary patch literal 165 zcmeAS@N?(olHy`uVBq!ia0vp^93afW1SGw4HSYka98VX=5Dr<^gN%#~1`JG&*Bzw< zoG%4GmKI35CwuR_c2~H_dX2dA#rhCNAm97vqDR4^&v%ML7#i*0KX*L>iC+JFryI-& ppK|zjrRztK=Dv2?zfoWPX-X^jhPHh)a3G-5mQxc}yY137zc hR&4&*bneq*W}ZgFzjxQ>_W;dh@O1TaS?83{1ORC6F(&{3 literal 0 HcmV?d00001 diff --git a/crates/burn-dataset/tests/data/segmask_folder/images/image_random_3colors.png b/crates/burn-dataset/tests/data/segmask_folder/images/image_random_3colors.png new file mode 100644 index 0000000000000000000000000000000000000000..880bac466f179eccc75613d429a4993c69510ae6 GIT binary patch literal 204 zcmeAS@N?(olHy`uVBq!ia0vp^93afW1SGw4HSYka$(}BbAsn)%2N&`lRS#qG&nfqQ`0$<%JF6?cA*VDeVjGtWP)}Oz`AgCQbJ@M}gPkWO z=I#8UBKdQJr<>iIO`)G($j#N?%Lo*o$j|+UVX}|gf*@<5oj~U?c)I$ztaD0e0ssYW BOJV>3 literal 0 HcmV?d00001