Skip to content

Commit

Permalink
Add VQVAE + Transformer inferer (#242)
Browse files Browse the repository at this point in the history
* [WIP] Add inferer

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* [WIP] Add sample method

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Add ordering and complete Inferer methods

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Update inferer

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Add test_prediction_shape

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Add test_sample

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Fix comments and docstring

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Remove starting_token from __call__

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

---------

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
  • Loading branch information
Warvito authored Feb 13, 2023
1 parent 57bf5c4 commit 91e9429
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 2 deletions.
2 changes: 1 addition & 1 deletion generative/inferers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

from __future__ import annotations

from .inferer import DiffusionInferer, LatentDiffusionInferer
from .inferer import DiffusionInferer, LatentDiffusionInferer, VQVAETransformerInferer
109 changes: 108 additions & 1 deletion generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from __future__ import annotations

import math
from collections.abc import Callable
from collections.abc import Callable, Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.inferers import Inferer
from monai.utils import optional_import

Expand Down Expand Up @@ -414,3 +415,109 @@ def get_likelihood(
intermediates = [resizer(x) for x in intermediates]
outputs = (outputs[0], intermediates)
return outputs


class VQVAETransformerInferer(Inferer):
"""
Class to perform inference with a VQVAE + Transformer model.
"""

def __init__(self) -> None:
Inferer.__init__(self)

def __call__(
self,
inputs: torch.Tensor,
vqvae_model: Callable[..., torch.Tensor],
transformer_model: Callable[..., torch.Tensor],
ordering: Callable[..., torch.Tensor],
condition: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Implements the forward pass for a supervised training iteration.
Args:
inputs: input image to which the latent representation will be extracted.
vqvae_model: first stage model.
transformer_model: autoregressive transformer model.
ordering: ordering of the quantised latent representation.
condition: conditioning for network input.
"""
with torch.no_grad():
latent = vqvae_model.index_quantize(inputs)

latent = latent.reshape(latent.shape[0], -1)
latent = latent[:, ordering.get_sequence_ordering()]

# Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token.
# Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens.
latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings)
latent = latent.long()

prediction = transformer_model(x=latent, context=condition)

return prediction

@torch.no_grad()
def sample(
self,
latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int],
starting_tokens: torch.Tensor,
vqvae_model: Callable[..., torch.Tensor],
transformer_model: Callable[..., torch.Tensor],
ordering: Callable[..., torch.Tensor],
conditioning: torch.Tensor | None = None,
temperature: float = 1.0,
top_k: int | None = None,
verbose: bool | None = True,
) -> torch.Tensor:
"""
Sampling function for the VQVAE + Transformer model.
Args:
latent_spatial_dim: shape of the sampled image.
starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value.
vqvae_model: first stage model.
transformer_model: model to sample from.
conditioning: Conditioning for network input.
temperature: temperature for sampling.
top_k: top k sampling.
verbose: if true, prints the progression bar of the sampling process.
"""
seq_len = math.prod(latent_spatial_dim)

if verbose and has_tqdm:
progress_bar = tqdm(range(seq_len))
else:
progress_bar = iter(range(seq_len))

latent_seq = starting_tokens.long()
for _ in progress_bar:
# if the sequence context is growing too long we must crop it at block_size
if latent_seq.size(1) <= transformer_model.max_seq_len:
idx_cond = latent_seq
else:
idx_cond = latent_seq[:, -transformer_model.max_seq_len :]

# forward the model to get the logits for the index in the sequence
logits = transformer_model(x=idx_cond, context=conditioning)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# remove the chance to be sampled the BOS token
probs[:, vqvae_model.num_embeddings] = 0
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence and continue
latent_seq = torch.cat((latent_seq, idx_next), dim=1)

latent_seq = latent_seq[:, 1:]
latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()]
latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim)

return vqvae_model.decode_samples(latent)
143 changes: 143 additions & 0 deletions tests/test_vqvaetransformer_inferer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from generative.inferers import VQVAETransformerInferer
from generative.networks.nets import VQVAE, DecoderOnlyTransformer
from generative.utils.ordering import Ordering, OrderingType

TEST_CASES = [
[
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"downsample_parameters": ((2, 4, 1, 1),) * 2,
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
"num_res_layers": 1,
"num_channels": 8,
"num_res_channels": [8, 8],
"num_embeddings": 16,
"embedding_dim": 8,
},
{
"num_tokens": 16 + 1,
"max_seq_len": 4 + 1,
"attn_layers_dim": 4,
"attn_layers_depth": 2,
"attn_layers_heads": 1,
"with_cross_attention": False,
},
{"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)},
(2, 1, 8, 8),
(2, 5, 17),
],
[
{
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"downsample_parameters": ((2, 4, 1, 1),) * 2,
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
"num_res_layers": 1,
"num_channels": 8,
"num_res_channels": [8, 8],
"num_embeddings": 16,
"embedding_dim": 8,
},
{
"num_tokens": 16 + 1,
"max_seq_len": 9 + 1,
"attn_layers_dim": 4,
"attn_layers_depth": 2,
"attn_layers_heads": 1,
"with_cross_attention": False,
},
{"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)},
(2, 1, 8, 8, 8),
(2, 9, 17),
],
]


class TestVQVAETransformerInferer(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input_shape, latent_shape):
stage_1 = VQVAE(**stage_1_params)
stage_2 = DecoderOnlyTransformer(**stage_2_params)
ordering = Ordering(**ordering_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

input = torch.randn(input_shape).to(device)

inferer = VQVAETransformerInferer()
prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering)
self.assertEqual(prediction.shape, latent_shape)

def test_sample(self):
stage_1 = VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_levels=2,
downsample_parameters=((2, 4, 1, 1),) * 2,
upsample_parameters=((2, 4, 1, 1, 0),) * 2,
num_res_layers=1,
num_channels=8,
num_res_channels=(8, 8),
num_embeddings=16,
embedding_dim=8,
)
stage_2 = DecoderOnlyTransformer(
num_tokens=16 + 1,
max_seq_len=4 + 1,
attn_layers_dim=4,
attn_layers_depth=2,
attn_layers_heads=1,
with_cross_attention=False,
)
ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2))

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

inferer = VQVAETransformerInferer()

starting_token = 16 # from stage_1 num_embeddings

sample = inferer.sample(
latent_spatial_dim=(2, 2),
starting_tokens=starting_token * torch.ones((2, 1), device=device),
vqvae_model=stage_1,
transformer_model=stage_2,
ordering=ordering,
)
self.assertEqual(sample.shape, (2, 1, 8, 8))


if __name__ == "__main__":
unittest.main()

0 comments on commit 91e9429

Please sign in to comment.