-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add VQVAE + Transformer inferer (#242)
* [WIP] Add inferer Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * [WIP] Add sample method Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * Add ordering and complete Inferer methods Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * Update inferer Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * Add test_prediction_shape Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * Add test_sample Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * Fix comments and docstring Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * Remove starting_token from __call__ Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> --------- Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
- Loading branch information
Showing
3 changed files
with
252 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import torch | ||
from parameterized import parameterized | ||
|
||
from generative.inferers import VQVAETransformerInferer | ||
from generative.networks.nets import VQVAE, DecoderOnlyTransformer | ||
from generative.utils.ordering import Ordering, OrderingType | ||
|
||
TEST_CASES = [ | ||
[ | ||
{ | ||
"spatial_dims": 2, | ||
"in_channels": 1, | ||
"out_channels": 1, | ||
"num_levels": 2, | ||
"downsample_parameters": ((2, 4, 1, 1),) * 2, | ||
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2, | ||
"num_res_layers": 1, | ||
"num_channels": 8, | ||
"num_res_channels": [8, 8], | ||
"num_embeddings": 16, | ||
"embedding_dim": 8, | ||
}, | ||
{ | ||
"num_tokens": 16 + 1, | ||
"max_seq_len": 4 + 1, | ||
"attn_layers_dim": 4, | ||
"attn_layers_depth": 2, | ||
"attn_layers_heads": 1, | ||
"with_cross_attention": False, | ||
}, | ||
{"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)}, | ||
(2, 1, 8, 8), | ||
(2, 5, 17), | ||
], | ||
[ | ||
{ | ||
"spatial_dims": 3, | ||
"in_channels": 1, | ||
"out_channels": 1, | ||
"num_levels": 2, | ||
"downsample_parameters": ((2, 4, 1, 1),) * 2, | ||
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2, | ||
"num_res_layers": 1, | ||
"num_channels": 8, | ||
"num_res_channels": [8, 8], | ||
"num_embeddings": 16, | ||
"embedding_dim": 8, | ||
}, | ||
{ | ||
"num_tokens": 16 + 1, | ||
"max_seq_len": 9 + 1, | ||
"attn_layers_dim": 4, | ||
"attn_layers_depth": 2, | ||
"attn_layers_heads": 1, | ||
"with_cross_attention": False, | ||
}, | ||
{"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)}, | ||
(2, 1, 8, 8, 8), | ||
(2, 9, 17), | ||
], | ||
] | ||
|
||
|
||
class TestVQVAETransformerInferer(unittest.TestCase): | ||
@parameterized.expand(TEST_CASES) | ||
def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input_shape, latent_shape): | ||
stage_1 = VQVAE(**stage_1_params) | ||
stage_2 = DecoderOnlyTransformer(**stage_2_params) | ||
ordering = Ordering(**ordering_params) | ||
|
||
device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
stage_1.to(device) | ||
stage_2.to(device) | ||
stage_1.eval() | ||
stage_2.eval() | ||
|
||
input = torch.randn(input_shape).to(device) | ||
|
||
inferer = VQVAETransformerInferer() | ||
prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) | ||
self.assertEqual(prediction.shape, latent_shape) | ||
|
||
def test_sample(self): | ||
stage_1 = VQVAE( | ||
spatial_dims=2, | ||
in_channels=1, | ||
out_channels=1, | ||
num_levels=2, | ||
downsample_parameters=((2, 4, 1, 1),) * 2, | ||
upsample_parameters=((2, 4, 1, 1, 0),) * 2, | ||
num_res_layers=1, | ||
num_channels=8, | ||
num_res_channels=(8, 8), | ||
num_embeddings=16, | ||
embedding_dim=8, | ||
) | ||
stage_2 = DecoderOnlyTransformer( | ||
num_tokens=16 + 1, | ||
max_seq_len=4 + 1, | ||
attn_layers_dim=4, | ||
attn_layers_depth=2, | ||
attn_layers_heads=1, | ||
with_cross_attention=False, | ||
) | ||
ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) | ||
|
||
device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
stage_1.to(device) | ||
stage_2.to(device) | ||
stage_1.eval() | ||
stage_2.eval() | ||
|
||
inferer = VQVAETransformerInferer() | ||
|
||
starting_token = 16 # from stage_1 num_embeddings | ||
|
||
sample = inferer.sample( | ||
latent_spatial_dim=(2, 2), | ||
starting_tokens=starting_token * torch.ones((2, 1), device=device), | ||
vqvae_model=stage_1, | ||
transformer_model=stage_2, | ||
ordering=ordering, | ||
) | ||
self.assertEqual(sample.shape, (2, 1, 8, 8)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |