Skip to content

Commit

Permalink
add info on inference site
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Mar 24, 2024
1 parent cd6dddf commit ac08937
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 194 deletions.
46 changes: 9 additions & 37 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@

# Video Swin Transformer Model Zoo

Video Swin in `keras` can be used with multiple backends, i.e. `tensorflow`, `torch`, and `jax`. The input shape are expected to be `channel_last`, i.e. `(depth, height, width, channel)`.
Video Swin in `keras` can be used with multiple backends, i.e. `tensorflow`, `torch`, and `jax`. The input shape are expected to be `channel_last`, i.e. `(depth, height, width, channel)`.

## Note

While evaluating the video model for classification task, multiple clips from a video are sampled. This process also involves multiple crops on the sample.
**Note**: While evaluating the video model for classification task, multiple clips from a video are sampled. And additionally, this process also involves multiple crops on the sample. So, while evaluating on benchmark dataset, we should consider this current standard.

- `#Frame = #input_frame x #clip x #crop`. The frame interval is `2` to evaluate on benchmark dataset.
- `#input_frame` means how many frames are input for model during the test phase. For video swin, it is `32`.
Expand All @@ -15,52 +13,26 @@ While evaluating the video model for classification task, multiple clips from a

# Checkpoints

In the training phase, the video swin mdoels are initialized with the pretrained weights of image swin models. In that case, `IN` referes to **ImageNet**. In the following, the `keras` checkpoints are the complete model, so `keras.saving.load_model` API can be used. In contrast, the `h5` checkpoints are the only weight file.
In the training phase, the video swin mdoels are initialized with the pretrained weights of image swin models. In the following table, `IN` referes to **ImageNet**. By default, the video swin is trained with input shape of `32, 224, 224, 3`.

### Kinetics 400

| Model | Pretrain | #Frame | Top-1 | Top-5 | Checkpoints | config |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-T | IN-1K | 32x4x3 | 78.8 | 93.6 | [keras](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400_classifier.keras)/[h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400.weights.h5) | [swin-t](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_tiny_patch244_window877_kinetics400_1k.py) |
| Swin-S | IN-1K | 32x4x3 | 80.6 | 94.5 | [keras](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_small_kinetics400_classifier.keras)/[h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_small_kinetics400.weights.h5) | [swin-s](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_small_patch244_window877_kinetics400_1k.py) |
| Swin-B | IN-1K | 32x4x3 | 80.6 | 94.6 | [keras](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400_classifier.keras)/[h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400.weights.h5) | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics400_1k.py) |
| Swin-B | IN-22K | 32x4x3 | 82.7 | 95.5 | [keras](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400_imagenet22k_classifier.keras)/[h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400_imagenet22k.weights.h5) | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics400_22k.py) |
| Swin-T | IN-1K | 32x4x3 | 78.8 | 93.6 | [h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400.weights.h5) | [swin-t](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_tiny_patch244_window877_kinetics400_1k.py) |
| Swin-S | IN-1K | 32x4x3 | 80.6 | 94.5 | [h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_small_kinetics400.weights.h5) | [swin-s](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_small_patch244_window877_kinetics400_1k.py) |
| Swin-B | IN-1K | 32x4x3 | 80.6 | 94.6 | [h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400.weights.h5) | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics400_1k.py) |
| Swin-B | IN-22K | 32x4x3 | 82.7 | 95.5 | [h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics400_imagenet22k.weights.h5) | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics400_22k.py) |

### Kinetics 600

| Model | Pretrain | #Frame | Top-1 | Top-5 | Checkpoints | config |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-B | IN-22K | 32x4x3 | 84.0 | 96.5 | [keras](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics600_imagenet22k_classifier.keras)/[h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics600_imagenet22k.weights.h5) | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics600_22k.py) |
| Swin-B | IN-22K | 32x4x3 | 84.0 | 96.5 | [h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_kinetics600_imagenet22k.weights.h5) | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window877_kinetics600_22k.py) |

### Something-Something V2

| Model | Pretrain | #Frame | Top-1 | Top-5 | Checkpoints | config |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Swin-B | Kinetics 400 | 32x1x3 | 69.6 | 92.7 | [keras](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_something_something_v2_classifier.keras)/[h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_something_something_v2.weights.h5) | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window1677_sthv2.py) |


## Weight Comparison

