diff --git a/Quick_demo/Language_files/config.json b/Quick_demo/Language_files/config.json
new file mode 100644
index 0000000..0f9f43e
--- /dev/null
+++ b/Quick_demo/Language_files/config.json
@@ -0,0 +1,23 @@
+{
+ "_name_or_path": "/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/llama-13b-hf",
+ "architectures": [
+ "LlamaForCausalLM"
+ ],
+ "bos_token_id": 0,
+ "eos_token_id": 1,
+ "hidden_act": "silu",
+ "hidden_size": 5120,
+ "initializer_range": 0.02,
+ "intermediate_size": 13824,
+ "max_sequence_length": 2048,
+ "model_type": "llama",
+ "num_attention_heads": 40,
+ "num_hidden_layers": 40,
+ "pad_token_id": -1,
+ "rms_norm_eps": 1e-06,
+ "tie_word_embeddings": false,
+ "torch_dtype": "float32",
+ "transformers_version": "4.28.0.dev0",
+ "use_cache": true,
+ "vocab_size": 32000
+}
diff --git a/Quick_demo/Language_files/special_tokens_map.json b/Quick_demo/Language_files/special_tokens_map.json
new file mode 100644
index 0000000..9e26dfe
--- /dev/null
+++ b/Quick_demo/Language_files/special_tokens_map.json
@@ -0,0 +1 @@
+{}
\ No newline at end of file
diff --git a/Quick_demo/Language_files/tokenizer.model b/Quick_demo/Language_files/tokenizer.model
new file mode 100644
index 0000000..22bccbc
Binary files /dev/null and b/Quick_demo/Language_files/tokenizer.model differ
diff --git a/Quick_demo/Language_files/tokenizer_config.json b/Quick_demo/Language_files/tokenizer_config.json
new file mode 100644
index 0000000..a54b01a
--- /dev/null
+++ b/Quick_demo/Language_files/tokenizer_config.json
@@ -0,0 +1 @@
+{"bos_token": "", "eos_token": "", "model_max_length": 1000000000000000019884624838656, "tokenizer_class": "LlamaTokenizer", "unk_token": ""}
\ No newline at end of file
diff --git a/Quick_demo/Model/RadFM/__init__.py b/Quick_demo/Model/RadFM/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/Quick_demo/Model/RadFM/__pycache__/__init__.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..de287b2
Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/__init__.cpython-39.pyc differ
diff --git a/Quick_demo/Model/RadFM/__pycache__/blocks.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/blocks.cpython-39.pyc
new file mode 100644
index 0000000..0349241
Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/blocks.cpython-39.pyc differ
diff --git a/Quick_demo/Model/RadFM/__pycache__/helpers.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/helpers.cpython-39.pyc
new file mode 100644
index 0000000..fd945c9
Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/helpers.cpython-39.pyc differ
diff --git a/Quick_demo/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc
new file mode 100644
index 0000000..f2b8999
Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc differ
diff --git a/Quick_demo/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc
new file mode 100644
index 0000000..5147283
Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc differ
diff --git a/Quick_demo/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc
new file mode 100644
index 0000000..0a33999
Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc differ
diff --git a/Quick_demo/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc
new file mode 100644
index 0000000..c8cedd6
Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc differ
diff --git a/Quick_demo/Model/RadFM/__pycache__/utils.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000..3cf863e
Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/utils.cpython-39.pyc differ
diff --git a/Quick_demo/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc
new file mode 100644
index 0000000..0c892a0
Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc differ
diff --git a/Quick_demo/Model/RadFM/blocks.py b/Quick_demo/Model/RadFM/blocks.py
new file mode 100644
index 0000000..ccf1a34
--- /dev/null
+++ b/Quick_demo/Model/RadFM/blocks.py
@@ -0,0 +1,400 @@
+from collections import OrderedDict
+from typing import Tuple, Union, Callable, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+class PMC_CLIP_cfg:
+ backbone: str = 'ModifiedRN50' # ['RN50', 'ModifiedRN50', 'MAE']
+ layers: Union[Tuple[int, int, int, int], int] = [3,4,6,3]
+ width: int = 64
+ head_width: int = 64
+ mlp_ratio: float = 4.0
+ patch_size: int = 16
+ image_size: Union[Tuple[int, int], int] = 224
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
+ patch_dropout: float = 0.0 # patch dropout rate, no dropout by default
+ drop_attention_rate: float = 0. # Transformer Dropout
+ patch_size: None
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu2 = nn.ReLU(inplace=True)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu3 = nn.ReLU(inplace=True)
+
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu1(self.bn1(self.conv1(x)))
+ out = self.relu2(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu3(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x, key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+
+ return x[0]
+
+
+class ResNet(nn.Module):
+ """
+ RN50
+ """
+
+ def __init__(
+ self, layers, output_dim, heads, image_size=224, width=64,
+ block=Bottleneck,
+ ):
+ super().__init__()
+ self.output_dim = output_dim
+ self.image_size = image_size
+
+ # the 1-layer stem
+ self.conv1 = nn.Conv2d(3, width, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ # self.head = nn.Linear(512 * 6, output_dim)
+ self.head = nn.Linear(512 * block.expansion, output_dim)
+
+ # embed_dim = width * 32 # the ResNet feature dimension
+ # self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
+
+ self.init_parameters()
+
+ def _make_layer(
+ self,
+ planes, blocks, stride=1,
+ block=Bottleneck,
+ ):
+ layers = [block(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def init_parameters(self):
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
+ for param in self.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ # FIXME support for non-transformer
+ pass
+
+ def stem(self, x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.maxpool(x)
+ return x
+
+ def forward(self, x):
+ # x[0]: [batch_size, 3, 224, 224]
+ # x[1]: [batch_size, 1]
+ x = self.stem(x) # [batch_size, 64, 56, 56]
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x) # [batch_size, 2048, 7, 7]
+ x = self.avgpool(x) # [batch_size, 2048, 1, 1]
+ x = torch.flatten(x, 1) # [batch_size, 2048*1*1]
+ x = self.head(x) # [batch_size, 1024]
+
+ visual_output = dict.fromkeys(["image_features", "mim_loss"], None)
+ visual_output.update({
+ 'image_features': x,
+ })
+
+ return visual_output
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.image_size = image_size
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.relu3 = nn.ReLU(inplace=True)
+ self.avgpool = nn.AvgPool2d(2)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
+
+ self.init_parameters()
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def init_parameters(self):
+ if self.attnpool is not None:
+ std = self.attnpool.c_proj.in_features ** -0.5
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
+ for param in self.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ # FIXME support for non-transformer
+ pass
+
+ def stem(self, x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.avgpool(x)
+ return x
+
+ def forward(self, x):
+ x = self.stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ visual_output = dict.fromkeys(["image_features", "mim_loss"], None)
+ visual_output.update({
+ 'image_features': x,
+ })
+
+ return visual_output
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class QuickGELU(nn.Module):
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU,
+ drop_attention_rate: float = 0.,
+ ):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(
+ embed_dim=d_model,
+ num_heads=n_head,
+ dropout=drop_attention_rate,
+ )
+ self.ln_1 = LayerNorm(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class PatchDropout(nn.Module):
+ """
+ https://arxiv.org/abs/2212.00794
+ """
+
+ def __init__(self, prob, exclude_first_token=True):
+ super().__init__()
+ assert 0 <= prob < 1.
+ self.prob = prob
+ self.exclude_first_token = exclude_first_token # exclude CLS token
+
+ def forward(self, x):
+ if not self.training or self.prob == 0.:
+ return x
+
+ if self.exclude_first_token:
+ cls_tokens, x = x[:, :1], x[:, 1:]
+ else:
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
+
+ batch = x.size()[0]
+ num_tokens = x.size()[1]
+
+ batch_indices = torch.arange(batch)
+ batch_indices = batch_indices[..., None]
+
+ keep_prob = 1 - self.prob
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
+
+ rand = torch.randn(batch, num_tokens)
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
+
+ x = x[batch_indices, patch_indices_keep]
+
+ if self.exclude_first_token:
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU,
+ drop_attention_rate: float = 0.,
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.grad_checkpointing = False
+
+ self.resblocks = nn.ModuleList([
+ ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, drop_attention_rate=drop_attention_rate)
+ for _ in range(layers)
+ ])
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ for r in self.resblocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
\ No newline at end of file
diff --git a/Quick_demo/Model/RadFM/helpers.py b/Quick_demo/Model/RadFM/helpers.py
new file mode 100644
index 0000000..78b4896
--- /dev/null
+++ b/Quick_demo/Model/RadFM/helpers.py
@@ -0,0 +1,275 @@
+"""
+Taken from https://github.com/lucidrains/flamingo-pytorch
+"""
+
+import torch
+from einops import rearrange, repeat
+from einops_exts import rearrange_many
+from torch import einsum, nn
+
+
+def exists(val):
+ return val is not None
+
+
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm_media = nn.LayerNorm(dim)
+ self.norm_latents = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, T, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, T, n2, D)
+ """
+ x = self.norm_media(x)
+ latents = self.norm_latents(latents)
+
+ h = self.heads
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
+ q = q * self.scale
+
+ # attention
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
+ attn = sim.softmax(dim=-1)
+
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
+ return self.to_out(out)
+
+
+class PerceiverResampler(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth=6,
+ dim_head=64,
+ heads=8,
+ num_latents=64,
+ max_num_media=None,
+ max_num_frames=None,
+ ff_mult=4,
+ ):
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
+ self.frame_embs = (
+ nn.Parameter(torch.randn(max_num_frames, dim))
+ if exists(max_num_frames)
+ else None
+ )
+ self.media_time_embs = (
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
+ if exists(max_num_media)
+ else None
+ )
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ self.norm = nn.LayerNorm(dim)
+
+ def forward(self, x):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, T, F, v, D)
+ Returns:
+ shape (b, T, n, D) where n is self.num_latents
+ """
+ b, T, F, v = x.shape[:4]
+
+ # frame and media time embeddings
+ if exists(self.frame_embs):
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
+ x = x + frame_embs
+ x = rearrange(
+ x, "b T F v d -> b T (F v) d"
+ ) # flatten the frame and spatial dimensions
+ if exists(self.media_time_embs):
+ x = x + self.media_time_embs[:T]
+
+ # blocks
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+ return self.norm(latents)
+
+
+# gated cross attention
+
+
+class MaskedCrossAttention(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ dim_visual,
+ dim_head=64,
+ heads=8,
+ only_attend_immediate_media=True,
+ ):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ # whether for text to only attend to immediate preceding image, or all previous images
+ self.only_attend_immediate_media = only_attend_immediate_media
+
+ def forward(self, x, media, media_locations=None, attend_previous=True):
+ """
+ Args:
+ x (torch.Tensor): text features
+ shape (B, T_txt, D_txt)
+ media (torch.Tensor): image features
+ shape (B, T_img, n, D_img) where n is the dim of the latents
+ media_locations: boolean mask identifying the media tokens in x
+ shape (B, T_txt)
+ attend_previous: bool
+ If false, ignores immediately preceding image and starts attending when following image
+ """
+ _, T_img, n = media.shape[:3]
+ h = self.heads
+
+ x = self.norm(x)
+
+ q = self.to_q(x)
+ media = rearrange(media, "b t n d -> b (t n) d")
+
+ k, v = self.to_kv(media).chunk(2, dim=-1)
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
+
+ q = q * self.scale
+
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
+
+ if exists(media_locations):
+ # at each boolean of True, increment the time counter (relative to media time)
+ text_time = media_locations.cumsum(dim=-1)
+ media_time = torch.arange(T_img, device=x.device) + 1
+
+ if not attend_previous:
+ text_time[~media_locations] += 1
+ # make sure max is still the number of images in the sequence
+ text_time[
+ text_time
+ > repeat(
+ torch.count_nonzero(media_locations, dim=1),
+ "b -> b i",
+ i=text_time.shape[1],
+ )
+ ] = 0
+
+ # text time must equal media time if only attending to most immediate image
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
+
+ text_to_media_mask = mask_op(
+ rearrange(text_time, "b i -> b 1 i 1"),
+ repeat(media_time, "j -> 1 1 1 (j n)", n=n),
+ )
+ sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
+
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
+ attn = sim.softmax(dim=-1)
+
+ if exists(media_locations) and self.only_attend_immediate_media:
+ # any text without a preceding media needs to have attention zeroed out
+ text_without_media_mask = text_time == 0
+ text_without_media_mask = rearrange(
+ text_without_media_mask, "b i -> b 1 i 1"
+ )
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
+
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+class GatedCrossAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ dim_visual,
+ dim_head=64,
+ heads=8,
+ ff_mult=4,
+ only_attend_immediate_media=True,
+ ):
+ super().__init__()
+ self.attn = MaskedCrossAttention(
+ dim=dim,
+ dim_visual=dim_visual,
+ dim_head=dim_head,
+ heads=heads,
+ only_attend_immediate_media=only_attend_immediate_media,
+ )
+ self.attn_gate = nn.Parameter(torch.tensor([0.0]))
+
+ self.ff = FeedForward(dim, mult=ff_mult)
+ self.ff_gate = nn.Parameter(torch.tensor([0.0]))
+
+ def forward(
+ self,
+ x,
+ media,
+ media_locations=None,
+ attend_previous=True,
+ ):
+ x = (
+ self.attn(
+ x,
+ media,
+ media_locations=media_locations,
+ attend_previous=attend_previous,
+ )
+ * self.attn_gate.tanh()
+ + x
+ )
+ x = self.ff(x) * self.ff_gate.tanh() + x
+
+ return x
diff --git a/Quick_demo/Model/RadFM/multimodality_model.py b/Quick_demo/Model/RadFM/multimodality_model.py
new file mode 100644
index 0000000..0855346
--- /dev/null
+++ b/Quick_demo/Model/RadFM/multimodality_model.py
@@ -0,0 +1,84 @@
+from torch import nn
+from transformers.models.llama import LlamaForCausalLM
+from transformers import AutoConfig
+from .my_embedding_layer import MyEmbedding
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+import tqdm.auto as tqdm
+import torch.nn as nn
+import torch
+from torch.utils.checkpoint import checkpoint
+from torch.autograd import Variable
+import numpy as np
+class MultiLLaMAForCausalLM(nn.Module):
+ def __init__(self, lang_model_path):
+ super(MultiLLaMAForCausalLM, self).__init__()
+ try:
+ self.lang_model = LlamaForCausalLM.from_pretrained(
+ lang_model_path,
+ )
+ except:
+ config = AutoConfig.from_pretrained(lang_model_path)
+ self.lang_model = LlamaForCausalLM(config)
+ self.lang_model.gradient_checkpointing_enable()
+ self.lang_model.enable_input_require_grads()
+ # self.lang_model.requires_grad_(False)
+ self.embedding_layer = MyEmbedding()
+ self.embedding_layer.weight = self.lang_model.get_input_embeddings().weight
+ self.hidden_dim = 5120
+ self.voc_size = 32000
+
+ def forward(self,lang_x, vision_x, attention_mask, labels, loss_reweight,key_words_query):
+ if labels.shape == lang_x.shape:
+ self.embedding_layer.flag = 'Text'
+ # lang_x = lang_x.to(vision_x.dtype)
+ # lang_x = lang_x + torch.zeros(1, dtype=lang_x.dtype, device=lang_x.device, requires_grad=True)
+ # vision_x = vision_x + torch.zeros(1, dtype=vision_x.dtype, device=vision_x.device, requires_grad=True)
+ # input_embedding = checkpoint(self.embedding_layer, lang_x, vision_x)
+ input_embedding,loss_match= self.embedding_layer(lang_x, vision_x,key_words_query) # ,loss_matching
+ output = self.lang_model(inputs_embeds = input_embedding,attention_mask = attention_mask, labels = labels)
+ logits = output['logits']
+
+ loss_reg = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ shift_loss_reweight = loss_reweight[...,1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(reduction = 'none')
+ shift_logits = shift_logits.view(-1, self.voc_size)
+ shift_labels = shift_labels.view(-1)
+ shift_loss_reweight = shift_loss_reweight.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ shift_loss_reweight = shift_loss_reweight.to(shift_logits.device)
+ loss_reg = loss_fct(shift_logits, shift_labels)
+ loss_reg = torch.sum(shift_loss_reweight*loss_reg)/torch.sum(shift_loss_reweight)
+ loss = loss_reg
+ if loss_match!= None:
+ loss = 0.8*loss + 0.2*loss_match
+
+ logits = output['logits'][..., :-1, :].contiguous().detach()
+ total = len(labels)
+ predictions = torch.argmax(logits, dim=-1)
+ labels = labels[..., 1:].contiguous()
+ Acc = torch.sum(torch.all(torch.logical_or(predictions == labels, labels == -100),dim = -1))
+ Accuracy = Acc /total
+
+ return dict(
+ # loss_reg = loss_reg,
+ # loss_matching = loss_matching,
+ logits = Accuracy,
+ loss = output['loss'],
+ )
+ ### useless for now ignore the folowing codes ###
+ # if labels.shape == vision_x.shape:
+ # self.embedding_layer.flag = 'Seg'
+ # input_embedding = self.embedding_layer(lang_x, vision_x)
+
+ def generate(self, lang_x,vision_x):
+ self.embedding_layer.flag = 'Text'
+ with torch.no_grad():
+ input_embedding,_ = self.embedding_layer(lang_x, vision_x)
+ generation = self.lang_model.generate(inputs_embeds = input_embedding, max_new_tokens =200,top_k=50)
+ return generation
diff --git a/Quick_demo/Model/RadFM/my_embedding_layer.py b/Quick_demo/Model/RadFM/my_embedding_layer.py
new file mode 100644
index 0000000..0c1b9b2
--- /dev/null
+++ b/Quick_demo/Model/RadFM/my_embedding_layer.py
@@ -0,0 +1,163 @@
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+from .helpers import PerceiverResampler
+from .utils import get_visual_encoder
+from einops import rearrange, repeat
+from einops_exts import rearrange_many
+import torchvision
+from .vit_3d import ViT
+from einops.layers.torch import Rearrange
+from .transformer_decoder import TransformerDecoder,TransformerDecoderLayer
+from torch.utils.checkpoint import checkpoint
+from torch.autograd import Variable
+import random
+from transformers import AutoTokenizer, AutoModel
+
+class MyEmbedding(nn.Module):
+ def __init__(self, num_embeddings=32000, embedding_dim=5120, perceiver_num=32,vis_dim = 768, patch_size=32, frame_patch_size = 4 ,seg_channel = 256):
+ super().__init__()
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.weight = nn.Parameter(torch.torch.randn((num_embeddings, embedding_dim)))
+ self.figure_token_weight = nn.Parameter(torch.randn((2, embedding_dim)))
+ self.flag = 'Text'
+ self.patch_size = patch_size
+ self.frame_patch_size = frame_patch_size
+ self.seg_channel = seg_channel
+
+ self.bert_tokenizer = AutoTokenizer.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT")
+ self.bert_model = AutoModel.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT")
+ self.bert_projection_fc = nn.Linear(768,vis_dim)
+
+ self.vision_encoder = ViT(
+ image_size = 512, # image size
+ frames = 512, # max number of frames
+ image_patch_size = patch_size, # image patch size
+ frame_patch_size = frame_patch_size, # frame patch size
+ dim = vis_dim,
+ depth = 12,
+ heads = 8,
+ mlp_dim = 2048,
+ dropout = 0.1,
+ emb_dropout = 0.1
+ )
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose3d(vis_dim, vis_dim // 4, kernel_size=2, stride=2),
+ nn.BatchNorm3d(vis_dim // 4),
+ nn.GELU(),
+ nn.ConvTranspose3d(vis_dim // 4, vis_dim // 8, kernel_size=2, stride=2),
+ nn.GELU(),
+ )
+
+ decoder_layer = TransformerDecoderLayer(d_model = vis_dim, nhead = 8, normalize_before=True)
+ decoder_norm = nn.LayerNorm(vis_dim)
+ self.transformer_decoder = TransformerDecoder(decoder_layer = decoder_layer, num_layers = 4, norm=decoder_norm)
+ self.transformer_decoder_mlp = nn.Sequential(
+ nn.Linear(vis_dim,vis_dim // 4),
+ nn.GELU(),
+ nn.Linear(vis_dim // 4,vis_dim // 8),
+ nn.GELU(),
+ )
+ self.vis_dim = vis_dim
+
+ self.perceiver = PerceiverResampler(dim=self.vis_dim, num_latents = perceiver_num)
+ self.fc = nn.Linear(self.vis_dim,self.embedding_dim)
+ self.cls_head = nn.Linear(self.vis_dim // 8, 1)
+
+
+ def forward(self, text_input, vision_x, key_words_query = None):
+ if self.flag == 'Text':
+ B,S,C,H,W,D = vision_x.shape
+ vision_x = rearrange(vision_x, "b S c h w d-> (b S) c h w d")
+
+
+ vision_x, pos_embedding = self.vision_encoder(vision_x)
+ # vision_x = Variable(vision_x,requires_grad=True)
+ # vision_x, _ = checkpoint(self.vision_encoder,vision_x)
+
+ vision_x = rearrange(vision_x, "(b s F) v d -> b s F v d", b=B, s=S,F=1)
+
+ loss_matching = None
+ if key_words_query != None:
+ # key_words_query list[list[str]] B, words, each word matches corresponding vision_x embedding
+ query_words = [item for sublist in key_words_query for item in sublist]
+ query_words = list(set(query_words))
+ if len(query_words)>16:
+ random.shuffle(query_words)
+ query_words = query_words[0:16]
+ if query_words != []:
+ contrastive_labels = torch.zeros(B,len(query_words)) #B Q
+ for i,sublist in enumerate(key_words_query):
+ for j,item in enumerate(query_words):
+ if item in sublist:
+ contrastive_labels[i,j] = 1
+ contrastive_labels = contrastive_labels.to(vision_x.dtype).to(vision_x.device)
+
+ with torch.no_grad():
+ query_words_embedding = self.bert_tokenizer(query_words, padding='max_length', truncation=True, max_length=256,return_tensors="pt")
+ query_words_embedding = self.bert_model(input_ids = query_words_embedding['input_ids'].to(vision_x.device),attention_mask = query_words_embedding['attention_mask'].to(vision_x.device))['last_hidden_state'][:,0,:].to(vision_x.dtype).to(vision_x.device) # Q,D
+ query_words_embedding = self.bert_projection_fc(query_words_embedding)
+ query_words_embedding = query_words_embedding.unsqueeze(0).repeat(B,1,1) # B,Q,D
+ _,N,_ = query_words_embedding.shape
+
+ image_embedding = vision_x.mean(dim=1) # B V D average pooling 去除掉多模态。
+ image_embedding = rearrange(image_embedding, "b F v d -> b (F v) d")
+ pos_embedding = rearrange(pos_embedding, "(b s) v d -> b s v d", b=B, s=S)[:,0,:,:]
+
+ image_embedding = image_embedding.transpose(0,1) # (H/P W/P D/P) B D
+ pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D
+ query_words_embedding = query_words_embedding.transpose(0,1) # N B D
+
+ oo_embedding,_ = self.transformer_decoder(query_words_embedding, image_embedding, pos = pos_embedding)
+ oo_embedding = oo_embedding.transpose(0,1) # B Q D
+ oo_embedding = rearrange(oo_embedding, 'b n d -> (b n) d')
+ oo_embedding = self.transformer_decoder_mlp(oo_embedding)
+ oo_embedding = self.cls_head(oo_embedding).mean(dim = -1)
+ oo_embedding = rearrange(oo_embedding, '(b n) -> b n', b=B, n=N) # B Q
+ # oo_embedding = rearrange(oo_embedding, 'b n d -> b (n d)') # B Q
+ loss_matching = F.binary_cross_entropy_with_logits(oo_embedding, contrastive_labels)
+
+ vision_x = self.perceiver(vision_x) # reshapes to (b, S, n, d)
+ #vision_x = checkpoint(self.perceiver,vision_x)
+
+ n = vision_x.shape[2]
+
+ vision_x = rearrange(vision_x, "b s n d -> (b s n) d")
+ vision_x = self.fc(vision_x)
+ vision_x = rearrange(vision_x, "(b T) d -> b T d", b=B, T=n*S)
+
+ embedding_weight = torch.cat([self.weight, self.figure_token_weight],dim = 0)
+ embedding_weight = embedding_weight.unsqueeze(0).repeat(B, 1, 1)
+ embedding_weight = torch.cat([embedding_weight,vision_x],dim = 1)
+ text_input = F.one_hot(text_input,embedding_weight.shape[1]).to(vision_x.dtype).to(vision_x.device)
+ out_put = torch.matmul(text_input, embedding_weight)
+
+ ## useless for now. ignore the folowing code##
+ # if self.flag == 'Seg':
+ # B,C,H,W,D = vision_x.shape
+ # _,N,_ = text_input.shape
+ # latent_embedding, pos_embedding = self.vision_encoder(vision_x) # B (H/P W/P D/P) D
+
+ # image_embedding = latent_embedding.transpose(0,1) # (H/P W/P D/P) B D
+ # pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D
+ # text_input = text_input.transpose(0,1) # N B D
+
+ # mask_embedding,_ = self.transformer_decoder(text_input, image_embedding, pos = pos_embedding)
+ # mask_embedding = mask_embedding.transpose(0,1) # B N D
+ # mask_embedding = rearrange(mask_embedding, 'b n d -> (b n) d')
+ # mask_embedding = self.transformer_decoder_mlp(mask_embedding)
+ # mask_embedding = rearrange(mask_embedding, '(b n) d -> b n d', b=B, n=N,d = self.vis_dim // 8)
+
+ # vision_x = rearrange(latent_embedding,'b (h w d) c -> b c h w d', h = (H // self.patch_size), w = (W // self.patch_size), d = (D // self.frame_patch_size), c=self.vis_dim)
+ # vision_x = self.output_upscaling(vision_x) #B C H/4 W/4 D/4
+ # out_put = torch.einsum('bchwd,bnc->bnhwd', vision_x, mask_embedding)
+
+ return out_put,loss_matching
+
+# model = MyEmbedding(vision_encoder_path = '')
+# text_input = torch.randint(low=0, high=3210, size=(4,2048))
+# image_input = torch.randn((4,3,3,512,512,4))
+# key_words_query = [[],[],[],['consoliation']]
+# print(model(text_input, image_input, key_words_query))
diff --git a/Quick_demo/Model/RadFM/position_encoding.py b/Quick_demo/Model/RadFM/position_encoding.py
new file mode 100644
index 0000000..9d0af4b
--- /dev/null
+++ b/Quick_demo/Model/RadFM/position_encoding.py
@@ -0,0 +1,121 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+from einops.layers.torch import Rearrange
+from einops import rearrange, repeat
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor_list):
+ x = tensor_list.tensors
+ mask = tensor_list.mask
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+ def __init__(self, num_pos_feats=256):
+ super().__init__()
+ self.row_embed = nn.Embedding(50, num_pos_feats)
+ self.col_embed = nn.Embedding(50, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, tensor_list):
+ x = tensor_list.tensors
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = torch.cat([
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
+ return pos
+
+class PositionEmbeddingLearned3d(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+ def __init__(self, num_pos_feats=256,h_patch_num = 16, w_patch_num = 16,d_patch_num = 64):
+ super().__init__()
+ self.h_patch_num = h_patch_num
+ self.w_patch_num = w_patch_num
+ self.d_patch_num = d_patch_num
+ self.row_embed = nn.Embedding(h_patch_num, num_pos_feats)
+ self.col_embed = nn.Embedding(w_patch_num, num_pos_feats)
+ self.dep_embed = nn.Embedding(d_patch_num, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+ nn.init.uniform_(self.dep_embed.weight)
+
+ def forward(self, B, h, w, d,x):
+ i = (torch.arange(h, device=x.device) + 1)* (self.h_patch_num // h) -1
+ j = (torch.arange(w, device=x.device) + 1)* (self.w_patch_num // w) -1
+ k = (torch.arange(d, device=x.device) + 1)* (self.d_patch_num // d) -1
+ x_emb = self.row_embed(i).unsqueeze(1).unsqueeze(2).repeat(1,w,d,1)
+ y_emb = self.col_embed(j).unsqueeze(0).unsqueeze(2).repeat(h,1,d,1)
+ z_emb = self.dep_embed(k).unsqueeze(0).unsqueeze(1).repeat(h,w,1,1)
+ pos = torch.cat([x_emb,y_emb,z_emb,], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1, 1)
+ pos = rearrange(pos,'b h w d c -> b (h w d) c')
+ return pos
+
+def build_position_encoding(args):
+ N_steps = args.hidden_dim // 2
+ if args.position_embedding in ('v2', 'sine'):
+ # TODO find a better way of exposing other arguments
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
+ elif args.position_embedding in ('v3', 'learned'):
+ position_embedding = PositionEmbeddingLearned(N_steps)
+ else:
+ raise ValueError(f"not supported {args.position_embedding}")
+
+ return position_embedding
+
+# Pos = PositionEmbeddingLearned3d()
+# x = torch.randn((8,3,32,32,1))
+# print(Pos(8,16,16,1,x))
\ No newline at end of file
diff --git a/Quick_demo/Model/RadFM/transformer_decoder.py b/Quick_demo/Model/RadFM/transformer_decoder.py
new file mode 100644
index 0000000..b50a350
--- /dev/null
+++ b/Quick_demo/Model/RadFM/transformer_decoder.py
@@ -0,0 +1,160 @@
+"""
+Code modified from DETR tranformer:
+https://github.com/facebookresearch/detr
+Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+
+import copy
+from typing import Optional, List
+import pickle as cp
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+
+class TransformerDecoder(nn.Module):
+
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ output = tgt
+ T,B,C = memory.shape
+ intermediate = []
+ atten_layers = []
+ for n,layer in enumerate(self.layers):
+
+ residual=True
+ output,ws = layer(output, memory, tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos, query_pos=query_pos,residual=residual)
+ atten_layers.append(ws)
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+ return output,atten_layers
+
+
+
+class TransformerDecoderLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ residual=True):
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)
+ tgt = self.norm1(tgt)
+ tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)
+
+
+ # attn_weights [B,NUM_Q,T]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt,ws
+
+ def forward_pre(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt,attn_weights
+
+ def forward(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ residual=True):
+ if self.normalize_before:
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual)
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
diff --git a/Quick_demo/Model/RadFM/utils.py b/Quick_demo/Model/RadFM/utils.py
new file mode 100644
index 0000000..1624348
--- /dev/null
+++ b/Quick_demo/Model/RadFM/utils.py
@@ -0,0 +1,74 @@
+from .blocks import ModifiedResNet,PMC_CLIP_cfg
+import torch
+from torchvision import transforms
+from PIL import Image
+import torch.nn as nn
+def extend_instance(obj, mixin):
+ """Apply mixins to a class instance after creation"""
+ base_cls = obj.__class__
+ base_cls_name = obj.__class__.__name__
+ obj.__class__ = type(
+ base_cls_name, (mixin, base_cls), {}
+ ) # mixin needs to go first for our forward() logic to work
+
+
+def getattr_recursive(obj, att):
+ """
+ Return nested attribute of obj
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
+ """
+ if att == "":
+ return obj
+ i = att.find(".")
+ if i < 0:
+ return getattr(obj, att)
+ else:
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
+
+
+def setattr_recursive(obj, att, val):
+ """
+ Set nested attribute of obj
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
+ """
+ if "." in att:
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
+ setattr(obj, att.split(".")[-1], val)
+
+
+
+def get_visual_encoder(model_str):
+ """
+ Args:
+ str (_type_): str_to_model_path
+ Return:
+ vision_model, visual_dim, img_preprocessor
+ """
+ normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
+ img_preprocessor = transforms.Compose([
+ transforms.Resize((512,512), interpolation=Image.BICUBIC),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ if 'PMC-CLIP' in model_str:
+ #vision_cfg = json.load(open(model_args.visual_model_config,'r'))['vision_cfg']
+ vision_cfg = PMC_CLIP_cfg()
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
+ vision_model = ModifiedResNet(
+ layers=vision_cfg.layers,
+ heads=vision_heads,
+ output_dim = 768,
+ image_size=vision_cfg.image_size,
+ width=vision_cfg.width
+ )
+ vision_model = vision_load_pretrain(vision_model,model_str)
+ vision_model = nn.Sequential(*list(vision_model.children())[:-2])
+ visual_dim = 1024
+ return vision_model,visual_dim,img_preprocessor
+
+def vision_load_pretrain(resnet,model_path):
+ checkpoint = torch.load(model_path, map_location='cpu')
+ state_dict = checkpoint['state_dict']
+ state_dict = {k.replace('module.visual.',''): v for k, v in state_dict.items() if '.visual' in k}
+ resnet.load_state_dict(state_dict)
+ return resnet
diff --git a/Quick_demo/Model/RadFM/vit_3d.py b/Quick_demo/Model/RadFM/vit_3d.py
new file mode 100644
index 0000000..1c36b2b
--- /dev/null
+++ b/Quick_demo/Model/RadFM/vit_3d.py
@@ -0,0 +1,123 @@
+import torch
+from torch import nn
+
+from einops import rearrange, repeat
+from einops.layers.torch import Rearrange
+from .position_encoding import PositionEmbeddingLearned3d
+
+# helpers
+
+def pair(t):
+ return t if isinstance(t, tuple) else (t, t)
+
+# classes
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+ def forward(self, x, **kwargs):
+ return self.fn(self.norm(x), **kwargs)
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout = 0.):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim, dim),
+ nn.Dropout(dropout)
+ )
+ def forward(self, x):
+ return self.net(x)
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+
+ self.attend = nn.Softmax(dim = -1)
+ self.dropout = nn.Dropout(dropout)
+
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim),
+ nn.Dropout(dropout)
+ ) if project_out else nn.Identity()
+
+ def forward(self, x):
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+
+ attn = self.attend(dots)
+ attn = self.dropout(attn)
+
+ out = torch.matmul(attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+class Transformer(nn.Module):
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
+ ]))
+ def forward(self, x):
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+ return x
+
+class ViT(nn.Module):
+ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
+ super().__init__()
+ image_height, image_width = pair(image_size)
+ patch_height, patch_width = pair(image_patch_size)
+
+ assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
+ assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
+
+ self.patch_height = patch_height
+ self.patch_width = patch_width
+ self.frame_patch_size = frame_patch_size
+
+ num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
+ patch_dim = channels * patch_height * patch_width * frame_patch_size
+
+ assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
+
+ self.to_patch_embedding = nn.Sequential(
+ Rearrange('b c (h p1) (w p2) (f pf) -> b (h w f) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
+ nn.LayerNorm(patch_dim),
+ nn.Linear(patch_dim, dim),
+ nn.LayerNorm(dim),
+ )
+
+ self.pos_embedding = PositionEmbeddingLearned3d(dim // 3,(image_height // patch_height), (image_width // patch_width), (frames // frame_patch_size))
+ self.dropout = nn.Dropout(emb_dropout)
+
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
+
+ def forward(self, video):
+ B, C, H, W, D = video.shape
+ x = self.to_patch_embedding(video)
+ b, n, _ = x.shape
+
+ pos = self.pos_embedding(B, H // self.patch_height, W // self.patch_width, D // self.frame_patch_size,x)
+ x += pos
+ x = self.dropout(x)
+
+ x = self.transformer(x)
+ return x,pos
diff --git a/Quick_demo/test.py b/Quick_demo/test.py
new file mode 100644
index 0000000..62719ae
--- /dev/null
+++ b/Quick_demo/test.py
@@ -0,0 +1,122 @@
+import tqdm.auto as tqdm
+import torch.nn.functional as F
+from typing import Optional, Dict, Sequence
+from typing import List, Optional, Tuple, Union
+import transformers
+from dataclasses import dataclass, field
+from Model.RadFM.multimodality_model import MultiLLaMAForCausalLM
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
+from torchvision import transforms
+from PIL import Image
+
+def get_tokenizer(tokenizer_path, max_img_size = 100, image_num = 32):
+ '''
+ Initialize the image special tokens
+ max_img_size denotes the max image put length and image_num denotes how many patch embeddings the image will be encoded to
+ '''
+ if isinstance(tokenizer_path,str):
+ image_padding_tokens = []
+ text_tokenizer = LlamaTokenizer.from_pretrained(
+ tokenizer_path,
+ )
+ special_token = {"additional_special_tokens": ["",""]}
+ for i in range(max_img_size):
+ image_padding_token = ""
+
+ for j in range(image_num):
+ image_token = ""
+ image_padding_token = image_padding_token + image_token
+ special_token["additional_special_tokens"].append("")
+ image_padding_tokens.append(image_padding_token)
+ text_tokenizer.add_special_tokens(
+ special_token
+ )
+ ## make sure the bos eos pad tokens are correct for LLaMA-like models
+ text_tokenizer.pad_token_id = 0
+ text_tokenizer.bos_token_id = 1
+ text_tokenizer.eos_token_id = 2
+
+ return text_tokenizer,image_padding_tokens
+
+def combine_and_preprocess(question,image_list,image_padding_tokens):
+
+ transform = transforms.Compose([
+ transforms.RandomResizedCrop([512,512],scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ ])
+ images = []
+ new_qestions = [_ for _ in question]
+ padding_index = 0
+ for img in image_list:
+ img_path = img['img_path']
+ position = img['position']
+
+
+
+ image = Image.open(img_path).convert('RGB')
+ image = transform(image)
+ image = image.unsqueeze(0).unsqueeze(-1) # c,w,h,d
+
+ ## pre-process the img first
+ target_H = 512
+ target_W = 512
+ target_D = 4
+ # This can be different for 3D and 2D images. For demonstration we here set this as the default sizes for 2D images.
+ images.append(torch.nn.functional.interpolate(image, size = (target_H,target_W,target_D)))
+
+ ## add img placeholder to text
+ new_qestions[position] = ""+ image_padding_tokens[padding_index] +"" + new_qestions[position]
+ padding_index +=1
+
+ vision_x = torch.cat(images,dim = 1).unsqueeze(0) #cat tensors and expand the batch_size dim
+ text = ''.join(new_qestions)
+ return [text], vision_x,
+
+
+def main():
+
+ print("Setup tokenizer")
+ text_tokenizer,image_padding_tokens = get_tokenizer('./Language_files')
+ print("Finish loading tokenizer")
+
+ ### Initialize a simple case for demo ###
+ print("Setup demo case")
+ question = "Can you identify any visible signs of Cardiomegaly in the image?"
+ image =[
+ {
+ 'img_path': './view1_frontal.jpg',
+ 'position': 0, #indicate where to put the images in the text string, range from [0,len(question)-1]
+ }, # can add abitrary number of imgs
+ ]
+
+ text,vision_x = combine_and_preprocess(question,image,image_padding_tokens)
+
+ print("Finish loading demo case")
+
+ print("Setup Model")
+ model = MultiLLaMAForCausalLM(
+ lang_model_path='./Language_files', ### Build up model based on LLaMa-13B config
+ )
+ ckpt = torch.load('./pytorch_model.bin',map_location ='cpu') # Please dowloud our checkpoint from huggingface and Decompress the original zip file first
+ model.load_state_dict(ckpt)
+ print("Finish loading model")
+
+ model = model.to('cuda')
+ model.eval()
+ with torch.no_grad():
+ lang_x = text_tokenizer(
+ question, max_length=2048, truncation=True, return_tensors="pt"
+ )['input_ids'].to('cuda')
+
+ vision_x = vision_x.to('cuda')
+ generation = model.generate(lang_x,vision_x)
+ generated_texts = text_tokenizer.batch_decode(generation, skip_special_tokens=True)
+ print('---------------------------------------------------')
+ print(question)
+ print(generated_texts[0])
+
+
+if __name__ == "__main__":
+ main()
+
\ No newline at end of file
diff --git a/Quick_demo/view1_frontal.jpg b/Quick_demo/view1_frontal.jpg
new file mode 100644
index 0000000..a11c4d0
Binary files /dev/null and b/Quick_demo/view1_frontal.jpg differ
diff --git a/src/Model/RadFM/multimodality_model.py b/src/Model/RadFM/multimodality_model.py
index 8d9d558..e88f258 100644
--- a/src/Model/RadFM/multimodality_model.py
+++ b/src/Model/RadFM/multimodality_model.py
@@ -9,7 +9,7 @@
from torch.autograd import Variable
import numpy as np
class MultiLLaMAForCausalLM(nn.Module):
- def __init__(self, lang_model_path, vision_encoder_path):
+ def __init__(self, lang_model_path):
super(MultiLLaMAForCausalLM, self).__init__()
self.lang_model = LlamaForCausalLM.from_pretrained(
lang_model_path,
@@ -17,7 +17,7 @@ def __init__(self, lang_model_path, vision_encoder_path):
self.lang_model.gradient_checkpointing_enable()
self.lang_model.enable_input_require_grads()
# self.lang_model.requires_grad_(False)
- self.embedding_layer = MyEmbedding(vision_encoder_path)
+ self.embedding_layer = MyEmbedding()
self.embedding_layer.weight = self.lang_model.get_input_embeddings().weight
self.hidden_dim = 5120
self.voc_size = 32000
diff --git a/src/Model/RadFM/my_embedding_layer.py b/src/Model/RadFM/my_embedding_layer.py
index a2c36f4..0c1b9b2 100644
--- a/src/Model/RadFM/my_embedding_layer.py
+++ b/src/Model/RadFM/my_embedding_layer.py
@@ -15,7 +15,7 @@
from transformers import AutoTokenizer, AutoModel
class MyEmbedding(nn.Module):
- def __init__(self, vision_encoder_path, num_embeddings=32000, embedding_dim=5120, perceiver_num=32,vis_dim = 768, patch_size=32, frame_patch_size = 4 ,seg_channel = 256):
+ def __init__(self, num_embeddings=32000, embedding_dim=5120, perceiver_num=32,vis_dim = 768, patch_size=32, frame_patch_size = 4 ,seg_channel = 256):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
diff --git a/src/test.py b/src/test.py
index b59dc72..73ef782 100644
--- a/src/test.py
+++ b/src/test.py
@@ -27,7 +27,7 @@ def setup_seed(seed):
class ModelArguments:
lang_encoder_path: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/book_pretrain/Results/Book_mix_2048_13B_full/checkpoint-45800")
tokenizer_path: str = field(default='/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/tokenizer', metadata={"help": "Path to the tokenizer data."})
- vision_encoder_path: str = field(default='/home/cs/leijiayu/wuchaoyi/multi_modal/src/PMC-CLIP/checkpoint.pt', metadata={"help": "Path to the vision_encoder."})
+ #vision_encoder_path: str = field(default='/home/cs/leijiayu/wuchaoyi/multi_modal/src/PMC-CLIP/checkpoint.pt', metadata={"help": "Path to the vision_encoder."})
@dataclass
@@ -117,7 +117,6 @@ def main():
print("Setup Model")
model = MultiLLaMAForCausalLM(
lang_model_path=model_args.lang_encoder_path,
- vision_encoder_path=model_args.vision_encoder_path,
)
ckpt = torch.load('/gpfs/home/cs/leijiayu/wuchaoyi/wangyingjie/src/Results/backup/checkpoint-17600/pytorch_model.bin',map_location ='cpu')
# ckpt.pop('embedding_layer.figure_token_weight')
diff --git a/src/train.py b/src/train.py
index bb2cee6..12d324e 100644
--- a/src/train.py
+++ b/src/train.py
@@ -24,7 +24,7 @@ def compute_metrics(eval_preds):
class ModelArguments:
lang_encoder_path: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/book_pretrain/Results/Book_mix_2048_13B_full/checkpoint-45800")
tokenizer_path: str = field(default='/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/tokenizer', metadata={"help": "Path to the tokenizer data."})
- vision_encoder_path: str = field(default='/home/cs/leijiayu/wuchaoyi/multi_modal/src/PMC-CLIP/checkpoint.pt', metadata={"help": "Path to the vision_encoder."})
+
@dataclass
@@ -107,7 +107,6 @@ def main():
model = MultiLLaMAForCausalLM(
lang_model_path=model_args.lang_encoder_path,
- vision_encoder_path=model_args.vision_encoder_path,
)
trainer = Trainer(model=model,