From c236ff0264f6eb5484c27bb8e3d7dbffd220f08f Mon Sep 17 00:00:00 2001 From: Eugene Hauptmann Date: Mon, 16 Dec 2024 16:53:31 -0500 Subject: [PATCH] fixed URLs and show_image --- .DS_Store | Bin 0 -> 8196 bytes .vscode/settings.json | 4 +++ Cargo.toml | 9 +++--- README.md | 16 ++++++++++ examples/fashion_mnist.rs | 27 ++++++++-------- examples/mnist.rs | 28 ++++++++--------- src/download.rs | 16 ++++++---- src/lib.rs | 64 +++++++++++++++++++++++--------------- 8 files changed, 100 insertions(+), 64 deletions(-) create mode 100644 .DS_Store create mode 100644 .vscode/settings.json diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..109fe0489b1aaf952a9ae11bc2c80652b3d42cd9 GIT binary patch literal 8196 zcmeHMU2GIp6u#edp) z74$(3QD4*;{CibV9yEpr;)99!;)7|7n1Bx^zUYJcXiW6X+*#5V`d~r~33HQs&pqd! zbLZZ3zcY8wEMp8EC1V|94U93FYC!oMYVJ^i@u1#Tq)3t=3bJSUF>f?)yLsZjxJ^42 z#2$z}5PKl@KhqvOG!@ZIM8^fC5h@U-M1?7$6a&JX&rhenhg6lAHM_c6sh(4PBs=VmWqsW* zc^!K3P{8%LuAQ&+^?aXYj^%iLr|bB-<(S3Lz%ooyIoxkKhC9~bmK?(iZL?xPQDh~@ z8_t|rw_!_5YQv`V*_PCqE$h=Qsn)jEvu71~Vbj`e2XiOw5yyQ=Yyr`0fX$n>rI#n~ z$M(vd;(J6@ZDl_{MEj8yb&l5Crw-(JU6~f$J?dMoWFphb+2GLhUJ_X zBF{O#b=)$GN-bKH9!DQBC-|bp4NI14+H~#8ECZ#y^>oSdPxB^O7SCwZ9#Hsa@?nD( z__%8XGaBKD)~GF)d|e3ip(rlXSeobiLXtb3^I1 z^}J~d(I$0^%=at(W#L~TtWzJAm4SjgGD4g2ge;?V$o!goGIIDHEl4Ms@*V5gErl;XffBG(y}uDj%>K) zlr4fYowz`?+>~h+hZAl`Y09Ldvy#~owuY|3J?t_>Kk-DE$r-`Jn*ZzM1eHQ=blQfNqG8CIbc8_|Xi?8gB-fo}BR2#(?ybR36^Q550f zG@i#fynq*R0dL?D7A@oOO-$woQgSV2l2N8p8_TNa3|S!3Y28mE_Kg65h&N^2BA?~x1fIEvZj@* z*SBrkaSK2KZxx75hHi+~DHsorH;kI;NE0!nCVL6a=Bb%NiB|_Isu8u##}ug4Kk}&De@<*p4pj#$N0rl=h$x z0~o{*3NUd3!>|aWK1wL#Sv*HLJx@rzfLHMvUdP+GhAC4^_aWM9;g!YH3a#8*UaDl@A6aQ>0%GW9=Ml0fW*%1&JJ49Oz+(w z*G^GAMl~pb-ngKl2{mClPBcu%iQf9dkoqZ74Qw(I9T$`&RQ~r50pd4qeE-Mye~7>P GwfPrOE)pUD literal 0 HcmV?d00001 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..0f51e55 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "editor.formatOnSave": true, + "editor.defaultFormatter": "rust-lang.rust-analyzer" +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index cae609e..78286d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "mnist" description = "MNIST data set parser." -version = "0.6.0" +version = "0.6.1" authors = ["David McNeil "] repository = "https://github.com/davidMcneil/mnist" documentation = "https://docs.rs/mnist" @@ -28,7 +28,6 @@ pbr = {version = "1.0", optional = true} flate2 = {version = "1.0.2", optional = true, features = ["rust_backend"], default-features = false} [dev-dependencies] -ndarray = "0.14" -image = "0.23" -# show-image is used to visualize datasets in a pop-up window -show-image = {version = "0.6", features = ["image"]} \ No newline at end of file +ndarray = "0.16.1" +image = "0.25.5" +show-image = { version = "0.14.0", git = "https://github.com/robohouse-delft/show-image-rs.git", features = ["image", "png"] } diff --git a/README.md b/README.md index fa8e9f3..8e5750d 100644 --- a/README.md +++ b/README.md @@ -57,3 +57,19 @@ An example of downloading this dataset may be found by running: ```sh $ cargo run --features download --example fashion_mnist ``` + +## Troubleshooting + +### On Mac + + +`ld: library not found for -lSDL2` + +```shell +brew link sdl2 +brew unlink sdl2 && brew link sdl2 +export LIBRARY_PATH="$LIBRARY_PATH:/opt/homebrew/lib" # needed to link to SDL2 library + +cargo clean +cargo run --example mnist -F download +``` \ No newline at end of file diff --git a/examples/fashion_mnist.rs b/examples/fashion_mnist.rs index 806b946..9b1df08 100644 --- a/examples/fashion_mnist.rs +++ b/examples/fashion_mnist.rs @@ -1,8 +1,9 @@ use image::*; use mnist::*; use ndarray::prelude::*; -use show_image::{make_window_full, Event, WindowOptions}; +use show_image::{create_window, event, WindowOptions}; +#[show_image::main] fn main() { let (trn_size, _rows, _cols) = (50_000, 28, 28); @@ -27,24 +28,22 @@ fn main() { .mapv(|x| x as f32 / 256.); let image = bw_ndarray2_to_rgb_image(train_data.slice(s![item_num, .., ..]).to_owned()); - let window_options = WindowOptions { - name: "image".to_string(), - size: [100, 100], - resizable: true, - preserve_aspect_ratio: true, - }; - let window = make_window_full(window_options).unwrap(); - window.set_image(image, "test_result").unwrap(); + let window_options = WindowOptions::new().set_size(Some([100, 100])); + let window = create_window("image", window_options).unwrap(); + window.set_image("test_result", image).unwrap(); - for event in window.events() { - if let Event::KeyboardEvent(event) = event { - if event.key == show_image::KeyCode::Escape { + // Wait for the window to be closed or Escape to be pressed. + for event in window.event_channel().map_err(|e| e.to_string()).unwrap() { + if let event::WindowEvent::KeyboardInput(event) = event { + if !event.is_synthetic + && event.input.key_code == Some(event::VirtualKeyCode::Escape) + && event.input.state.is_pressed() + { + println!("Escape pressed!"); break; } } } - - show_image::stop().unwrap(); } fn return_item_description_from_number(val: u8) { diff --git a/examples/mnist.rs b/examples/mnist.rs index 9597972..4e3972f 100644 --- a/examples/mnist.rs +++ b/examples/mnist.rs @@ -1,8 +1,9 @@ use image::*; use mnist::*; use ndarray::prelude::*; -use show_image::{make_window_full, Event, WindowOptions}; +use show_image::{create_window, event, WindowOptions}; +#[show_image::main] fn main() { let (trn_size, _rows, _cols) = (50_000, 28, 28); @@ -28,24 +29,23 @@ fn main() { .mapv(|x| x as f32 / 256.); let image = bw_ndarray2_to_rgb_image(train_data.slice(s![item_num, .., ..]).to_owned()); - let window_options = WindowOptions { - name: "image".to_string(), - size: [100, 100], - resizable: true, - preserve_aspect_ratio: true, - }; - let window = make_window_full(window_options).unwrap(); - window.set_image(image, "test_result").unwrap(); + let window_options = WindowOptions::new().set_size(Some([100, 100])); + let window = create_window("image", window_options).unwrap(); + + window.set_image("test_result", image).unwrap(); - for event in window.events() { - if let Event::KeyboardEvent(event) = event { - if event.key == show_image::KeyCode::Escape { + // Wait for the window to be closed or Escape to be pressed. + for event in window.event_channel().map_err(|e| e.to_string()).unwrap() { + if let event::WindowEvent::KeyboardInput(event) = event { + if !event.is_synthetic + && event.input.key_code == Some(event::VirtualKeyCode::Escape) + && event.input.state.is_pressed() + { + println!("Escape pressed!"); break; } } } - - show_image::stop().unwrap(); } fn return_item_description_from_number(val: u8) { diff --git a/src/download.rs b/src/download.rs index 2a0987b..ceef1c9 100644 --- a/src/download.rs +++ b/src/download.rs @@ -11,15 +11,18 @@ use pbr::ProgressBar; use std::convert::TryInto; use std::thread; +#[allow(unused)] use log::Level; +use std::time::Duration; + #[cfg(target_family = "unix")] use std::os::unix::fs::MetadataExt; #[cfg(target_family = "windows")] use std::os::windows::fs::MetadataExt; #[cfg(target_family = "unix")] -fn file_size(meta: &MetadataExt) -> usize { +fn file_size(meta: &dyn MetadataExt) -> usize { meta.size() as usize } @@ -70,6 +73,7 @@ pub(super) fn download_and_extract( Ok(()) } +#[allow(unused_variables)] fn download( base_url: &str, archive: &str, @@ -81,10 +85,10 @@ fn download( let url = Path::new(base_url).join(archive); let file_name = download_dir.to_str().unwrap().to_owned() + archive; //.clone(); if Path::new(&file_name).exists() { - log::info!( - " File {:?} already exists, skipping downloading.", - file_name - ); + log::info!( + " File {:?} already exists, skipping downloading.", + file_name + ); } else { log::info!( "- Downloading from file from {} and saving to file as: {}", @@ -105,7 +109,7 @@ fn download( current_size = file_size(&meta); pb.set(current_size.try_into().unwrap()); - thread::sleep_ms(10); + thread::sleep(Duration::from_millis(10)); } pb.finish_println(" "); }); diff --git a/src/lib.rs b/src/lib.rs index 2182194..380828f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,8 +100,22 @@ use std::fs::File; use std::io::prelude::*; use std::path::Path; +// From https://github.com/cvdfoundation/mnist + +// Training +// https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz +// https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz + +// Testing +// https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz +// https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz + static BASE_PATH: &str = "data/"; -static BASE_URL: &str = "http://yann.lecun.com/exdb/mnist"; +// old, doesn't work, gives 404 +// static BASE_URL: &str = "http://yann.lecun.com/exdb/mnist"; +static BASE_URL: &str = "https://storage.googleapis.com/cvdf-datasets/mnist/"; + +#[allow(dead_code)] static FASHION_BASE_URL: &str = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com"; static TRN_IMG_FILENAME: &str = "train-images-idx3-ubyte"; static TRN_LBL_FILENAME: &str = "train-labels-idx1-ubyte"; @@ -401,6 +415,7 @@ impl<'a> MnistBuilder<'a> { /// If `trn_len + val_len + tst_len > 70,000`. pub fn finalize(&self) -> Mnist { if self.download_and_extract { + #[cfg(feature = "download")] let base_url = if self.use_fashion_data { FASHION_BASE_URL } else if self.base_url != BASE_URL { @@ -430,10 +445,9 @@ impl<'a> MnistBuilder<'a> { let available_length = (TRN_LEN + TST_LEN) as usize; assert!( total_length <= available_length, - format!( - "Total data set length ({}) greater than maximum possible length ({}).", - total_length, available_length - ) + "Total data set length ({}) greater than maximum possible length ({}).", + total_length, + available_length ); let mut trn_img = images( &Path::new(self.base_path).join(self.trn_img_filename), @@ -457,8 +471,8 @@ impl<'a> MnistBuilder<'a> { let mut val_lbl = trn_lbl.split_off(trn_len); let mut tst_img = val_img.split_off(val_len * ROWS * COLS); let mut tst_lbl = val_lbl.split_off(val_len); - tst_img.split_off(tst_len * ROWS * COLS); - tst_lbl.split_off(tst_len); + let _ = tst_img.split_off(tst_len * ROWS * COLS); + let _ = tst_lbl.split_off(tst_len); if self.lbl_format == LabelFormat::OneHotVector { fn digit2one_hot(v: Vec) -> Vec { v.iter() @@ -559,20 +573,18 @@ fn labels(path: &Path, expected_length: u32) -> Vec { .unwrap_or_else(|_| panic!("Unable to read magic number from {:?}.", path)); assert!( LBL_MAGIC_NUMBER == magic_number, - format!( - "Expected magic number {} got {}.", - LBL_MAGIC_NUMBER, magic_number - ) + "Expected magic number {} got {}.", + LBL_MAGIC_NUMBER, + magic_number ); let length = file .read_u32::() .unwrap_or_else(|_| panic!("Unable to length from {:?}.", path)); assert!( expected_length == length, - format!( - "Expected data set length of {} got {}.", - expected_length, length - ) + "Expected data set length of {} got {}.", + expected_length, + length ); file.bytes().map(|b| b.unwrap()).collect() } @@ -596,20 +608,18 @@ fn images(path: &Path, expected_length: u32) -> Vec { .unwrap_or_else(|_| panic!("Unable to read magic number from {:?}.", path)); assert!( IMG_MAGIC_NUMBER == magic_number, - format!( - "Expected magic number {} got {}.", - IMG_MAGIC_NUMBER, magic_number - ) + "Expected magic number {} got {}.", + IMG_MAGIC_NUMBER, + magic_number ); let length = file .read_u32::() .unwrap_or_else(|_| panic!("Unable to length from {:?}.", path)); assert!( expected_length == length, - format!( - "Expected data set length of {} got {}.", - expected_length, length - ) + "Expected data set length of {} got {}.", + expected_length, + length ); let rows = file .read_u32::() @@ -617,7 +627,9 @@ fn images(path: &Path, expected_length: u32) -> Vec { as usize; assert!( ROWS == rows, - format!("Expected rows length of {} got {}.", ROWS, rows) + "Expected rows length of {} got {}.", + ROWS, + rows ); let cols = file .read_u32::() @@ -625,7 +637,9 @@ fn images(path: &Path, expected_length: u32) -> Vec { as usize; assert!( COLS == cols, - format!("Expected cols length of {} got {}.", COLS, cols) + "Expected cols length of {} got {}.", + COLS, + cols ); // Convert `file` from a Vec to a slice. file.to_vec()