The `torch` videoswin model can be loaded from the official [repo](https://github.com/SwinTransformer/Video-Swin-Transformer). Following are some quick test of both implementation showing logit matching.

```python
input = np.random.rand(4, 32, 224, 224, 3).astype('float32')
inputs = torch.tensor(input)
inputs = torch.einsum('nthwc->ncthw', inputs)
# inputs.shape: torch.Size([4, 3, 32, 224, 224])

# torch model
model_pt.eval()
x = model_torch(inputs.float())
x = x.detach().numpy()
x.shape # (4, 174) (Sth-Sth dataset)

# keras model
y = model_keras(input, training=False)
y = y.numpy()
y.shape # (4, 174) (Sth-Sth dataset)
| Swin-B | Kinetics 400 | 32x1x3 | 69.6 | 92.7 | [h5](https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_base_something_something_v2.weights.h5) | [swin-b](https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window1677_sthv2.py) |

np.testing.assert_allclose(x, y, 1e-4, 1e-4)
np.testing.assert_allclose(x, y, 1e-5, 1e-5)
# OK
```
25 changes: 23 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

VideoSwin is a pure transformer based video modeling algorithm, attained top accuracy on the major video recognition benchmarks. In this model, the author advocates an inductive bias of locality in video transformers, which leads to a better speed-accuracy trade-off compared to previous approaches which compute self-attention globally even with spatial-temporal factorization. The locality of the proposed video architecture is realized by adapting the [**Swin Transformer**](https://arxiv.org/abs/2103.14030) designed for the image domain, while continuing to leverage the power of pre-trained image models.

This is a unofficial `Keras 3` implementation of [Video Swin transformers](https://arxiv.org/abs/2106.13230). The official `PyTorch` implementation is [here](https://github.com/SwinTransformer/Video-Swin-Transformer) based on [mmaction2](https://github.com/open-mmlab/mmaction2). The official PyTorch weight has been converted to `Keras 3` compatible. This implementaiton supports to run the model on multiple backend, i.e. TensorFlow, PyTorch, and Jax.
This is a unofficial `Keras 3` implementation of [Video Swin transformers](https://arxiv.org/abs/2106.13230). The official `PyTorch` implementation is [here](https://github.com/SwinTransformer/Video-Swin-Transformer) based on [mmaction2](https://github.com/open-mmlab/mmaction2). The official PyTorch weight has been converted to `Keras 3` compatible. This implementaiton supports to run the model on multiple backend, i.e. TensorFlow, PyTorch, and Jax. However, to work with `tensorflow.keras`, check the `tfkeras` branch.


# Install
Expand All @@ -23,7 +23,7 @@ This is a unofficial `Keras 3` implementation of [Video Swin transformers](https

# Checkpoints

The **VideoSwin** checkpoints are available in both `.weights.h5`, and `.keras` formats for Kinetrics 400/600 and Something Something V2 datasets. Here, the `H5` format is the **weight** file and the `keras` format is the **weight + model architecture**. The variants of this models are `tiny`, `small`, and `base`. Check [model zoo](https://github.com/innat/VideoSwin/blob/main/MODEL_ZOO.md) page to know details of it.
The **VideoSwin** checkpoints are available in `.weights.h5` for Kinetrics 400/600 and Something Something V2 datasets. The variants of this models are `tiny`, `small`, and `base`. Check [model zoo](https://github.com/innat/VideoSwin/blob/main/MODEL_ZOO.md) page to know details of it.


# Inference
Expand Down Expand Up @@ -65,6 +65,27 @@ A classification results on a sample from [Kinetics-400](https://paperswithcode.
| ![](./assets/view1.gif) | <pre>{<br> 'playing_cello': 0.9941741824150085,<br> 'playing_violin': 0.0016851733671501279,<br> 'playing_recorder': 0.0011555481469258666,<br> 'playing_clarinet': 0.0009695519111119211,<br> 'playing_harp': 0.0007713600643910468<br>}</pre> |


To get the backbone of video swin, we can pass `include_top=False` params to exclude the classification layer. For example:

```python
from videoswin.backbone import VideoSwinBackbone

backbone = VideoSwinT(
include_top=False, input_shape=(32, 224, 224, 3)
)
```

By default, the video swin officially is trained with input shape of `32, 224, 224, 3`. But, We can load the model with different shape. And also load the pretrained weight partially.

```python
model = VideoSwinT(
input_shape=(8, 224, 256, 3),
include_rescaling=False,
num_classes=10,
)
model.load_weights('model.weights.h5', skip_mismatch=True)
```

**Guides**

- To ensure the keras reimplementation with official torch: [logit comparison](guides/video-swin-transformer-keras-and-torchvision.ipynb)
Expand Down
155 changes: 0 additions & 155 deletions videoswin/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
import os
import warnings

warnings.simplefilter(action="ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import keras

from videoswin.backbone import VideoSwinBackbone


@keras.utils.register_keras_serializable(package="swin.transformer.tiny.3d")
def VideoSwinT(
input_shape=(32, 224, 224, 3),
Expand Down Expand Up @@ -56,80 +49,6 @@ def VideoSwinT(
return keras.Model(inputs=inputs, outputs=outputs, name="VideoSwinT", **kwargs)


# @keras.utils.register_keras_serializable(package="swin.transformer.tiny.3d")
# class VideoSwinT(keras.Model):
# def __init__(
# self,
# input_shape=(32, 224, 224, 3),
# num_classes=400,
# pooling="avg",
# activation="softmax",
# window_size=(8,7,7),
# embed_size=96,
# depths=[2, 2, 6, 2],
# num_heads=[3, 6, 12, 24],
# include_rescaling=False,
# include_top=True,
# **kwargs,
# ):

# if pooling == "avg":
# pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool")
# elif pooling == "max":
# pooling_layer = keras.layers.GlobalMaxPooling3D(name="max_pool")
# else:
# raise ValueError(
# f'`pooling` must be one of "avg", "max". Received: {pooling}.'
# )

# backbone = VideoSwinBackbone(
# input_shape=input_shape,
# window_size=window_size,
# embed_dim=embed_size,
# depths=depths,
# num_heads=num_heads,
# include_rescaling=include_rescaling,
# )

# if not include_top:
# return backbone

# inputs = backbone.input
# x = backbone(inputs)
# x = pooling_layer(x)
# outputs = keras.layers.Dense(
# num_classes,
# activation=activation,
# name="predictions",
# dtype="float32",
# )(x)
# super().__init__(inputs=inputs, outputs=outputs, name='VideoSwinT', **kwargs)
# self.window_size = window_size
# self.num_classes = num_classes
# self.pooling = pooling
# self.activation = activation
# self.embed_size = embed_size
# self.depths = depths
# self.num_heads = num_heads
# self.include_rescaling = include_rescaling
# self.include_top = include_top

# def get_config(self):
# config = {
# "input_shape": self.input_shape[1:],
# "window_size": self.window_size,
# "num_classes": self.num_classes,
# "pooling": self.pooling,
# "activation": self.activation,
# "embed_size": self.embed_size,
# "depths": self.depths,
# "num_heads": self.num_heads,
# "include_rescaling": self.include_rescaling,
# "include_top": self.include_top,
# }
# return config


@keras.utils.register_keras_serializable(package="swin.transformer.small.3d")
class VideoSwinS(keras.Model):
def __init__(
Expand Down Expand Up @@ -249,77 +168,3 @@ def VideoSwinB(
)(x)

return keras.Model(inputs=inputs, outputs=outputs, name="VideoSwinB", **kwargs)


# @keras.utils.register_keras_serializable(package="swin.transformer.base.3d")
# class VideoSwinB(keras.Model):
# def __init__(
# self,
# input_shape=(32, 224, 224, 3),
# num_classes=400,
# pooling="avg",
# activation="softmax",
# window_size=(8, 7, 7),
# embed_size=128,
# depths=[2, 2, 18, 2],
# num_heads=[4, 8, 16, 32],
# include_rescaling=False,
# include_top=True,
# **kwargs,
# ):

# if pooling == "avg":
# pooling_layer = keras.layers.GlobalAveragePooling3D(name="avg_pool")
# elif pooling == "max":
# pooling_layer = keras.layers.GlobalMaxPooling3D(name="max_pool")
# else:
# raise ValueError(
# f'`pooling` must be one of "avg", "max". Received: {pooling}.'
# )

# backbone = VideoSwinBackbone(
# input_shape=input_shape,
# embed_dim=embed_size,
# window_size=window_size,
# depths=depths,
# num_heads=num_heads,
# include_rescaling=include_rescaling,
# )

# if not include_top:
# return backbone

# inputs = backbone.input
# x = backbone(inputs)
# x = pooling_layer(x)
# outputs = keras.layers.Dense(
# num_classes,
# activation=activation,
# name="predictions",
# dtype="float32",
# )(x)
# super().__init__(inputs=inputs, outputs=outputs, name="VideoSwinB", **kwargs)
# self.window_size = window_size
# self.num_classes = num_classes
# self.pooling = pooling
# self.activation = activation
# self.embed_size = embed_size
# self.depths = depths
# self.num_heads = num_heads
# self.include_rescaling = include_rescaling
# self.include_top = include_top

# def get_config(self):
# config = {
# "input_shape": self.input_shape[1:],
# "window_size": self.window_size,
# "num_classes": self.num_classes,
# "pooling": self.pooling,
# "activation": self.activation,
# "embed_size": self.embed_size,
# "depths": self.depths,
# "num_heads": self.num_heads,
# "include_rescaling": self.include_rescaling,
# "include_top": self.include_top,
# }
# return config

0 comments on commit ac08937

Please sign in to comment.