Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate video classification example to keras 3.0 #1674

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions examples/vision/ipynb/video_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Sayak Paul](https://twitter.com/RisingSayak)<br>\n",
"**Date created:** 2021/05/28<br>\n",
"**Last modified:** 2023/08/28<br>\n",
"**Last modified:** 2023/12/08<br>\n",
"**Description:** Training a video classifier with transfer learning and a recurrent model on the UCF101 dataset."
]
},
Expand Down Expand Up @@ -72,7 +72,7 @@
},
"outputs": [],
"source": [
"!wget -q https://github.com/sayakpaul/Action-Recognition-in-TensorFlow/releases/download/v1.0.0/ucf101_top5.tar.gz\n",
"!!wget -q https://github.com/sayakpaul/Action-Recognition-in-TensorFlow/releases/download/v1.0.0/ucf101_top5.tar.gz\n",
"!tar xf ucf101_top5.tar.gz"
]
},
Expand All @@ -93,17 +93,17 @@
},
"outputs": [],
"source": [
"from tensorflow_docs.vis import embed\n",
"from tensorflow import keras\n",
"import os\n",
"\n",
"import keras\n",
"from imutils import paths\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf\n",
"import pandas as pd\n",
"import numpy as np\n",
"import imageio\n",
"import cv2\n",
"import os"
"from IPython.display import Image"
]
},
{
Expand Down Expand Up @@ -314,7 +314,7 @@
" num_samples = len(df)\n",
" video_paths = df[\"video_name\"].values.tolist()\n",
" labels = df[\"tag\"].values\n",
" labels = label_processor(labels[..., None]).numpy()\n",
" labels = keras.ops.convert_to_numpy(label_processor(labels[..., None]))\n",
"\n",
" # `frame_masks` and `frame_features` are what we will feed to our sequence model.\n",
" # `frame_masks` will contain a bunch of booleans denoting if a timestep is\n",
Expand All @@ -331,7 +331,13 @@
" frames = frames[None, ...]\n",
"\n",
" # Initialize placeholders to store the masks and features of the current video.\n",
" temp_frame_mask = np.zeros(shape=(1, MAX_SEQ_LENGTH,), dtype=\"bool\")\n",
" temp_frame_mask = np.zeros(\n",
" shape=(\n",
" 1,\n",
" MAX_SEQ_LENGTH,\n",
" ),\n",
" dtype=\"bool\",\n",
" )\n",
" temp_frame_features = np.zeros(\n",
" shape=(1, MAX_SEQ_LENGTH, NUM_FEATURES), dtype=\"float32\"\n",
" )\n",
Expand Down Expand Up @@ -388,6 +394,7 @@
},
"outputs": [],
"source": [
"\n",
"# Utility for our sequence model.\n",
"def get_sequence_model():\n",
" class_vocab = label_processor.get_vocabulary()\n",
Expand Down Expand Up @@ -415,7 +422,7 @@
"\n",
"# Utility for running experiments.\n",
"def run_experiment():\n",
" filepath = \"/tmp/video_classifier\"\n",
" filepath = \"/tmp/video_classifier/ckpt.weights.h5\"\n",
" checkpoint = keras.callbacks.ModelCheckpoint(\n",
" filepath, save_weights_only=True, save_best_only=True, verbose=1\n",
" )\n",
Expand Down Expand Up @@ -471,7 +478,13 @@
"\n",
"def prepare_single_video(frames):\n",
" frames = frames[None, ...]\n",
" frame_mask = np.zeros(shape=(1, MAX_SEQ_LENGTH,), dtype=\"bool\")\n",
" frame_mask = np.zeros(\n",
" shape=(\n",
" 1,\n",
" MAX_SEQ_LENGTH,\n",
" ),\n",
" dtype=\"bool\",\n",
" )\n",
" frame_features = np.zeros(shape=(1, MAX_SEQ_LENGTH, NUM_FEATURES), dtype=\"float32\")\n",
"\n",
" for i, batch in enumerate(frames):\n",
Expand Down Expand Up @@ -502,7 +515,7 @@
"def to_gif(images):\n",
" converted_images = images.astype(np.uint8)\n",
" imageio.mimsave(\"animation.gif\", converted_images, duration=100)\n",
" return embed.embed_file(\"animation.gif\")\n",
" return Image(\"animation.gif\")\n",
"\n",
"\n",
"test_video = np.random.choice(test_df[\"video_name\"].values.tolist())\n",
Expand All @@ -523,7 +536,7 @@
"from video frames. You could also fine-tune the pre-trained network to notice how that\n",
"affects the end results.\n",
"* For speed-accuracy trade-offs, you can try out other models present inside\n",
"`tf.keras.applications`.\n",
"`keras.applications`.\n",
"* Try different combinations of `MAX_SEQ_LENGTH` to observe how that affects the\n",
"performance.\n",
"* Train on a higher number of classes and see if you are able to get good performance.\n",
Expand Down Expand Up @@ -571,4 +584,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading