Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VQVAE + Transformer inferer #242

Merged
merged 9 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion generative/inferers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

from __future__ import annotations

from .inferer import DiffusionInferer, LatentDiffusionInferer
from .inferer import DiffusionInferer, LatentDiffusionInferer, VQVAETransformerInferer
109 changes: 108 additions & 1 deletion generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from __future__ import annotations

import math
from collections.abc import Callable
from collections.abc import Callable, Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.inferers import Inferer
from monai.utils import optional_import

Expand Down Expand Up @@ -414,3 +415,109 @@ def get_likelihood(
intermediates = [resizer(x) for x in intermediates]
outputs = (outputs[0], intermediates)
return outputs


class VQVAETransformerInferer(Inferer):
"""
Class to perform inference with a VQVAE + Transformer model.
"""

def __init__(self) -> None:
Inferer.__init__(self)

def __call__(
self,
inputs: torch.Tensor,
vqvae_model: Callable[..., torch.Tensor],
transformer_model: Callable[..., torch.Tensor],
ordering: Callable[..., torch.Tensor],
condition: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Implements the forward pass for a supervised training iteration.

Args:
inputs: input image to which the latent representation will be extracted.
vqvae_model: first stage model.
transformer_model: autoregressive transformer model.
ordering: ordering of the quantised latent representation.
condition: conditioning for network input.
"""
with torch.no_grad():
latent = vqvae_model.index_quantize(inputs)

latent = latent.reshape(latent.shape[0], -1)
latent = latent[:, ordering.get_sequence_ordering()]

# Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token.
# Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens.
latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings)
latent = latent.long()

prediction = transformer_model(x=latent, context=condition)

return prediction

@torch.no_grad()
def sample(
self,
latent_spatial_dim: Sequence[int, int, int] | Sequence[int, int],
starting_tokens: torch.Tensor,
vqvae_model: Callable[..., torch.Tensor],
transformer_model: Callable[..., torch.Tensor],
ordering: Callable[..., torch.Tensor],
conditioning: torch.Tensor | None = None,
temperature: float = 1.0,
top_k: int | None = None,
verbose: bool | None = True,
) -> torch.Tensor:
"""
Sampling function for the VQVAE + Transformer model.

Args:
latent_spatial_dim: shape of the sampled image.
starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value.
vqvae_model: first stage model.
transformer_model: model to sample from.
conditioning: Conditioning for network input.
temperature: temperature for sampling.
top_k: top k sampling.
verbose: if true, prints the progression bar of the sampling process.
"""
seq_len = math.prod(latent_spatial_dim)

if verbose and has_tqdm:
progress_bar = tqdm(range(seq_len))
else:
progress_bar = iter(range(seq_len))

latent_seq = starting_tokens.long()
for _ in progress_bar:
# if the sequence context is growing too long we must crop it at block_size
if latent_seq.size(1) <= transformer_model.max_seq_len:
idx_cond = latent_seq
else:
idx_cond = latent_seq[:, -transformer_model.max_seq_len :]

# forward the model to get the logits for the index in the sequence
logits = transformer_model(x=idx_cond, context=conditioning)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# remove the chance to be sampled the BOS token
probs[:, vqvae_model.num_embeddings] = 0
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
Warvito marked this conversation as resolved.
Show resolved Hide resolved
# append sampled index to the running sequence and continue
latent_seq = torch.cat((latent_seq, idx_next), dim=1)

latent_seq = latent_seq[:, 1:]
latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()]
latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim)

return vqvae_model.decode_samples(latent)
143 changes: 143 additions & 0 deletions tests/test_vqvaetransformer_inferer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from generative.inferers import VQVAETransformerInferer
from generative.networks.nets import VQVAE, DecoderOnlyTransformer
from generative.utils.ordering import Ordering, OrderingType

TEST_CASES = [
[
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"downsample_parameters": ((2, 4, 1, 1),) * 2,
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
"num_res_layers": 1,
"num_channels": 8,
"num_res_channels": [8, 8],
"num_embeddings": 16,
"embedding_dim": 8,
},
{
"num_tokens": 16 + 1,
"max_seq_len": 4 + 1,
"attn_layers_dim": 4,
"attn_layers_depth": 2,
"attn_layers_heads": 1,
"with_cross_attention": False,
},
{"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)},
(2, 1, 8, 8),
(2, 5, 17),
],
[
{
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"downsample_parameters": ((2, 4, 1, 1),) * 2,
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
"num_res_layers": 1,
"num_channels": 8,
"num_res_channels": [8, 8],
"num_embeddings": 16,
"embedding_dim": 8,
},
{
"num_tokens": 16 + 1,
"max_seq_len": 9 + 1,
"attn_layers_dim": 4,
"attn_layers_depth": 2,
"attn_layers_heads": 1,
"with_cross_attention": False,
},
{"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)},
(2, 1, 8, 8, 8),
(2, 9, 17),
],
]


class TestVQVAETransformerInferer(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_prediction_shape(self, stage_1_params, stage_2_params, ordering_params, input_shape, latent_shape):
stage_1 = VQVAE(**stage_1_params)
stage_2 = DecoderOnlyTransformer(**stage_2_params)
ordering = Ordering(**ordering_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

input = torch.randn(input_shape).to(device)

inferer = VQVAETransformerInferer()
prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering)
self.assertEqual(prediction.shape, latent_shape)

def test_sample(self):
stage_1 = VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_levels=2,
downsample_parameters=((2, 4, 1, 1),) * 2,
upsample_parameters=((2, 4, 1, 1, 0),) * 2,
num_res_layers=1,
num_channels=8,
num_res_channels=(8, 8),
num_embeddings=16,
embedding_dim=8,
)
stage_2 = DecoderOnlyTransformer(
num_tokens=16 + 1,
max_seq_len=4 + 1,
attn_layers_dim=4,
attn_layers_depth=2,
attn_layers_heads=1,
with_cross_attention=False,
)
ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2))

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

inferer = VQVAETransformerInferer()

starting_token = 16 # from stage_1 num_embeddings

sample = inferer.sample(
latent_spatial_dim=(2, 2),
starting_tokens=starting_token * torch.ones((2, 1), device=device),
vqvae_model=stage_1,
transformer_model=stage_2,
ordering=ordering,
)
self.assertEqual(sample.shape, (2, 1, 8, 8))


if __name__ == "__main__":
unittest.main()