diff --git a/Quick_demo/Language_files/config.json b/Quick_demo/Language_files/config.json new file mode 100644 index 0000000..0f9f43e --- /dev/null +++ b/Quick_demo/Language_files/config.json @@ -0,0 +1,23 @@ +{ + "_name_or_path": "/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/llama-13b-hf", + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 0, + "eos_token_id": 1, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 13824, + "max_sequence_length": 2048, + "model_type": "llama", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "pad_token_id": -1, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": false, + "torch_dtype": "float32", + "transformers_version": "4.28.0.dev0", + "use_cache": true, + "vocab_size": 32000 +} diff --git a/Quick_demo/Language_files/special_tokens_map.json b/Quick_demo/Language_files/special_tokens_map.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/Quick_demo/Language_files/special_tokens_map.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/Quick_demo/Language_files/tokenizer.model b/Quick_demo/Language_files/tokenizer.model new file mode 100644 index 0000000..22bccbc Binary files /dev/null and b/Quick_demo/Language_files/tokenizer.model differ diff --git a/Quick_demo/Language_files/tokenizer_config.json b/Quick_demo/Language_files/tokenizer_config.json new file mode 100644 index 0000000..a54b01a --- /dev/null +++ b/Quick_demo/Language_files/tokenizer_config.json @@ -0,0 +1 @@ +{"bos_token": "", "eos_token": "", "model_max_length": 1000000000000000019884624838656, "tokenizer_class": "LlamaTokenizer", "unk_token": ""} \ No newline at end of file diff --git a/Quick_demo/Model/RadFM/__init__.py b/Quick_demo/Model/RadFM/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Quick_demo/Model/RadFM/__pycache__/__init__.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..de287b2 Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/__init__.cpython-39.pyc differ diff --git a/Quick_demo/Model/RadFM/__pycache__/blocks.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/blocks.cpython-39.pyc new file mode 100644 index 0000000..0349241 Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/blocks.cpython-39.pyc differ diff --git a/Quick_demo/Model/RadFM/__pycache__/helpers.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/helpers.cpython-39.pyc new file mode 100644 index 0000000..fd945c9 Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/helpers.cpython-39.pyc differ diff --git a/Quick_demo/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc new file mode 100644 index 0000000..f2b8999 Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc differ diff --git a/Quick_demo/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc new file mode 100644 index 0000000..5147283 Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc differ diff --git a/Quick_demo/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc new file mode 100644 index 0000000..0a33999 Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc differ diff --git a/Quick_demo/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc new file mode 100644 index 0000000..c8cedd6 Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc differ diff --git a/Quick_demo/Model/RadFM/__pycache__/utils.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000..3cf863e Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/utils.cpython-39.pyc differ diff --git a/Quick_demo/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc b/Quick_demo/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc new file mode 100644 index 0000000..0c892a0 Binary files /dev/null and b/Quick_demo/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc differ diff --git a/Quick_demo/Model/RadFM/blocks.py b/Quick_demo/Model/RadFM/blocks.py new file mode 100644 index 0000000..ccf1a34 --- /dev/null +++ b/Quick_demo/Model/RadFM/blocks.py @@ -0,0 +1,400 @@ +from collections import OrderedDict +from typing import Tuple, Union, Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +class PMC_CLIP_cfg: + backbone: str = 'ModifiedRN50' # ['RN50', 'ModifiedRN50', 'MAE'] + layers: Union[Tuple[int, int, int, int], int] = [3,4,6,3] + width: int = 64 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + patch_dropout: float = 0.0 # patch dropout rate, no dropout by default + drop_attention_rate: float = 0. # Transformer Dropout + patch_size: None + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ResNet(nn.Module): + """ + RN50 + """ + + def __init__( + self, layers, output_dim, heads, image_size=224, width=64, + block=Bottleneck, + ): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 1-layer stem + self.conv1 = nn.Conv2d(3, width, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width) + self.relu1 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + # self.head = nn.Linear(512 * 6, output_dim) + self.head = nn.Linear(512 * block.expansion, output_dim) + + # embed_dim = width * 32 # the ResNet feature dimension + # self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer( + self, + planes, blocks, stride=1, + block=Bottleneck, + ): + layers = [block(self._inplanes, planes, stride)] + + self._inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.maxpool(x) + return x + + def forward(self, x): + # x[0]: [batch_size, 3, 224, 224] + # x[1]: [batch_size, 1] + x = self.stem(x) # [batch_size, 64, 56, 56] + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) # [batch_size, 2048, 7, 7] + x = self.avgpool(x) # [batch_size, 2048, 1, 1] + x = torch.flatten(x, 1) # [batch_size, 2048*1*1] + x = self.head(x) # [batch_size, 1024] + + visual_output = dict.fromkeys(["image_features", "mim_loss"], None) + visual_output.update({ + 'image_features': x, + }) + + return visual_output + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + visual_output = dict.fromkeys(["image_features", "mim_loss"], None) + visual_output.update({ + 'image_features': x, + }) + + return visual_output + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, + drop_attention_rate: float = 0., + ): + super().__init__() + + self.attn = nn.MultiheadAttention( + embed_dim=d_model, + num_heads=n_head, + dropout=drop_attention_rate, + ) + self.ln_1 = LayerNorm(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, + drop_attention_rate: float = 0., + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, drop_attention_rate=drop_attention_rate) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x \ No newline at end of file diff --git a/Quick_demo/Model/RadFM/helpers.py b/Quick_demo/Model/RadFM/helpers.py new file mode 100644 index 0000000..78b4896 --- /dev/null +++ b/Quick_demo/Model/RadFM/helpers.py @@ -0,0 +1,275 @@ +""" +Taken from https://github.com/lucidrains/flamingo-pytorch +""" + +import torch +from einops import rearrange, repeat +from einops_exts import rearrange_many +from torch import einsum, nn + + +def exists(val): + return val is not None + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, T, n1, D) + latent (torch.Tensor): latent features + shape (b, T, n2, D) + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) + q = q * self.scale + + # attention + sim = einsum("... i d, ... j d -> ... i j", q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h t n d -> b t n (h d)", h=h) + return self.to_out(out) + + +class PerceiverResampler(nn.Module): + def __init__( + self, + *, + dim, + depth=6, + dim_head=64, + heads=8, + num_latents=64, + max_num_media=None, + max_num_frames=None, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + self.frame_embs = ( + nn.Parameter(torch.randn(max_num_frames, dim)) + if exists(max_num_frames) + else None + ) + self.media_time_embs = ( + nn.Parameter(torch.randn(max_num_media, 1, dim)) + if exists(max_num_media) + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + """ + Args: + x (torch.Tensor): image features + shape (b, T, F, v, D) + Returns: + shape (b, T, n, D) where n is self.num_latents + """ + b, T, F, v = x.shape[:4] + + # frame and media time embeddings + if exists(self.frame_embs): + frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) + x = x + frame_embs + x = rearrange( + x, "b T F v d -> b T (F v) d" + ) # flatten the frame and spatial dimensions + if exists(self.media_time_embs): + x = x + self.media_time_embs[:T] + + # blocks + latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + return self.norm(latents) + + +# gated cross attention + + +class MaskedCrossAttention(nn.Module): + def __init__( + self, + *, + dim, + dim_visual, + dim_head=64, + heads=8, + only_attend_immediate_media=True, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # whether for text to only attend to immediate preceding image, or all previous images + self.only_attend_immediate_media = only_attend_immediate_media + + def forward(self, x, media, media_locations=None, attend_previous=True): + """ + Args: + x (torch.Tensor): text features + shape (B, T_txt, D_txt) + media (torch.Tensor): image features + shape (B, T_img, n, D_img) where n is the dim of the latents + media_locations: boolean mask identifying the media tokens in x + shape (B, T_txt) + attend_previous: bool + If false, ignores immediately preceding image and starts attending when following image + """ + _, T_img, n = media.shape[:3] + h = self.heads + + x = self.norm(x) + + q = self.to_q(x) + media = rearrange(media, "b t n d -> b (t n) d") + + k, v = self.to_kv(media).chunk(2, dim=-1) + q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) + + q = q * self.scale + + sim = einsum("... i d, ... j d -> ... i j", q, k) + + if exists(media_locations): + # at each boolean of True, increment the time counter (relative to media time) + text_time = media_locations.cumsum(dim=-1) + media_time = torch.arange(T_img, device=x.device) + 1 + + if not attend_previous: + text_time[~media_locations] += 1 + # make sure max is still the number of images in the sequence + text_time[ + text_time + > repeat( + torch.count_nonzero(media_locations, dim=1), + "b -> b i", + i=text_time.shape[1], + ) + ] = 0 + + # text time must equal media time if only attending to most immediate image + # otherwise, as long as text time is greater than media time (if attending to all previous images / media) + mask_op = torch.eq if self.only_attend_immediate_media else torch.ge + + text_to_media_mask = mask_op( + rearrange(text_time, "b i -> b 1 i 1"), + repeat(media_time, "j -> 1 1 1 (j n)", n=n), + ) + sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + if exists(media_locations) and self.only_attend_immediate_media: + # any text without a preceding media needs to have attention zeroed out + text_without_media_mask = text_time == 0 + text_without_media_mask = rearrange( + text_without_media_mask, "b i -> b 1 i 1" + ) + attn = attn.masked_fill(text_without_media_mask, 0.0) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class GatedCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + dim, + dim_visual, + dim_head=64, + heads=8, + ff_mult=4, + only_attend_immediate_media=True, + ): + super().__init__() + self.attn = MaskedCrossAttention( + dim=dim, + dim_visual=dim_visual, + dim_head=dim_head, + heads=heads, + only_attend_immediate_media=only_attend_immediate_media, + ) + self.attn_gate = nn.Parameter(torch.tensor([0.0])) + + self.ff = FeedForward(dim, mult=ff_mult) + self.ff_gate = nn.Parameter(torch.tensor([0.0])) + + def forward( + self, + x, + media, + media_locations=None, + attend_previous=True, + ): + x = ( + self.attn( + x, + media, + media_locations=media_locations, + attend_previous=attend_previous, + ) + * self.attn_gate.tanh() + + x + ) + x = self.ff(x) * self.ff_gate.tanh() + x + + return x diff --git a/Quick_demo/Model/RadFM/multimodality_model.py b/Quick_demo/Model/RadFM/multimodality_model.py new file mode 100644 index 0000000..0855346 --- /dev/null +++ b/Quick_demo/Model/RadFM/multimodality_model.py @@ -0,0 +1,84 @@ +from torch import nn +from transformers.models.llama import LlamaForCausalLM +from transformers import AutoConfig +from .my_embedding_layer import MyEmbedding +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +import tqdm.auto as tqdm +import torch.nn as nn +import torch +from torch.utils.checkpoint import checkpoint +from torch.autograd import Variable +import numpy as np +class MultiLLaMAForCausalLM(nn.Module): + def __init__(self, lang_model_path): + super(MultiLLaMAForCausalLM, self).__init__() + try: + self.lang_model = LlamaForCausalLM.from_pretrained( + lang_model_path, + ) + except: + config = AutoConfig.from_pretrained(lang_model_path) + self.lang_model = LlamaForCausalLM(config) + self.lang_model.gradient_checkpointing_enable() + self.lang_model.enable_input_require_grads() + # self.lang_model.requires_grad_(False) + self.embedding_layer = MyEmbedding() + self.embedding_layer.weight = self.lang_model.get_input_embeddings().weight + self.hidden_dim = 5120 + self.voc_size = 32000 + + def forward(self,lang_x, vision_x, attention_mask, labels, loss_reweight,key_words_query): + if labels.shape == lang_x.shape: + self.embedding_layer.flag = 'Text' + # lang_x = lang_x.to(vision_x.dtype) + # lang_x = lang_x + torch.zeros(1, dtype=lang_x.dtype, device=lang_x.device, requires_grad=True) + # vision_x = vision_x + torch.zeros(1, dtype=vision_x.dtype, device=vision_x.device, requires_grad=True) + # input_embedding = checkpoint(self.embedding_layer, lang_x, vision_x) + input_embedding,loss_match= self.embedding_layer(lang_x, vision_x,key_words_query) # ,loss_matching + output = self.lang_model(inputs_embeds = input_embedding,attention_mask = attention_mask, labels = labels) + logits = output['logits'] + + loss_reg = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_loss_reweight = loss_reweight[...,1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction = 'none') + shift_logits = shift_logits.view(-1, self.voc_size) + shift_labels = shift_labels.view(-1) + shift_loss_reweight = shift_loss_reweight.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + shift_loss_reweight = shift_loss_reweight.to(shift_logits.device) + loss_reg = loss_fct(shift_logits, shift_labels) + loss_reg = torch.sum(shift_loss_reweight*loss_reg)/torch.sum(shift_loss_reweight) + loss = loss_reg + if loss_match!= None: + loss = 0.8*loss + 0.2*loss_match + + logits = output['logits'][..., :-1, :].contiguous().detach() + total = len(labels) + predictions = torch.argmax(logits, dim=-1) + labels = labels[..., 1:].contiguous() + Acc = torch.sum(torch.all(torch.logical_or(predictions == labels, labels == -100),dim = -1)) + Accuracy = Acc /total + + return dict( + # loss_reg = loss_reg, + # loss_matching = loss_matching, + logits = Accuracy, + loss = output['loss'], + ) + ### useless for now ignore the folowing codes ### + # if labels.shape == vision_x.shape: + # self.embedding_layer.flag = 'Seg' + # input_embedding = self.embedding_layer(lang_x, vision_x) + + def generate(self, lang_x,vision_x): + self.embedding_layer.flag = 'Text' + with torch.no_grad(): + input_embedding,_ = self.embedding_layer(lang_x, vision_x) + generation = self.lang_model.generate(inputs_embeds = input_embedding, max_new_tokens =200,top_k=50) + return generation diff --git a/Quick_demo/Model/RadFM/my_embedding_layer.py b/Quick_demo/Model/RadFM/my_embedding_layer.py new file mode 100644 index 0000000..0c1b9b2 --- /dev/null +++ b/Quick_demo/Model/RadFM/my_embedding_layer.py @@ -0,0 +1,163 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch +from .helpers import PerceiverResampler +from .utils import get_visual_encoder +from einops import rearrange, repeat +from einops_exts import rearrange_many +import torchvision +from .vit_3d import ViT +from einops.layers.torch import Rearrange +from .transformer_decoder import TransformerDecoder,TransformerDecoderLayer +from torch.utils.checkpoint import checkpoint +from torch.autograd import Variable +import random +from transformers import AutoTokenizer, AutoModel + +class MyEmbedding(nn.Module): + def __init__(self, num_embeddings=32000, embedding_dim=5120, perceiver_num=32,vis_dim = 768, patch_size=32, frame_patch_size = 4 ,seg_channel = 256): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.weight = nn.Parameter(torch.torch.randn((num_embeddings, embedding_dim))) + self.figure_token_weight = nn.Parameter(torch.randn((2, embedding_dim))) + self.flag = 'Text' + self.patch_size = patch_size + self.frame_patch_size = frame_patch_size + self.seg_channel = seg_channel + + self.bert_tokenizer = AutoTokenizer.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT") + self.bert_model = AutoModel.from_pretrained("/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/src/MedKEBERT") + self.bert_projection_fc = nn.Linear(768,vis_dim) + + self.vision_encoder = ViT( + image_size = 512, # image size + frames = 512, # max number of frames + image_patch_size = patch_size, # image patch size + frame_patch_size = frame_patch_size, # frame patch size + dim = vis_dim, + depth = 12, + heads = 8, + mlp_dim = 2048, + dropout = 0.1, + emb_dropout = 0.1 + ) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose3d(vis_dim, vis_dim // 4, kernel_size=2, stride=2), + nn.BatchNorm3d(vis_dim // 4), + nn.GELU(), + nn.ConvTranspose3d(vis_dim // 4, vis_dim // 8, kernel_size=2, stride=2), + nn.GELU(), + ) + + decoder_layer = TransformerDecoderLayer(d_model = vis_dim, nhead = 8, normalize_before=True) + decoder_norm = nn.LayerNorm(vis_dim) + self.transformer_decoder = TransformerDecoder(decoder_layer = decoder_layer, num_layers = 4, norm=decoder_norm) + self.transformer_decoder_mlp = nn.Sequential( + nn.Linear(vis_dim,vis_dim // 4), + nn.GELU(), + nn.Linear(vis_dim // 4,vis_dim // 8), + nn.GELU(), + ) + self.vis_dim = vis_dim + + self.perceiver = PerceiverResampler(dim=self.vis_dim, num_latents = perceiver_num) + self.fc = nn.Linear(self.vis_dim,self.embedding_dim) + self.cls_head = nn.Linear(self.vis_dim // 8, 1) + + + def forward(self, text_input, vision_x, key_words_query = None): + if self.flag == 'Text': + B,S,C,H,W,D = vision_x.shape + vision_x = rearrange(vision_x, "b S c h w d-> (b S) c h w d") + + + vision_x, pos_embedding = self.vision_encoder(vision_x) + # vision_x = Variable(vision_x,requires_grad=True) + # vision_x, _ = checkpoint(self.vision_encoder,vision_x) + + vision_x = rearrange(vision_x, "(b s F) v d -> b s F v d", b=B, s=S,F=1) + + loss_matching = None + if key_words_query != None: + # key_words_query list[list[str]] B, words, each word matches corresponding vision_x embedding + query_words = [item for sublist in key_words_query for item in sublist] + query_words = list(set(query_words)) + if len(query_words)>16: + random.shuffle(query_words) + query_words = query_words[0:16] + if query_words != []: + contrastive_labels = torch.zeros(B,len(query_words)) #B Q + for i,sublist in enumerate(key_words_query): + for j,item in enumerate(query_words): + if item in sublist: + contrastive_labels[i,j] = 1 + contrastive_labels = contrastive_labels.to(vision_x.dtype).to(vision_x.device) + + with torch.no_grad(): + query_words_embedding = self.bert_tokenizer(query_words, padding='max_length', truncation=True, max_length=256,return_tensors="pt") + query_words_embedding = self.bert_model(input_ids = query_words_embedding['input_ids'].to(vision_x.device),attention_mask = query_words_embedding['attention_mask'].to(vision_x.device))['last_hidden_state'][:,0,:].to(vision_x.dtype).to(vision_x.device) # Q,D + query_words_embedding = self.bert_projection_fc(query_words_embedding) + query_words_embedding = query_words_embedding.unsqueeze(0).repeat(B,1,1) # B,Q,D + _,N,_ = query_words_embedding.shape + + image_embedding = vision_x.mean(dim=1) # B V D average pooling 去除掉多模态。 + image_embedding = rearrange(image_embedding, "b F v d -> b (F v) d") + pos_embedding = rearrange(pos_embedding, "(b s) v d -> b s v d", b=B, s=S)[:,0,:,:] + + image_embedding = image_embedding.transpose(0,1) # (H/P W/P D/P) B D + pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D + query_words_embedding = query_words_embedding.transpose(0,1) # N B D + + oo_embedding,_ = self.transformer_decoder(query_words_embedding, image_embedding, pos = pos_embedding) + oo_embedding = oo_embedding.transpose(0,1) # B Q D + oo_embedding = rearrange(oo_embedding, 'b n d -> (b n) d') + oo_embedding = self.transformer_decoder_mlp(oo_embedding) + oo_embedding = self.cls_head(oo_embedding).mean(dim = -1) + oo_embedding = rearrange(oo_embedding, '(b n) -> b n', b=B, n=N) # B Q + # oo_embedding = rearrange(oo_embedding, 'b n d -> b (n d)') # B Q + loss_matching = F.binary_cross_entropy_with_logits(oo_embedding, contrastive_labels) + + vision_x = self.perceiver(vision_x) # reshapes to (b, S, n, d) + #vision_x = checkpoint(self.perceiver,vision_x) + + n = vision_x.shape[2] + + vision_x = rearrange(vision_x, "b s n d -> (b s n) d") + vision_x = self.fc(vision_x) + vision_x = rearrange(vision_x, "(b T) d -> b T d", b=B, T=n*S) + + embedding_weight = torch.cat([self.weight, self.figure_token_weight],dim = 0) + embedding_weight = embedding_weight.unsqueeze(0).repeat(B, 1, 1) + embedding_weight = torch.cat([embedding_weight,vision_x],dim = 1) + text_input = F.one_hot(text_input,embedding_weight.shape[1]).to(vision_x.dtype).to(vision_x.device) + out_put = torch.matmul(text_input, embedding_weight) + + ## useless for now. ignore the folowing code## + # if self.flag == 'Seg': + # B,C,H,W,D = vision_x.shape + # _,N,_ = text_input.shape + # latent_embedding, pos_embedding = self.vision_encoder(vision_x) # B (H/P W/P D/P) D + + # image_embedding = latent_embedding.transpose(0,1) # (H/P W/P D/P) B D + # pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D + # text_input = text_input.transpose(0,1) # N B D + + # mask_embedding,_ = self.transformer_decoder(text_input, image_embedding, pos = pos_embedding) + # mask_embedding = mask_embedding.transpose(0,1) # B N D + # mask_embedding = rearrange(mask_embedding, 'b n d -> (b n) d') + # mask_embedding = self.transformer_decoder_mlp(mask_embedding) + # mask_embedding = rearrange(mask_embedding, '(b n) d -> b n d', b=B, n=N,d = self.vis_dim // 8) + + # vision_x = rearrange(latent_embedding,'b (h w d) c -> b c h w d', h = (H // self.patch_size), w = (W // self.patch_size), d = (D // self.frame_patch_size), c=self.vis_dim) + # vision_x = self.output_upscaling(vision_x) #B C H/4 W/4 D/4 + # out_put = torch.einsum('bchwd,bnc->bnhwd', vision_x, mask_embedding) + + return out_put,loss_matching + +# model = MyEmbedding(vision_encoder_path = '') +# text_input = torch.randint(low=0, high=3210, size=(4,2048)) +# image_input = torch.randn((4,3,3,512,512,4)) +# key_words_query = [[],[],[],['consoliation']] +# print(model(text_input, image_input, key_words_query)) diff --git a/Quick_demo/Model/RadFM/position_encoding.py b/Quick_demo/Model/RadFM/position_encoding.py new file mode 100644 index 0000000..9d0af4b --- /dev/null +++ b/Quick_demo/Model/RadFM/position_encoding.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn +from einops.layers.torch import Rearrange +from einops import rearrange, repeat + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + +class PositionEmbeddingLearned3d(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256,h_patch_num = 16, w_patch_num = 16,d_patch_num = 64): + super().__init__() + self.h_patch_num = h_patch_num + self.w_patch_num = w_patch_num + self.d_patch_num = d_patch_num + self.row_embed = nn.Embedding(h_patch_num, num_pos_feats) + self.col_embed = nn.Embedding(w_patch_num, num_pos_feats) + self.dep_embed = nn.Embedding(d_patch_num, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + nn.init.uniform_(self.dep_embed.weight) + + def forward(self, B, h, w, d,x): + i = (torch.arange(h, device=x.device) + 1)* (self.h_patch_num // h) -1 + j = (torch.arange(w, device=x.device) + 1)* (self.w_patch_num // w) -1 + k = (torch.arange(d, device=x.device) + 1)* (self.d_patch_num // d) -1 + x_emb = self.row_embed(i).unsqueeze(1).unsqueeze(2).repeat(1,w,d,1) + y_emb = self.col_embed(j).unsqueeze(0).unsqueeze(2).repeat(h,1,d,1) + z_emb = self.dep_embed(k).unsqueeze(0).unsqueeze(1).repeat(h,w,1,1) + pos = torch.cat([x_emb,y_emb,z_emb,], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1, 1) + pos = rearrange(pos,'b h w d c -> b (h w d) c') + return pos + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding + +# Pos = PositionEmbeddingLearned3d() +# x = torch.randn((8,3,32,32,1)) +# print(Pos(8,16,16,1,x)) \ No newline at end of file diff --git a/Quick_demo/Model/RadFM/transformer_decoder.py b/Quick_demo/Model/RadFM/transformer_decoder.py new file mode 100644 index 0000000..b50a350 --- /dev/null +++ b/Quick_demo/Model/RadFM/transformer_decoder.py @@ -0,0 +1,160 @@ +""" +Code modified from DETR tranformer: +https://github.com/facebookresearch/detr +Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" + +import copy +from typing import Optional, List +import pickle as cp + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + T,B,C = memory.shape + intermediate = [] + atten_layers = [] + for n,layer in enumerate(self.layers): + + residual=True + output,ws = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos,residual=residual) + atten_layers.append(ws) + if self.return_intermediate: + intermediate.append(self.norm(output)) + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + return output,atten_layers + + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + residual=True): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = self.norm1(tgt) + tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + + + # attn_weights [B,NUM_Q,T] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt,ws + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt,attn_weights + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + residual=True): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/Quick_demo/Model/RadFM/utils.py b/Quick_demo/Model/RadFM/utils.py new file mode 100644 index 0000000..1624348 --- /dev/null +++ b/Quick_demo/Model/RadFM/utils.py @@ -0,0 +1,74 @@ +from .blocks import ModifiedResNet,PMC_CLIP_cfg +import torch +from torchvision import transforms +from PIL import Image +import torch.nn as nn +def extend_instance(obj, mixin): + """Apply mixins to a class instance after creation""" + base_cls = obj.__class__ + base_cls_name = obj.__class__.__name__ + obj.__class__ = type( + base_cls_name, (mixin, base_cls), {} + ) # mixin needs to go first for our forward() logic to work + + +def getattr_recursive(obj, att): + """ + Return nested attribute of obj + Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c + """ + if att == "": + return obj + i = att.find(".") + if i < 0: + return getattr(obj, att) + else: + return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) + + +def setattr_recursive(obj, att, val): + """ + Set nested attribute of obj + Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val + """ + if "." in att: + obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) + setattr(obj, att.split(".")[-1], val) + + + +def get_visual_encoder(model_str): + """ + Args: + str (_type_): str_to_model_path + Return: + vision_model, visual_dim, img_preprocessor + """ + normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + img_preprocessor = transforms.Compose([ + transforms.Resize((512,512), interpolation=Image.BICUBIC), + transforms.ToTensor(), + normalize, + ]) + if 'PMC-CLIP' in model_str: + #vision_cfg = json.load(open(model_args.visual_model_config,'r'))['vision_cfg'] + vision_cfg = PMC_CLIP_cfg() + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + vision_model = ModifiedResNet( + layers=vision_cfg.layers, + heads=vision_heads, + output_dim = 768, + image_size=vision_cfg.image_size, + width=vision_cfg.width + ) + vision_model = vision_load_pretrain(vision_model,model_str) + vision_model = nn.Sequential(*list(vision_model.children())[:-2]) + visual_dim = 1024 + return vision_model,visual_dim,img_preprocessor + +def vision_load_pretrain(resnet,model_path): + checkpoint = torch.load(model_path, map_location='cpu') + state_dict = checkpoint['state_dict'] + state_dict = {k.replace('module.visual.',''): v for k, v in state_dict.items() if '.visual' in k} + resnet.load_state_dict(state_dict) + return resnet diff --git a/Quick_demo/Model/RadFM/vit_3d.py b/Quick_demo/Model/RadFM/vit_3d.py new file mode 100644 index 0000000..1c36b2b --- /dev/null +++ b/Quick_demo/Model/RadFM/vit_3d.py @@ -0,0 +1,123 @@ +import torch +from torch import nn + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from .position_encoding import PositionEmbeddingLearned3d + +# helpers + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +# classes + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +class ViT(nn.Module): + def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(image_patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size' + + self.patch_height = patch_height + self.patch_width = patch_width + self.frame_patch_size = frame_patch_size + + num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size) + patch_dim = channels * patch_height * patch_width * frame_patch_size + + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) (f pf) -> b (h w f) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim), + ) + + self.pos_embedding = PositionEmbeddingLearned3d(dim // 3,(image_height // patch_height), (image_width // patch_width), (frames // frame_patch_size)) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + + def forward(self, video): + B, C, H, W, D = video.shape + x = self.to_patch_embedding(video) + b, n, _ = x.shape + + pos = self.pos_embedding(B, H // self.patch_height, W // self.patch_width, D // self.frame_patch_size,x) + x += pos + x = self.dropout(x) + + x = self.transformer(x) + return x,pos diff --git a/Quick_demo/test.py b/Quick_demo/test.py new file mode 100644 index 0000000..62719ae --- /dev/null +++ b/Quick_demo/test.py @@ -0,0 +1,122 @@ +import tqdm.auto as tqdm +import torch.nn.functional as F +from typing import Optional, Dict, Sequence +from typing import List, Optional, Tuple, Union +import transformers +from dataclasses import dataclass, field +from Model.RadFM.multimodality_model import MultiLLaMAForCausalLM +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer +from torchvision import transforms +from PIL import Image + +def get_tokenizer(tokenizer_path, max_img_size = 100, image_num = 32): + ''' + Initialize the image special tokens + max_img_size denotes the max image put length and image_num denotes how many patch embeddings the image will be encoded to + ''' + if isinstance(tokenizer_path,str): + image_padding_tokens = [] + text_tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_path, + ) + special_token = {"additional_special_tokens": ["",""]} + for i in range(max_img_size): + image_padding_token = "" + + for j in range(image_num): + image_token = "" + image_padding_token = image_padding_token + image_token + special_token["additional_special_tokens"].append("") + image_padding_tokens.append(image_padding_token) + text_tokenizer.add_special_tokens( + special_token + ) + ## make sure the bos eos pad tokens are correct for LLaMA-like models + text_tokenizer.pad_token_id = 0 + text_tokenizer.bos_token_id = 1 + text_tokenizer.eos_token_id = 2 + + return text_tokenizer,image_padding_tokens + +def combine_and_preprocess(question,image_list,image_padding_tokens): + + transform = transforms.Compose([ + transforms.RandomResizedCrop([512,512],scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + ]) + images = [] + new_qestions = [_ for _ in question] + padding_index = 0 + for img in image_list: + img_path = img['img_path'] + position = img['position'] + + + + image = Image.open(img_path).convert('RGB') + image = transform(image) + image = image.unsqueeze(0).unsqueeze(-1) # c,w,h,d + + ## pre-process the img first + target_H = 512 + target_W = 512 + target_D = 4 + # This can be different for 3D and 2D images. For demonstration we here set this as the default sizes for 2D images. + images.append(torch.nn.functional.interpolate(image, size = (target_H,target_W,target_D))) + + ## add img placeholder to text + new_qestions[position] = ""+ image_padding_tokens[padding_index] +"" + new_qestions[position] + padding_index +=1 + + vision_x = torch.cat(images,dim = 1).unsqueeze(0) #cat tensors and expand the batch_size dim + text = ''.join(new_qestions) + return [text], vision_x, + + +def main(): + + print("Setup tokenizer") + text_tokenizer,image_padding_tokens = get_tokenizer('./Language_files') + print("Finish loading tokenizer") + + ### Initialize a simple case for demo ### + print("Setup demo case") + question = "Can you identify any visible signs of Cardiomegaly in the image?" + image =[ + { + 'img_path': './view1_frontal.jpg', + 'position': 0, #indicate where to put the images in the text string, range from [0,len(question)-1] + }, # can add abitrary number of imgs + ] + + text,vision_x = combine_and_preprocess(question,image,image_padding_tokens) + + print("Finish loading demo case") + + print("Setup Model") + model = MultiLLaMAForCausalLM( + lang_model_path='./Language_files', ### Build up model based on LLaMa-13B config + ) + ckpt = torch.load('./pytorch_model.bin',map_location ='cpu') # Please dowloud our checkpoint from huggingface and Decompress the original zip file first + model.load_state_dict(ckpt) + print("Finish loading model") + + model = model.to('cuda') + model.eval() + with torch.no_grad(): + lang_x = text_tokenizer( + question, max_length=2048, truncation=True, return_tensors="pt" + )['input_ids'].to('cuda') + + vision_x = vision_x.to('cuda') + generation = model.generate(lang_x,vision_x) + generated_texts = text_tokenizer.batch_decode(generation, skip_special_tokens=True) + print('---------------------------------------------------') + print(question) + print(generated_texts[0]) + + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/Quick_demo/view1_frontal.jpg b/Quick_demo/view1_frontal.jpg new file mode 100644 index 0000000..a11c4d0 Binary files /dev/null and b/Quick_demo/view1_frontal.jpg differ diff --git a/src/Model/RadFM/multimodality_model.py b/src/Model/RadFM/multimodality_model.py index 8d9d558..e88f258 100644 --- a/src/Model/RadFM/multimodality_model.py +++ b/src/Model/RadFM/multimodality_model.py @@ -9,7 +9,7 @@ from torch.autograd import Variable import numpy as np class MultiLLaMAForCausalLM(nn.Module): - def __init__(self, lang_model_path, vision_encoder_path): + def __init__(self, lang_model_path): super(MultiLLaMAForCausalLM, self).__init__() self.lang_model = LlamaForCausalLM.from_pretrained( lang_model_path, @@ -17,7 +17,7 @@ def __init__(self, lang_model_path, vision_encoder_path): self.lang_model.gradient_checkpointing_enable() self.lang_model.enable_input_require_grads() # self.lang_model.requires_grad_(False) - self.embedding_layer = MyEmbedding(vision_encoder_path) + self.embedding_layer = MyEmbedding() self.embedding_layer.weight = self.lang_model.get_input_embeddings().weight self.hidden_dim = 5120 self.voc_size = 32000 diff --git a/src/Model/RadFM/my_embedding_layer.py b/src/Model/RadFM/my_embedding_layer.py index a2c36f4..0c1b9b2 100644 --- a/src/Model/RadFM/my_embedding_layer.py +++ b/src/Model/RadFM/my_embedding_layer.py @@ -15,7 +15,7 @@ from transformers import AutoTokenizer, AutoModel class MyEmbedding(nn.Module): - def __init__(self, vision_encoder_path, num_embeddings=32000, embedding_dim=5120, perceiver_num=32,vis_dim = 768, patch_size=32, frame_patch_size = 4 ,seg_channel = 256): + def __init__(self, num_embeddings=32000, embedding_dim=5120, perceiver_num=32,vis_dim = 768, patch_size=32, frame_patch_size = 4 ,seg_channel = 256): super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim diff --git a/src/test.py b/src/test.py index b59dc72..73ef782 100644 --- a/src/test.py +++ b/src/test.py @@ -27,7 +27,7 @@ def setup_seed(seed): class ModelArguments: lang_encoder_path: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/book_pretrain/Results/Book_mix_2048_13B_full/checkpoint-45800") tokenizer_path: str = field(default='/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/tokenizer', metadata={"help": "Path to the tokenizer data."}) - vision_encoder_path: str = field(default='/home/cs/leijiayu/wuchaoyi/multi_modal/src/PMC-CLIP/checkpoint.pt', metadata={"help": "Path to the vision_encoder."}) + #vision_encoder_path: str = field(default='/home/cs/leijiayu/wuchaoyi/multi_modal/src/PMC-CLIP/checkpoint.pt', metadata={"help": "Path to the vision_encoder."}) @dataclass @@ -117,7 +117,6 @@ def main(): print("Setup Model") model = MultiLLaMAForCausalLM( lang_model_path=model_args.lang_encoder_path, - vision_encoder_path=model_args.vision_encoder_path, ) ckpt = torch.load('/gpfs/home/cs/leijiayu/wuchaoyi/wangyingjie/src/Results/backup/checkpoint-17600/pytorch_model.bin',map_location ='cpu') # ckpt.pop('embedding_layer.figure_token_weight') diff --git a/src/train.py b/src/train.py index bb2cee6..12d324e 100644 --- a/src/train.py +++ b/src/train.py @@ -24,7 +24,7 @@ def compute_metrics(eval_preds): class ModelArguments: lang_encoder_path: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/book_pretrain/Results/Book_mix_2048_13B_full/checkpoint-45800") tokenizer_path: str = field(default='/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/tokenizer', metadata={"help": "Path to the tokenizer data."}) - vision_encoder_path: str = field(default='/home/cs/leijiayu/wuchaoyi/multi_modal/src/PMC-CLIP/checkpoint.pt', metadata={"help": "Path to the vision_encoder."}) + @dataclass @@ -107,7 +107,6 @@ def main(): model = MultiLLaMAForCausalLM( lang_model_path=model_args.lang_encoder_path, - vision_encoder_path=model_args.vision_encoder_path, ) trainer = Trainer(model=model,