From 4e220f0b873d7928aa559621a3b7db285a5d46cb Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 28 Mar 2023 21:24:36 +0800 Subject: [PATCH] add altclip-m18 Signed-off-by: zhaohu xing <920232796@qq.com> --- examples/AltCLIP-m18/README.md | 441 +++++ examples/AltCLIP-m18/altclip_evaluation.py | 89 + examples/AltCLIP-m18/altclip_finetuning.py | 65 + examples/AltCLIP-m18/altclip_inference.py | 41 + examples/AltCLIP-m18/dog.jpeg | Bin 0 -> 6215 bytes .../AltCLIP-m18/zeroshot_classification.py | 228 +++ flagai/auto_model/auto_loader.py | 2 + flagai/model/base_model.py | 3 + flagai/model/mm/AltCLIP.py | 2 +- flagai/model/mm/modeling_altclip.py | 1759 +++++++++++++++++ 10 files changed, 2629 insertions(+), 1 deletion(-) create mode 100644 examples/AltCLIP-m18/README.md create mode 100644 examples/AltCLIP-m18/altclip_evaluation.py create mode 100644 examples/AltCLIP-m18/altclip_finetuning.py create mode 100644 examples/AltCLIP-m18/altclip_inference.py create mode 100644 examples/AltCLIP-m18/dog.jpeg create mode 100644 examples/AltCLIP-m18/zeroshot_classification.py create mode 100644 flagai/model/mm/modeling_altclip.py diff --git a/examples/AltCLIP-m18/README.md b/examples/AltCLIP-m18/README.md new file mode 100644 index 00000000..d5ca7ea1 --- /dev/null +++ b/examples/AltCLIP-m18/README.md @@ -0,0 +1,441 @@ + +# AltCLIP + +## 简介/Overview + +我们提出了一个简单高效的方法去训练更加优秀的双语CLIP模型。命名为AltCLIP。AltCLIP基于 [OpenAI CLIP](https://github.com/openai/CLIP) 训练,训练数据来自 [WuDao数据集](https://data.baai.ac.cn/details/WuDaoCorporaText) 和 [LIAON](https://huggingface.co/datasets/ChristophSchuhmann/improved_aesthetics_6plus) + +AltCLIP模型可以为本项目中的AltDiffusion模型提供支持,关于AltDiffusion模型的具体信息可查看[此教程](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltDiffusion/README.md) 。 + +模型代码已经在 [FlagAI](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltCLIP) 上开源,权重位于我们搭建的 [modelhub](https://model.baai.ac.cn/model-detail/100075) 上。我们还提供了微调,推理,验证的脚本,欢迎试用。 + +首次运行AltCLIP时,下列权重将会自动从modelhub上下载。 + +| 模型名称 Model name | 大小 Size | 描述 Description | +| ------------------- | --------- | -------------------------------------------------- | +| [AltCLIP](https://model.baai.ac.cn/model-detail/100075) | 3.22G | 我们的双语AltCLIP模型;Our bilingual AltCLIP model | +| [AltCLIP-m9](https://model.baai.ac.cn/model-detail/100077) | 3.22G | support English(En), Chinese(Zh), Spanish(Es), French(Fr), Russian(Ru), Japanese(Ja), Korean(Ko), Arabic(Ar) and Italian(It) | + +Our AltCLIP support + +We propose a simple and efficient method to train a better multilingual CLIP model. Named AltCLIP. AltCLIP is trained based on [Stable Diffusiosn](https://github.com/CompVis/stable-diffusion) with training data from [WuDao dataset](https://data.baai.ac.cn/details/WuDaoCorporaText) and [Liaon](https://huggingface.co/datasets/laion/laion2B-en). + +The AltCLIP model can provide support for the AltDiffusion model in this project. Specific information on the AltDiffusion model can be found in [this tutorial](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltDiffusion/README.md). + +The model code has been open sourced on [FlagAI](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltCLIP) and the weights are located on [modelhub](https://model.baai.ac.cn/model-detail/100075). We also provide scripts for fine-tuning, inference, and validation, so feel free to try them out. + +## 引用 +关于AltCLIP,我们已经推出了相关报告,有更多细节可以查阅,如对您的工作有帮助,欢迎引用。 + +If you find this work helpful, please consider to cite +``` +@article{https://doi.org/10.48550/arxiv.2211.06679, + doi = {10.48550/ARXIV.2211.06679}, + url = {https://arxiv.org/abs/2211.06679}, + author = {Chen, Zhongzhi and Liu, Guang and Zhang, Bo-Wen and Ye, Fulong and Yang, Qinghong and Wu, Ledell}, + keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences}, + title = {AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities}, + publisher = {arXiv}, + year = {2022}, + copyright = {arXiv.org perpetual, non-exclusive license} +} +``` + + +## 训练/Training + +训练共有两个阶段。 +在平行知识蒸馏阶段,我们只是使用平行语料文本来进行蒸馏(平行语料相对于图文对更容易获取且数量更大)。在双语对比学习阶段,我们使用少量的中-英图像-文本对(一共约2百万)来训练我们的文本编码器以更好地适应图像编码器。 + +There are two phases of training. +In the parallel knowledge distillation phase, we only use parallel corpus texts for distillation (parallel corpus is easier to obtain and larger in number compared to image text pairs). In the mltilingual comparison learning phase, we use a small number of Chinese-English image-text pairs (about 2 million in total) to train our text encoder to better fit the image encoder. + + + +## 下游效果/Performance +我们提出的模型与SOTA CLIP模型在双语跨模态基准(即Flickr30k的中英文版本)上的比较结果。这些模型中使用的图像编码器均为ViT-L,便于比较。 + +Comparison results between our proposed model and SOTA CLIP model on a bilingual cross-modal benchmark (i.e., the English and Chinese versions of Flickr30k.) The image encoders used in these models are ViT-L for easy comparison. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
LanguageMethodText-to-Image RetrivalImage-to-Text RetrivalMR
R@1R@5R@10R@1R@5R@10
Flickr30k-EnglishCLIP65.0 87.1 92.2 85.1 97.3 99.2 87.6
Taiyi25.3 48.2 59.2 39.3 68.1 79.6 53.3
Wukong-------
R2D2-------
CN-CLIP49.5 76.9 83.8 66.5 91.2 96.0 77.3
AltCLIP72.5 91.6 95.4 86.0 98.0 99.1 90.4
Flickr30k-ChineseCLIP0.0 2.4 4.0 2.3 8.1 12.6 5.0
Taiyi53.7 79.8 86.6 63.8 90.5 95.9 78.4
Wukong51.7 78.9 86.3 76.1 94.8 97.5 80.9
R2D260.9 86.8 92.7 77.6 96.7 98.9 85.6
CN-CLIP68.0 89.7 94.4 80.2 96.6 98.2 87.9
AltCLIP69.8 89.9 94.7 84.8 97.4 98.8 89.2
+ +## 多语言性能/Multi-lingual performance +We achieve the SOTA zero-shot results on XTD. + +我们AltCLIP-m9在多语言的多模态检索数据集上的zero-shot性能。 +![](imgs/m9.png) + +## 可视化效果/Visualization effects + +基于AltCLIP,我们还开发了AltDiffusion模型,可视化效果如下。 + +Based on AltCLIP, we have also developed the AltDiffusion model, visualized as follows. + +![](https://raw.githubusercontent.com/920232796/test/master/image7.png) + +## 模型推理 Inference + +```python +import torch +from PIL import Image +from flagai.auto_model.auto_loader import AutoLoader + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +loader = AutoLoader( + task_name="txt_img_matching", + model_name="AltCLIP-XLMR-L", # Load the checkpoints from Modelhub(model.baai.ac.cn/models) + model_dir="./checkpoints" +) + +model = loader.get_model() +tokenizer = loader.get_tokenizer() +transform = loader.get_transform() + +model.eval() +model.to(device) +tokenizer = loader.get_tokenizer() + +def inference(): + image = Image.open("./dog.jpeg") + image = transform(image) + image = torch.tensor(image["pixel_values"]).to(device) + tokenizer_out = tokenizer(["a rat", "a dog", "a cat"], + padding=True, + truncation=True, + max_length=77, + return_tensors='pt') + + text = tokenizer_out["input_ids"].to(device) + attention_mask = tokenizer_out["attention_mask"].to(device) + with torch.no_grad(): + image_features = model.get_image_features(image) + text_features = model.get_text_features(text, attention_mask=attention_mask) + text_probs = (image_features @ text_features.T).softmax(dim=-1) + + print(text_probs.cpu().numpy()[0].tolist()) + +if __name__=="__main__": + inference() +``` + +## CLIP微调/Finetuning + +微调采用cifar10数据集,并使用FlagAI的Trainer快速开始训练过程。 + +Fine-tuning was done using the cifar10 dataset and using FlagAI's Trainer to quickly start the training process. + +```python +# Copyright © 2022 BAAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License") +import torch +from flagai.auto_model.auto_loader import AutoLoader +import os +from flagai.trainer import Trainer +from torchvision.datasets import ( + CIFAR10 +) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +dataset_root = "./clip_benchmark_datasets" +dataset_name = "cifar10" + +batch_size = 4 +classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] + +auto_loader = AutoLoader( + task_name="txt_img_matching", + model_dir="./checkpoints", + model_name="AltCLIP-XLMR-L" # Load the checkpoints from Modelhub(model.baai.ac.cn/models) +) + +model = auto_loader.get_model() +model.to(device) +model.eval() +tokenizer = auto_loader.get_tokenizer() +transform = auto_loader.get_transform() + +trainer = Trainer(env_type="pytorch", + pytorch_device=device, + experiment_name="clip_finetuning", + batch_size=4, + lr=1e-4, + epochs=10, + log_interval=10) + +dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name), + transform=transform, + download=True) + +def cifar10_collate_fn(batch): + # image shape is (batch, 3, 224, 224) + images = torch.tensor([b[0]["pixel_values"][0] for b in batch]) + # text_id shape is (batch, n) + input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}", + padding=True, + truncation=True, + max_length=77)["input_ids"] for b in batch]) + + attention_mask = torch.tensor([tokenizer(f"a photo of a {b[1]}", + padding=True, + truncation=True, + max_length=77)["attention_mask"] for b in batch]) + + return { + "pixel_values": images, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + +if __name__ == "__main__": + trainer.train(model=model, train_dataset=dataset, collate_fn=cifar10_collate_fn) +``` + +## 模型验证/Evaluation + +我们提供了可以直接运行的验证脚本,在cifar10数据集上进行验证。 + +期待的输出为:```{'dataset': 'cifar10', 'metrics': {'acc1': 0.95402, 'acc5': 0.99616, 'mean_per_class_recall': 0.9541200000000002}}``` + +We provide validation scripts that can be run directly on the cifar10 dataset. + +```python +# Copyright © 2022 BAAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License") +import torch +from flagai.auto_model.auto_loader import AutoLoader +from metrics import zeroshot_classification +import json +import os +from torchvision.datasets import CIFAR10 + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +maxlen = 256 + +dataset_root = "./clip_benchmark_datasets" +dataset_name = "cifar10" + +auto_loader = AutoLoader( + task_name="txt_img_matching", + model_dir="./checkpoints/", + model_name="AltCLIP-XLMR-L" +) + +model = auto_loader.get_model() +model.to(device) +model.eval() +tokenizer = auto_loader.get_tokenizer() +transform = auto_loader.get_transform() + +dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name), + transform=transform, + download=True) +batch_size = 128 +num_workers = 4 + +template = {"cifar10": [ + "a photo of a {c}.", + "a blurry photo of a {c}.", + "a black and white photo of a {c}.", + "a low contrast photo of a {c}.", + "a high contrast photo of a {c}.", + "a bad photo of a {c}.", + "a good photo of a {c}.", + "a photo of a small {c}.", + "a photo of a big {c}.", + "a photo of the {c}.", + "a blurry photo of the {c}.", + "a black and white photo of the {c}.", + "a low contrast photo of the {c}.", + "a high contrast photo of the {c}.", + "a bad photo of the {c}.", + "a good photo of the {c}.", + "a photo of the small {c}.", + "a photo of the big {c}." + ], +} +def evaluate(): + if dataset: + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + ) + classnames = dataset.classes if hasattr(dataset, "classes") else None + + zeroshot_templates = template["cifar10"] + metrics = zeroshot_classification.evaluate( + model, + dataloader, + tokenizer, + classnames, + zeroshot_templates, + device=device, + amp=True, + ) + + dump = { + "dataset": dataset_name, + "metrics": metrics + } + + print(dump) + with open("./result.txt", "w") as f: + json.dump(dump, f) + return metrics + +if __name__ == "__main__": + evaluate() + +``` +# Huggingface Version + +我们已经上传了模型权重到 `transformers` ,只需要几行代码就能快速使用我们的模型! [Huggingface Model Card](https://huggingface.co/BAAI/AltCLIP) + +we have uploaded our model to `transformers`. you can use our model by a few lines of code. If you find it useful, feel free to star🌟! + +更多信息可查看 `hf_altclip/` + +more details please refer directory `hf_altclip/` diff --git a/examples/AltCLIP-m18/altclip_evaluation.py b/examples/AltCLIP-m18/altclip_evaluation.py new file mode 100644 index 00000000..fb281e04 --- /dev/null +++ b/examples/AltCLIP-m18/altclip_evaluation.py @@ -0,0 +1,89 @@ +# Copyright © 2022 BAAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License") +import torch +from flagai.auto_model.auto_loader import AutoLoader +import zeroshot_classification +import json +import os +from torchvision.datasets import CIFAR10 + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +maxlen = 256 + +dataset_root = "./clip_benchmark_datasets/" +dataset_name = "cifar10" + +auto_loader = AutoLoader( + task_name="txt_img_matching", + model_dir="./checkpoints/", + model_name="AltCLIP-XLMR-L-m18" # Load the checkpoints from Modelhub(model.baai.ac.cn/models) +) + +model = auto_loader.get_model() +model.to(device) +model.eval() +tokenizer = auto_loader.get_tokenizer() +transform = auto_loader.get_transform() + +dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name), + transform=transform, + download=True) +batch_size = 128 +num_workers = 4 + +template = {"cifar10": [ + "a photo of a {c}.", + "a blurry photo of a {c}.", + "a black and white photo of a {c}.", + "a low contrast photo of a {c}.", + "a high contrast photo of a {c}.", + "a bad photo of a {c}.", + "a good photo of a {c}.", + "a photo of a small {c}.", + "a photo of a big {c}.", + "a photo of the {c}.", + "a blurry photo of the {c}.", + "a black and white photo of the {c}.", + "a low contrast photo of the {c}.", + "a high contrast photo of the {c}.", + "a bad photo of the {c}.", + "a good photo of the {c}.", + "a photo of the small {c}.", + "a photo of the big {c}." + ], +} +def evaluate(): + if dataset: + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + ) + + zeroshot_templates = template["cifar10"] + classnames = dataset.classes if hasattr(dataset, "classes") else None + + metrics = zeroshot_classification.evaluate( + model, + dataloader, + tokenizer, + classnames, + zeroshot_templates, + device=device, + amp=True, + ) + + dump = { + "dataset": dataset_name, + "metrics": metrics + } + + print(dump) + with open("./result.txt", "w") as f: + json.dump(dump, f) + return metrics + +if __name__ == "__main__": + evaluate() diff --git a/examples/AltCLIP-m18/altclip_finetuning.py b/examples/AltCLIP-m18/altclip_finetuning.py new file mode 100644 index 00000000..2eb7823d --- /dev/null +++ b/examples/AltCLIP-m18/altclip_finetuning.py @@ -0,0 +1,65 @@ +# Copyright © 2022 BAAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License") +import torch +from flagai.auto_model.auto_loader import AutoLoader +import os +from flagai.trainer import Trainer +from torchvision.datasets import ( + CIFAR10 +) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +dataset_root = "./clip_benchmark_datasets" +dataset_name = "cifar10" + +batch_size = 4 +classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] + +auto_loader = AutoLoader( + task_name="txt_img_matching", + model_dir="./checkpoints", + model_name="AltCLIP-XLMR-L-m18" # Load the checkpoints from Modelhub(model.baai.ac.cn/models) +) + +model = auto_loader.get_model() +model.to(device) +model.eval() +tokenizer = auto_loader.get_tokenizer() +transform = auto_loader.get_transform() + +trainer = Trainer(env_type="pytorch", + pytorch_device=device, + experiment_name="clip_finetuning", + batch_size=4, + lr=1e-4, + epochs=10, + log_interval=10) + +dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name), + transform=transform, + download=True) + +def cifar10_collate_fn(batch): + # image shape is (batch, 3, 224, 224) + images = torch.tensor([b[0]["pixel_values"][0] for b in batch]) + # text_id shape is (batch, n) + input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}", + padding=True, + truncation=True, + max_length=77)["input_ids"] for b in batch]) + + attention_mask = torch.tensor([tokenizer(f"a photo of a {b[1]}", + padding=True, + truncation=True, + max_length=77)["attention_mask"] for b in batch]) + + return { + "pixel_values": images, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + +if __name__ == "__main__": + trainer.train(model=model, train_dataset=dataset, collate_fn=cifar10_collate_fn) \ No newline at end of file diff --git a/examples/AltCLIP-m18/altclip_inference.py b/examples/AltCLIP-m18/altclip_inference.py new file mode 100644 index 00000000..7be92f1e --- /dev/null +++ b/examples/AltCLIP-m18/altclip_inference.py @@ -0,0 +1,41 @@ +import torch +from PIL import Image +from flagai.auto_model.auto_loader import AutoLoader + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +loader = AutoLoader( + task_name="txt_img_matching", + model_name="AltCLIP-XLMR-L-m18", # Load the checkpoints from Modelhub(model.baai.ac.cn/models) + model_dir="./checkpoints" +) + +model = loader.get_model() +tokenizer = loader.get_tokenizer() +transform = loader.get_transform() + +model.eval() +model.to(device) +tokenizer = loader.get_tokenizer() + +def inference(): + image = Image.open("./dog.jpeg") + image = transform(image) + image = torch.tensor(image["pixel_values"]).to(device) + tokenizer_out = tokenizer(["a rat", "a dog", "a cat"], + padding=True, + truncation=True, + max_length=77, + return_tensors='pt') + + text = tokenizer_out["input_ids"].to(device) + attention_mask = tokenizer_out["attention_mask"].to(device) + with torch.no_grad(): + image_features = model.get_image_features(image) + text_features = model.get_text_features(text, attention_mask=attention_mask) + text_probs = (image_features @ text_features.T).softmax(dim=-1) + + print(text_probs.cpu().numpy()[0].tolist()) + +if __name__=="__main__": + inference() \ No newline at end of file diff --git a/examples/AltCLIP-m18/dog.jpeg b/examples/AltCLIP-m18/dog.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..3e79ea589c086e5608b8aac6aa250f2450ae03d7 GIT binary patch literal 6215 zcmZvAbyU<1)AkQox)-E&>F!#(o26@Ml$7oTL6q*6Mx_^6N+}f(lw5G>kdl^`z7d2Q z<@5RCJ>PrId(D}1rmmU!cjkWaegmM=(}rsUI5+@+^H+fTZ$LaiLP$hM_kiI)`hUR4 z!1NzsVq#!mVr2XqgR+ZoLs4EP#pGO`2WoRiv179eJ?;w1eD<@;o)!qxYRg!)HwHp00aPV@o@gL>i@%E7laQG z5dKZVr~n*1Tzp&*0Um&hi}$wz=PwhF9RwCvGj$5Y=ZJYhBT>=PcSbP1NvrOB&KWfG zm&J$!{P&K3X}EYG{J$xF>c4$(adAMn0O9|ghl5KEX2%m(1JQ6e(V9xA2gX$No!>75 zWO)C37B!#(EP5Dt*yFmW|3f5K!Gh@JSA{aw(CX!hPqQ=GR}$kBbgG)AF+)-KM9W3! zeN&=pG8AW=MSpgU2}pfklO)!hgn(K-#jrwc@^;Ny{c~y2 zu`NY6R%A+wSYuef{k#&TSIS9rWuF4g0OxVV3{_Ux3Hf435w)n5WuQI4NfwNEd1TO;~ zsteMyp|kx~ZtD|w%HT&~0Y6%8Cx9RaNf}WJThXDnf)wVIa2}it$KoxzLSi2*Go-Ni zl4sSQe;%+gD-m1X22(^^KRxR~=v`_UbZCP8QN;Y=5L>>j;uk~R0y~Epk68P=f#qH| zYEfiv1YAaGR-nz zz;2m+$#qGabBB+W=ehO|@0r-=ErCh~4mbg_o)W7_qpBi|6*912rkBmTLPzQh?6szV#>GEpprwP+$8OnGOF|^l&IR zzY057zHz)7cD?j`bl|L_!(Q5yS}Q$(z_F#=`*X-ew@QL#gTXwCiD{OC@*;=VS(t|O z;2~c6c%nBks5ynx1wwoC+TOFML(#^U^aZ#;ZzH#=DBFsvvx%>@KE1Kg`W@Du7RXhv zBr@sT!3hc{f>8?CAw`KIn1S+6;vgz;J(U)Ls}Y}*lt(95m#}LBe?AGoEv1rC^6MnU zd#*;}C77L?>r0+3Cx?w!I|j>|gxulEZt;P85t`?Fju}?L9e2OMu3WUW^tAN(6%snc z`sn+J^}KijE{hqqayyJ_HS=q%sgj8!-ejndyPAeV>uNO{dJ}}%MFXpSRKUXfXntB9 z3d8^jUav_lza*j$pNTIOKRk-QN$PtnVfox!?{kRYL=@ZE6UDv9ZOZ*oL$ky3g(Kqu z=zD0_7hYD-9NRJZe6vqOAA5yyJho1u2ik5aEvvh7F*T+S$J zE*^5x(;G{K@BUa)dKE7v(XEkGfnanV$!Wtn9kk+j_jQg9Ls2kdrn})s0qu~Dm1h~zTTlo9tGt9Tnywq_{;Q~ZuG5BXO*r<2UhMx`&`|} zByodoLTDy`z3Ue2k0d1xeq6*`F+*sm+%OSEcdlrt(^@v*qM1G^0pRX}0u|o+?F=6x zq*AJD>5_dxlmiVyPbKKdYNi7=(o>JhoyfHXqo~yDYz~HQpue0yNsDe)(NdeCWh_~( z)wxI35?F|a-FT;nDasp_E-ZRB_{_bvJ{uDIlDa{-oEaKlrWk|Eps14O?ajChjs|`X zUr2+2D`|YK1&y1R-`IL>YHEvAz*&C2$EWfHt;tT~Ro6cP{sxK{dM_z@zOtL`mzW&Y z*dQ}Uwywril5;My1#O5MUz)Y4q+QrnThcHi+|BDE8vJ9amb>Ug-%V}(g#E~Fa6f!y z5Sv&v%gx{OChcetm+H>Va4??y>EyQ$kN=zqx_v<8;r~X^bz|RS^k!*aly&V0Q`91? zgZoDt9Z{*f77^io1MBI&5MOqy`k(Op74w5>-Vj6dm11i~xxL=UXusg+_*k3-#hPd7 zin*i@r5N?ntNFE9o?sNkGY6H+sU)JOhk4JM%i?37idrR0kYLc`K1x=0sck~9_*(|z zja`^goG=&HvLxb;WOM#lgNJMpP3eTOvr*%0KiteowGEo=@dE>ta#A7gVN`zA}+TFX@wv{<9Kl{zpIA15qEQrRo#wvl{hM zFO7cG&1An3|2i^RJYUW@$Se7Pq@^0)Iuh!Z_tMN(FhgHYH(gsm{UKv3H$K65YJjUp zOD^2(=&&z*A8PnPk2gVf7raPpVL&A4p`fkUEqG&HzR6!x4j=3VW#sE16odo)ibwg) z1D}e*a$aJiz1lpSbyId8qOWXys1wL`Z$>P8K{tvm)<*fWnnkzIZ<*# zK8@t2UAzg0ra#RRC{c@SfL~$`OKhLCuB9L8*2*(aEJ!z}=x>Gmx(vq% zW;ScxlERX+!W=chnPz>WZCXVAgOf?N+a{m-vB6j;`8RY9T&OKiH}KbOy%2$((%97D zw3Z9wIVIOL7&8Z3#@bl5!OWya??MaU^rW(upnuidQvAFFiDh|CzXa@rm@Nu5NGA;S>uCrO$F<7H zU&xwgNIVksN|cDx{Tj?L<>CsP_^6w=p0q=tWHNh(seZ65JLA-BWmFcbets$n<9H=w*)jU2F*^44^=uUE_qssHdqZHfKFuTA`@2x&P;k0-{N zoG1Hqm6CPlNSnjfamD}IqHbZ!m`_8-V+-&p+4G!*BbH3#)5bAn%T1;?HZF~mxCaaZ zv(&8+P=q4+qw}XicwY#@9@693r!f}4TWC5eppWAF&i$b^4S%k)H&#abIxXB;9Zw${uAXt6 z=;;YyO1K$ADlJS0WK)>D^s3T_s4rkSK6NlNY!uySpmg4$qdL}21~!^;?DwE{K|K;) z2u?!{&JAF$dDI{3wlF78p`Yt`h0GfI6%VVc(`ejpt!k>@5ShM6y2NlI~g_cy6pRzllf9sGmpxMjy-16-prR(_6 zB!RRILh%#~>pZS_ahP)(Pbifci_AC$^InZIs^21~XDJNYzt`@jb!?5eM5hFRu6IaP zxiViK)muiCuH=RTKj$f@L>~Cyv=IC@-nXJHC?3YnJzoB?|D1sCo3!tw(Qrceujrs? zA#@zw{H$+?x)%*#PhEGV&99&)V#VDBYn3}wo#_*J!+0dIyZC1XwfTdt{iabfrs+1| z<(OSl``th-DWjf|MM1{qFq!ZBvVFRiJ?-2j2%m~I+wcMLI+X5R`o;LF(nwbAG@f{P zu%uM&X=7lRkz$UqhjHSAP)MUF+kUa;ge8@T*3b8~N9YesTl>NVMzu4)Pg^@TC%w~N za}kWEE0fXLOojv&Wr(pk9bRQ`q?#pRU2m zHePL4b0F}o4tI7*Hlx&ikTLJEX9~Q(&7A&Qr_s5$8^JK_8WLY%vGPE3$!TY(6Ul0N z_#(!WGGOCRzs$k~O$WZ|8FuI-oS8|58x}jM6UWufC?kAo8Z zCS7uT4J2023buR5SC4E$SX$11j)Y}=Q^wE)ET*iQMQ}IKE1HuCh$ZKMGJo`^4t0Uq zRz#m+8~UTzx$|GoH&fXYwXEG<d#h?ce&>MG=!ek1RgYsg7z4aX~8zf4HUSE#<9 z{rF?WVGU<^S!>d+bT}xRm#_JigJ(F`ta#wn$>{iySO@Hs#ot;dRoHvPJ0^|#ZP$3M zDxoaQmMbHgj*7P1Vd%YdNlMS}Hzg`OL3v)4X=CWK8C*De{?(5lCGg@u)br2Z!O>T* z2D3zJ{P4eDyY9n33$<>LO}rO4R#al>nvgnH{P+5(U`#b_*m&Jdg#2Gs}MYV(RAccPz_zhL$~$rTp&so_{Ix?4vspNg&jOE6}%p-Q0`p4HB0YWMhU{*38Y zd|ms|XgvM}b;D2SBAlD(7S2ysS%`kl(Vp-XgbEIUK!F$C3LIp1F1uw{iDC zvb}ElLaa^^9vDQ66nlH~tSOqqu*N9f5cS;5`Yto%OSFE72a|o@8cp-&@*$pI*v8MU zVmX<5SZd;%BaM2wR!()^nnAt7tz%%dS@%^Y(Z4lcos_-TK*m%!r@yOQApcNIg@V=( zRovts>iX8!-UB3}LrR&TUcV=VJze>8e7n&l?s{7M2;GOl7x5G<22^G$=15?a< z$B$!2L%n#_cRiatrQ8+$eQlg#1e6hLR|tj%1N4Oc_l7Q&lPwHHV z!bWWU19Yg_#}!w2HY8n>exMbT7b=%8y{cPxp`IUhyryH7=v;{27v2e4do8ebJLu(0 ztiruO)!2Et3pyOLz7FYT5e zYnRwR%F8kiiTZU6t(dKosNWEi816a0AWU3S9h0P)Ik~bFLvU>NekkEh3vQ*sZ(IAN zL7>J|{L}82Q^j|?w^DIo-jIkFJimpb>Subxk_moWniY&Yw|j?W2;C4Y@V3e`AWfh+ zU>&?8D@*j&TisXIazNGS>?PJg(DQNd(3xPB z;*P5zn~YNGu1#_<$l%*W=Zg6(_YUJD^?GxUh?Oa=lNpWou_9@kV&o27xSSe<#>1^) zZwhZg5(bR1Ju6sSuZmo15%v#zx=7!-l>|^OoKN%Z<0y6C64@&azmQ-}qJg zdyhP7O;3FaK8Q&Kt(|cdFR>1^C|Lgl;p(bHY4-mj=eS@qY-b~8_TR7gdKY`nJt4w` zQ#{TU?2p{{3(cuiW^VEx8g|g(*rWqTbHbLUoslwS9SDrF?UoFq9i~aSi#4F|5$!at zD#Rdw_hD;Rl|)seZA6*A!S4qJpU6_;=q|Of!ZQ)NdnOf^ycEXX3pbl<}2zSmpW(V+5S||oA~Eg!{oVD z_C|S)Lqw0HA=VBk<|N}KHPd6?%70v1(uOUNb zlaFJKGUqWc{g2bLtlmty zp;iybO8PQRTz!XsrTnb!BBicw6w$FsP$83Ba=lO?`NW>d%O^&@7pZ(W5h{`I&d~17 zdxXFmUCUE58BE&cG02O7is#*J?6;WImfcr*>JnaG#p+^(RSq3&g}wkGndKv#OX3sF z93hg2)kEcAIIrI=_j^ni)vfQ7qBKpU^c2s#SfUbuoY3iWWdK{Na5GE5H8i1dTk&S#0G=mFTv6hAidw>qa z9E)eBFD)j3^g$}Og?Za)7eB7w`FG}+kCf~E)bxvkz56A^U#6}8u#13tT@QJ5uplHG z%glqrxgsMgF~*A(0GLW<@=)NOt=A7$y-4r!N8h2k>J~Z=gJco2LwG^PiQ$0A3OQ>OTh$f+k+BpZ_whw8*+&`d_&X BI*= 5: + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + else: + acc1, = accuracy(logits, target, topk=(1,)) + acc5 = float("nan") + mean_per_class_recall = balanced_accuracy_score(target, pred) + if verbose: + print(classification_report(target, pred, digits=3)) + return {"acc1": acc1, "acc5": acc5, "mean_per_class_recall": mean_per_class_recall} diff --git a/flagai/auto_model/auto_loader.py b/flagai/auto_model/auto_loader.py index 71ac1c44..8ecebd6c 100644 --- a/flagai/auto_model/auto_loader.py +++ b/flagai/auto_model/auto_loader.py @@ -135,6 +135,8 @@ def __getattr__(self, name): "AltCLIPProcess"], "altclip-xlmr-l-m9": ["flagai.models.mm.AltCLIP", "AltCLIP", "altclip", "mm", "flagai.model.mm.AltCLIP", "AltCLIPProcess"], + "altclip-xlmr-l-m18": ["flagai.models.mm.AltCLIP", "AltCLIP", "altclip", "mm", "flagai.model.mm.AltCLIP", + "AltCLIPProcess"], "altclip-bert-b": ["flagai.models.mm.AltCLIP", "AltCLIP", "altclip", "mm", "flagai.model.mm.AltCLIP", "AltCLIPProcessBert"], "eva-clip": ["flagai.model.mm.eva_clip_model", "EVA_CLIP", "evaclip", "mm"], diff --git a/flagai/model/base_model.py b/flagai/model/base_model.py index c385c52a..2399ed3a 100644 --- a/flagai/model/base_model.py +++ b/flagai/model/base_model.py @@ -213,6 +213,7 @@ def load_diffusion_local(yaml_path, only_download_config=False, **kwargs): def download(cls, download_path='./checkpoints/', model_name='RoBERTa-base-ch', + only_download_config=False, **kwargs): try: model_id = _get_model_id(model_name) @@ -227,4 +228,6 @@ def download(cls, if not file_name.endswith("bin"): _get_vocab_path(os.path.join(download_path, model_name), file_name, model_id) else : + if only_download_config: + continue _get_checkpoint_path(os.path.join(download_path, model_name), file_name, model_id) \ No newline at end of file diff --git a/flagai/model/mm/AltCLIP.py b/flagai/model/mm/AltCLIP.py index 9757e2ff..a50713a4 100644 --- a/flagai/model/mm/AltCLIP.py +++ b/flagai/model/mm/AltCLIP.py @@ -446,7 +446,7 @@ def from_pretrain(cls, only_download_config=False, device="cpu", **kwargs): - super().download(download_path, model_name) + # super().download(download_path, model_name, only_download_config=only_download_config) pretrained_model_name_or_path = os.path.join(download_path, model_name) print(pretrained_model_name_or_path) return CLIPHF.from_pretrained(pretrained_model_name_or_path) diff --git a/flagai/model/mm/modeling_altclip.py b/flagai/model/mm/modeling_altclip.py new file mode 100644 index 00000000..ffc9d5ab --- /dev/null +++ b/flagai/model/mm/modeling_altclip.py @@ -0,0 +1,1759 @@ +# coding=utf-8 +# Copyright 2022 The BAAI Teams Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch AltCLIP model.""" +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig + + +@dataclass +class BaseModelOutputWithPoolingAndProjection(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` of shape `(batch_size,config.project_dim)`. + + Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder. + """ + + last_hidden_state: torch.FloatTensor = None + penultimate_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + pooler_output2: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + projection_state: Optional[Tuple[torch.FloatTensor]] = None + +logger = logging.get_logger(__name__) + +_TOKENIZER_FOR_DOC = "XLMRobertaTokenizer" +_CHECKPOINT_FOR_DOC = "BAAI/AltCLIP" +_CONFIG_FOR_DOC = "AltCLIPConfig" + +ALTCLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "BAAI/AltCLIP", + # See all AltCLIP models at https://huggingface.co/models?filter=altclip +] + + +ALTCLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ALTCLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`XLMRobertaTokenizerFast`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +ALTCLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +ALTCLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`XLMRobertaTokenizerFast`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +def clip_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->AltCLIP +class AltCLIPOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPTextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of + [`AltCLIPVisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`AltCLIPTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`AltCLIPVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->AltRoberta +class AltRobertaEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->AltRoberta +class AltRobertaSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in AltRobertaModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput +class AltRobertaSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta +class AltRobertaAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = AltRobertaSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = AltRobertaSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate with Roberta->AltRoberta +class AltRobertaIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaOutput +class AltRobertaOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta +class AltRobertaLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = AltRobertaAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = AltRobertaAttention(config, position_embedding_type="absolute") + self.intermediate = AltRobertaIntermediate(config) + self.output = AltRobertaOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->AltRoberta +class AltRobertaEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([AltRobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaPooler +class AltRobertaPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->AltCLIP +class AltCLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->AltCLIP +class AltCLIPMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->AltCLIP +class AltCLIPEncoderLayer(nn.Module): + def __init__(self, config: AltCLIPConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = AltCLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim) + self.mlp = AltCLIPMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->AltCLIP +class AltCLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`AltCLIPEncoderLayer`]. + + Args: + config: AltCLIPConfig + """ + + def __init__(self, config: AltCLIPConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + causal_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->AltCLIP +class AltCLIPVisionEmbeddings(nn.Module): + def __init__(self, config: AltCLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class AltCLIPPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = AltCLIPConfig + base_model_prefix = "altclip" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, AltCLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, AltCLIPAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, AltCLIPMLP): + factor = self.config.initializer_factor + in_proj_std = ( + (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + ) + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, AltCLIPModel): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, AltCLIPEncoder): + module.gradient_checkpointing = value + if isinstance(module, AltRobertaEncoder): + module.gradient_checkpointing = value + + +# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING +class AltCLIPVisionTransformer(nn.Module): + def __init__(self, config: AltCLIPVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = AltCLIPVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim) + self.encoder = AltCLIPEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim) + + @add_start_docstrings_to_model_forward(ALTCLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=AltCLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class AltCLIPVisionModel(AltCLIPPreTrainedModel): + config_class = AltCLIPVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: AltCLIPVisionConfig): + super().__init__(config) + self.vision_model = AltCLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(ALTCLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=AltCLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AltCLIPProcessor, AltCLIPVisionModel + + >>> model = AltCLIPVisionModel.from_pretrained("BAAI/AltCLIP") + >>> processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class AltRobertaModel(AltCLIPPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in *Attention is + all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + + .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 + + """ + + config_class = AltCLIPTextConfig + + # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->AltRoberta + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = AltRobertaEmbeddings(config) + self.encoder = AltRobertaEncoder(config) + + self.pooler = AltRobertaPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + # Copied from transformers.models.bert.modeling_bert.BertModel.forward + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class AltCLIPTextModel(AltCLIPPreTrainedModel): + config_class = AltCLIPTextConfig + + def __init__(self, config): + super().__init__(config) + self.roberta = AltRobertaModel(config, add_pooling_layer=False) + self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.roberta.embeddings.word_embeddings + + def set_input_embeddings(self, value: nn.Embedding) -> None: + self.roberta.embeddings.word_embeddings = value + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding: + return super().resize_token_embeddings(new_num_tokens) + + @add_start_docstrings_to_model_forward(ALTCLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndProjection, config_class=AltCLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AltCLIPProcessor, AltCLIPTextModel + + >>> model = AltCLIPTextModel.from_pretrained("BAAI/AltCLIP") + >>> processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP") + + >>> texts = ["it's a cat", "it's a dog"] + + >>> inputs = processor(text=texts, padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + # last module outputs + sequence_output = outputs[0] + + # project the last outputs + sequence_output = self.pre_LN(sequence_output) + + # pooler + projection_state = self.transformation(sequence_output) + pooler_output = projection_state[:, 0] + + sequence_output2 = outputs[1][-2] + + # project every module + sequence_output2 = self.pre_LN(sequence_output2) + + # pooler + projection_state2 = self.transformation_pre(sequence_output2) + pooler_output2 = projection_state2[:, 0] + if not return_dict: + return (projection_state, pooler_output) + outputs[2:4] + + return BaseModelOutputWithPoolingAndProjection( + last_hidden_state=projection_state, + penultimate_hidden_state=projection_state2, + pooler_output=pooler_output, + pooler_output2 = pooler_output2, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class AltCLIPModel(AltCLIPPreTrainedModel): + config_class = AltCLIPConfig + + def __init__(self, config: AltCLIPConfig): + super().__init__(config) + + if not isinstance(config.vision_config, AltCLIPVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type AltCLIPVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + if not isinstance(config.text_config, AltCLIPTextConfig): + raise ValueError( + "config.text_config is expected to be of type AltCLIPTextConfig but is of type" + f" {type(config.text_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.project_dim + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = AltCLIPTextModel(text_config) + self.vision_model = AltCLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ALTCLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + token_type_ids=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`AltCLIPTextModel`]. + + Examples: + + ```python + >>> from transformers import AltCLIPProcessor, AltCLIPModel + + >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP") + >>> processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP") + >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use AltCLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + token_type_ids=token_type_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(ALTCLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`AltCLIPVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AltCLIPProcessor, AltCLIPModel + + >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP") + >>> processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use AltCLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(ALTCLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=AltCLIPOutput, config_class=AltCLIPConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + token_type_ids=None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, AltCLIPOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AltCLIPProcessor, AltCLIPModel + + >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP") + >>> processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use AltCLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.T + + loss = None + if return_loss: + loss = clip_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return AltCLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx