Skip to content

Commit

Permalink
[Multimodal model] Add the decoder of the multimodal model (pytorch#626)
Browse files Browse the repository at this point in the history
After pytorch#589, we are adding
decoder model to torchtitan with simple unit test.
  • Loading branch information
fduwjj authored and mori360 committed Nov 25, 2024
1 parent 5260370 commit 57ba1b8
Show file tree
Hide file tree
Showing 3 changed files with 637 additions and 84 deletions.
76 changes: 65 additions & 11 deletions test/multimodal_model/test_multimodal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,50 @@

import pytest
import torch
from torchtitan.models.llama_multimodal import ModelArgs, VisionEncoder
from torchtitan.models.llama_multimodal import (
ModelArgs,
MultimodalDecoder,
VisionEncoder,
)

from test.multimodal_model.test_utils import fixed_init_model, fixed_init_tensor


@pytest.fixture
def model_config():
def encoder_config():
return ModelArgs(
dim=32,
num_layers=2,
num_heads=4,
encoder_embed_dim=32,
encoder_num_layers=2,
encoder_num_heads=4,
tile_size=49,
patch_size=9,
max_num_tiles=4,
in_channels=3,
return_intermediates=[0, 1],
num_layers_learnable_head=2,
num_layers_projection=2,
decoder_embed_dim=128,
)


@pytest.fixture
def decoder_config():
return ModelArgs(
decoder_embed_dim=512,
vocab_size=10000,
fusion_interval=2,
num_special_tokens=3,
decoder_num_layers=6,
decoder_num_heads=8,
decoder_num_kv_heads=4,
max_seq_len=512,
rope_theta=50000.0,
)


class TestMultimodalModelVisionEncoder:
@pytest.fixture(autouse=True)
def setup_class(self, model_config):
self.model_args = model_config
def setup_class(self, encoder_config):
self.model_args = encoder_config
self.batch_size = 1
self.num_imgs = 2
self.num_tiles = 4
Expand All @@ -52,10 +71,7 @@ def setup_class(self, model_config):
def test_llama_mm_vision_encoder(self):
model = VisionEncoder(self.model_args)
fixed_init_model(model, min_val=-1, max_val=1)
# call model
output = model(self.image, self.aspect_ratio)

# assertion
expected_shape = (
self.batch_size,
self.num_imgs * self.num_tiles * (model.vit.patches_per_tile + 1),
Expand All @@ -71,3 +87,41 @@ def test_llama_mm_vision_encoder(self):
# assert torch.allclose(
# output.mean(), torch.tensor(5.28800), atol=1e-3, rtol=1e-3
# )


class TestMultimodalModelDecoder:
@pytest.fixture(autouse=True)
def setup_class(self, decoder_config):
self.model_args = decoder_config
self.batch_size = 1
self.decoder_embed_dim = self.model_args.decoder_embed_dim
self.vocab_size = self.model_args.vocab_size
self.seq_len = 128
self.input = {
"tokens": torch.arange(self.batch_size * self.seq_len).reshape(
self.batch_size, self.seq_len
),
"encoder_input": fixed_init_tensor(
(self.batch_size, self.seq_len, self.decoder_embed_dim),
min_val=-1,
max_val=1,
),
"encoder_mask": None,
}

@torch.no_grad()
def test_llama_mm_decoder(self):
model = MultimodalDecoder(self.model_args)
fixed_init_model(model, min_val=-1, max_val=1)
output = model(**self.input)
expected_shape = (self.batch_size, self.seq_len, self.vocab_size)
assert (
output.shape == expected_shape
), f"Expected shape {expected_shape}, but got {output.shape}"

# TODO: Need to ensure numerical stability before doing convergence test.
# output.mean() = -0.0134, we need to debug why it is not close to -9.47548e-5, which is
# the test value from the original torch tune test
# assert torch.allclose(
# output.mean(), torch.tensor(-9.47548e-5), atol=1e-3, rtol=1e-3
# )
10 changes: 7 additions & 3 deletions torchtitan/models/llama_multimodal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Llama 2 is licensed under the LLAMA 2 Community License,
# Llama 3 is licensed under the LLAMA 3 Community License,
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.models.llama_multimodal.model import ModelArgs, VisionEncoder
from torchtitan.models.llama_multimodal.model import (
ModelArgs,
MultimodalDecoder,
VisionEncoder,
)

__all__ = ["VisionEncoder", "ModelArgs"]
__all__ = ["VisionEncoder", "ModelArgs", "MultimodalDecoder"]

llama3_2_configs = {
# TODO: add configs for llama3.2
Expand Down
Loading

0 comments on commit 57ba1b8

Please sign in to comment.