diff --git a/README.md b/README.md
index 1d2990b0..614e9fb0 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@
--------------------------------------------------------------------------------
-FlagAI (Fast LArge-scale General AI models) is an fast, easy-to-use and extensible toolkit for large-scale model. Our goal is to support training, fine-tuning, and deployment of large-scale models on various downstream tasks with multi-modality. Currently, we are focusing on NLP models and tasks. In near futher, we will support for other modalities.
+FlagAI (Fast LArge-scale General AI models) is a fast, easy-to-use and extensible toolkit for large-scale model. Our goal is to support training, fine-tuning, and deployment of large-scale models on various downstream tasks with multi-modality. Currently, we are focusing on NLP models and tasks. In near futher, we will support for other modalities.
* Now it supports **WuDao GLM** with a maximum of 10 billion parameters (see [Introduction to GLM](/docs/GLM.md)). It also supports **BERT**, **RoBERTa**, **GPT2**, **T5**, and models from Huggingface Transformers.
@@ -18,7 +18,7 @@ FlagAI (Fast LArge-scale General AI models) is an fast, easy-to-use and extensib
* FlagAI is backed by the three most popular data/model parallel libraries — [PyTorch](https://pytorch.org/)/[Deepspeed](https://www.deepspeed.ai/)/[Megatron-LM](https://github.com/NVIDIA/Megatron-LM) — with seamless integration between them. Users can parallel their training/testing process with less than ten lines of code.
-The code is partially based on [GLM](https://github.com/THUDM/GLM), [Transformers](https://github.com/huggingface/transformers) and [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM).
+The code is partially based on [GLM](https://github.com/THUDM/GLM), [Transformers](https://github.com/huggingface/transformers), [timm](https://github.com/rwightman/pytorch-image-models) and [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM).
@@ -156,7 +156,6 @@ start with our [contributor guidelines](CONTRIBUTING.md) and then
check these [open issues](https://github.com/BAAI-Open/FlagAI/issues) for specific tasks.
## Contact us
-Scan wechat QR code
diff --git a/README_zh.md b/README_zh.md
index 88def64d..e8dc5a3f 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -18,7 +18,7 @@
* 飞智由三个最流行的数据/模型并行库([PyTorch](https://pytorch.org/)/[Deepspeed](https://www.deepspeed.ai/)/[Megatron-LM](https://github.com/NVIDIA/Megatron-LM))提供支持,它们之间实现了无缝集成。 你可以用不到十行代码来并行你的训练/测试过程。
-本项目的部分代码基于[GLM](https://github.com/THUDM/GLM),[Transformers](https://github.com/huggingface/transformers) 和 [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM).
+本项目的部分代码基于 [GLM](https://github.com/THUDM/GLM),[Transformers](https://github.com/huggingface/transformers),[timm](https://github.com/rwightman/pytorch-image-models) 和 [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM).
diff --git a/doc_zh/TUTORIAL_4_TRAINER.md b/doc_zh/TUTORIAL_4_TRAINER.md
index c70c7871..0e83ffb9 100644
--- a/doc_zh/TUTORIAL_4_TRAINER.md
+++ b/doc_zh/TUTORIAL_4_TRAINER.md
@@ -13,7 +13,7 @@
- [deepspeed](#deepspeed)
- [pytorchDDP](#pytorchddp)
- [deepspeed + megatron-lm](#deepspeed--megatron-lm)
-
+- [EnvTrainer](#EnvTrainer)
Trainer 类提供了API用于多种并行框架的训练。API 支持在多个 GPU上使用Pytorch DDP/Deepspeed进行分布式训练,同时支持Megatron-LM+Deepspeed的混合并行分布式训练,同时也通过 NVIDIA Apex 实现混合精度。
## 入门
@@ -335,3 +335,72 @@ trainer = MyTrainer(
)
```
+# EnvTrainer
+
+为了更容易的输入参数,我们提供了EnvTrainer代替原来的Trainer
+例如:
+```python
+# train.py
+import torch
+from flagai.env_args import EnvArgs
+from flagai.env_trainer import EnvTrainer
+
+lr = 2e-5
+n_epochs = 50
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+env_args = EnvArgs(
+ env_type="pytorch",
+ experiment_name="vit-cifar100-single_gpu",
+ batch_size=150,
+ num_gpus=1,
+ gradient_accumulation_steps=1,
+ lr=lr,
+ weight_decay=1e-5,
+ epochs=n_epochs,
+ log_interval=100,
+ eval_interval=1000,
+ load_dir=None,
+ pytorch_device=device,
+ save_dir="checkpoints_vit_cifar100_single_gpu",
+ save_interval=1000,
+ num_checkpoints=1,
+)
+
+env_args.add_arg(arg_name="test1", default=0, type=int, )
+env_args_parse = env_args.parse_args()
+trainer = EnvTrainer(env_args)
+```
+
+运行train.py文件时,可以通过命令行修改输入参数。
+```commandline
+python train.py --batch_size=8 --epochs=10
+```
+如果你需要添加额外的参数,你可以调用这个函数:
+```python
+env_args.add_arg(arg_name="test1", default=0, type=int, )
+```
+然后你可以运行如下命令中的train.py文件:
+```commandline
+python train.py --test1=1
+```
+更多的例子可以查看 :
+
+1. [vit-env-trainer](https://github.com/BAAI-Open/FlagAI/tree/master/examples/vit_cifar100/train_env_trainer.py)
+
+2. [glm-title-generation-env-trainer](https://github.com/BAAI-Open/FlagAI/tree/master/examples/glm_title_generation/train_env_trainer.py)
+
+
+# 使用 pytorchDDP launcher 或 deepspeed launcher 运行
+如果你使用多个GPU来训练模型,你可以直接运行train.py来调用FlagAI训练器中的启动器。
+```commandline
+python train.py
+```
+另外,你也可以使用pytorchDDP和deepspeed启动器来运行,例如:
+### pytorchDDP
+```commandline
+python -m torch.distributed.launch --nproc_per_node 2 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 17750 train_env_trainer.py --not_call_launch
+```
+### deepspeed
+```commandline
+python -m deepspeed.launcher.launch --master_addr=172.31.125.121 --master_port=17500 train.py --not_call_launch
+```
diff --git a/docs/TUTORIAL_4_TRAINER.md b/docs/TUTORIAL_4_TRAINER.md
index f78526b4..2ec7d235 100644
--- a/docs/TUTORIAL_4_TRAINER.md
+++ b/docs/TUTORIAL_4_TRAINER.md
@@ -13,6 +13,9 @@
- [deepspeed](#deepspeed)
- [pytorchDDP](#pytorchddp)
- [deepspeed + megatron-lm](#deepspeed--megatron-lm)
+- [EnvTrainer](#EnvTrainer)
+
+
The Trainer class provides APIs for training with multiple parallel frameworks. The API supports distributed training with Pytorch DDP/Deepspeed on multiple GPUs, as well as mixed parallel distributed training with Megatron-LM+Deepspeed, and mixed precision via NVIDIA Apex.
## Getting Started
@@ -341,3 +344,76 @@ trainer = MyTrainer(
model_paralle_size = 2
)
```
+
+# EnvTrainer
+
+To input the parameters easier, we provided the EnvTrainer to replace the original Tranier.
+
+Taking the code for example:
+```python
+# train.py
+import torch
+from flagai.env_args import EnvArgs
+from flagai.env_trainer import EnvTrainer
+
+lr = 2e-5
+n_epochs = 50
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+env_args = EnvArgs(
+ env_type="pytorch",
+ experiment_name="vit-cifar100-single_gpu",
+ batch_size=150,
+ num_gpus=1,
+ gradient_accumulation_steps=1,
+ lr=lr,
+ weight_decay=1e-5,
+ epochs=n_epochs,
+ log_interval=100,
+ eval_interval=1000,
+ load_dir=None,
+ pytorch_device=device,
+ save_dir="checkpoints_vit_cifar100_single_gpu",
+ save_interval=1000,
+ num_checkpoints=1,
+)
+
+env_args.add_arg(arg_name="test1", default=0, type=int, )
+env_args_parse = env_args.parse_args()
+trainer = EnvTrainer(env_args)
+```
+
+When you run the train.py file, you can modify the input parameters through command line.
+```commandline
+python train.py --batch_size=8 --epochs=10
+```
+If you need to add additional parameters, you can call the function:
+```python
+env_args.add_arg(arg_name="test1", default=0, type=int, )
+```
+Then you can run the train.py file in the following command:
+```commandline
+python train.py --test1=1
+```
+
+More examples in :
+
+1. [vit-env-trainer](https://github.com/BAAI-Open/FlagAI/tree/master/examples/vit_cifar100/train_env_trainer.py)
+
+2. [glm-title-generation-env-trainer](https://github.com/BAAI-Open/FlagAI/tree/master/examples/glm_title_generation/train_env_trainer.py)
+
+
+# Run with pytorchDDP launcher or deepspeed launcher
+If you use multiple GPU to train models, you can run the train.py directly which to call the launcher in FlagAI Trainer.
+```commandline
+python train.py
+```
+In addition, you also can use the pytorchDDP and deepspeed launcher to run, as example:
+
+### pytorchDDP
+```commandline
+python -m torch.distributed.launch --nproc_per_node 2 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 17750 train_env_trainer.py --not_call_launch
+```
+### deepspeed
+```commandline
+python -m deepspeed.launcher.launch --master_addr=172.31.125.121 --master_port=17500 train.py --not_call_launch
+```
\ No newline at end of file
diff --git a/examples/glm_title_generation/train_env_trainer.py b/examples/glm_title_generation/train_env_trainer.py
new file mode 100644
index 00000000..39c4524f
--- /dev/null
+++ b/examples/glm_title_generation/train_env_trainer.py
@@ -0,0 +1,144 @@
+# Copyright © 2022 BAAI. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License")
+import os
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from flagai.auto_model.auto_loader import AutoLoader
+from flagai.env_trainer import EnvTrainer
+from flagai.env_args import EnvArgs
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+# You can input all parameters by the command line.
+# For example: python train_env_trainer.py --epochs=300 --batch_size=4 --env_type=pytorch
+env_args = EnvArgs()
+trainer = EnvTrainer(env_args)
+
+cur_dir = os.path.dirname(os.path.abspath(__file__))
+src_dir = cur_dir + '/data/train.src'
+tgt_dir = cur_dir + '/data/train.tgt'
+
+maxlen = 256
+auto_loader = AutoLoader("lm",
+ model_name="GLM-large-ch",
+ model_dir="./state_dict/")
+model = auto_loader.get_model()
+tokenizer = auto_loader.get_tokenizer()
+
+def read_file():
+ src = []
+ tgt = []
+
+ with open(src_dir, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+ for line in lines:
+ src.append(line.strip('\n').lower())
+
+ with open(tgt_dir, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+ for line in lines:
+ tgt.append(line.strip('\n').lower())
+
+ return src, tgt
+
+
+class GLMSeq2seqDataset(Dataset):
+
+ def __init__(self,
+ sents_src,
+ sents_tgt,
+ tokenizer,
+ max_src_length=300,
+ max_tgt_length=200):
+ super(GLMSeq2seqDataset, self).__init__()
+ self.sents_src = sents_src
+ self.sents_tgt = sents_tgt
+ self.tokenizer = tokenizer
+ self.max_src_length = max_src_length
+ self.max_tgt_length = max_tgt_length
+ self.no_block_position = False
+
+ def __getitem__(self, i):
+ source_text = self.sents_src[i]
+ target_text = self.sents_tgt[i]
+ data = self.tokenizer.encode_plus(source_text, target_text)
+
+ return data
+
+ def __len__(self):
+
+ return len(self.sents_src)
+
+
+class GLMPoetryDynamicCollateFN(): #padding process in each batch
+
+ def __init__(self, pad_id):
+ self.pad_id = pad_id
+
+ def pad_token(self, tokens, max_length):
+ pad_len = max_length - len(tokens)
+ tokens += [self.pad_id] * pad_len
+ return tokens
+
+ def pad_position_ids(self, position_ids, max_length):
+ pad_len = max_length - len(position_ids[0])
+ position_ids[0] += [len(position_ids[0]) + x for x in range(pad_len)]
+ position_ids[1] += [1] * pad_len
+ return position_ids
+
+ def pad_loss_mask(self, loss_mask, max_length):
+ pad_len = max_length - len(loss_mask)
+ loss_mask += [0] * pad_len
+ return loss_mask
+
+ def __call__(self, batch):
+ input_ids = [data["input_ids"] for data in batch]
+ target_ids = [data["target_ids"] for data in batch]
+ position_ids = [data["position_ids"] for data in batch]
+ attention_mask = [data['attention_mask'] for data in batch]
+ loss_mask = [data['loss_mask'] for data in batch]
+
+ max_length = max([len(t) for t in input_ids])
+ for i in range(len(input_ids)):
+ input_ids[i] = self.pad_token(input_ids[i], max_length)
+ target_ids[i] = self.pad_token(target_ids[i], max_length)
+ position_ids[i] = self.pad_position_ids(position_ids[i],
+ max_length)
+ loss_mask[i] = self.pad_loss_mask(loss_mask[i], max_length)
+ return {
+ 'input_ids': torch.LongTensor(input_ids),
+ 'labels': torch.LongTensor(target_ids),
+ 'position_ids': torch.LongTensor(position_ids),
+ 'attention_mask': torch.LongTensor(attention_mask),
+ 'loss_mask': torch.LongTensor(loss_mask)
+ }
+
+
+sents_src, sents_tgt = read_file()
+my_collate_fn = GLMPoetryDynamicCollateFN(
+ pad_id=tokenizer.get_command('pad').Id)
+
+data_len = len(sents_tgt)
+train_size = int(data_len * 0.8)
+train_src = sents_src[:train_size][:2000]
+train_tgt = sents_tgt[:train_size][:2000]
+
+val_src = sents_src[train_size:]
+val_tgt = sents_tgt[train_size:]
+
+train_dataset = GLMSeq2seqDataset(train_src,
+ train_tgt,
+ tokenizer=tokenizer,
+ max_src_length=300,
+ max_tgt_length=200)
+val_dataset = GLMSeq2seqDataset(val_src,
+ val_tgt,
+ tokenizer=tokenizer,
+ max_src_length=300,
+ max_tgt_length=200)
+
+trainer.train(model,
+ train_dataset=train_dataset,
+ valid_dataset=val_dataset,
+ collate_fn=my_collate_fn)
diff --git a/examples/vit_cifar100/README.md b/examples/vit_cifar100/README.md
new file mode 100644
index 00000000..1cd35edd
--- /dev/null
+++ b/examples/vit_cifar100/README.md
@@ -0,0 +1,163 @@
+# Vit for classification with cifar100 dataset
+
+Vision Transformer(Vit) is becoming increasingly popular in the field of
+compute vision(CV). More and more tasks are using Vit to achieve the SOTA.
+
+The paper is in https://arxiv.org/pdf/2010.11929.pdf.
+
+Code is in https://github.com/google-research/vision_transformer.
+
+## How to use
+We can easily use the Vit to finetune cifar100 dataset.
+### Training
+```python
+import torch
+from torchvision import transforms
+from torchvision.datasets import CIFAR100
+import ssl
+ssl._create_default_https_context = ssl._create_unverified_context
+from flagai.trainer import Trainer
+from flagai.auto_model.auto_loader import AutoLoader
+
+lr = 2e-5
+n_epochs = 50
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+trainer = Trainer(
+ env_type="pytorch",
+ experiment_name="vit-cifar100",
+ batch_size=64,
+ gradient_accumulation_steps=1,
+ lr=lr,
+ weight_decay=1e-5,
+ epochs=n_epochs,
+ log_interval=100,
+ eval_interval=1000,
+ load_dir=None,
+ pytorch_device=device,
+ save_dir="checkpoints_vit_cifar100",
+ save_interval=1000,
+ num_checkpoints=1,
+)
+
+def build_cifar():
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.Resize(224),
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ train_dataset = CIFAR100(root="./cifar100", train=True, download=True, transform=transform_train)
+ test_dataset = CIFAR100(root="./cifar100", train=False, download=True, transform=transform_test)
+ return train_dataset, test_dataset
+
+def collate_fn(batch):
+ images = torch.stack([b[0] for b in batch])
+ labels = [b[1] for b in batch]
+ labels = torch.tensor(labels).long()
+ return {"images": images, "labels": labels}
+
+
+def validate(logits, labels, meta=None):
+ _, predicted = logits.max(1)
+ total = labels.size(0)
+ correct = predicted.eq(labels).sum().item()
+ return correct / total
+
+
+if __name__ == '__main__':
+ loader = AutoLoader(task_name="backbone",
+ model_name="Vit-base-p16",
+ num_classes=100)
+
+ model = loader.get_model()
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)
+
+ train_dataset, val_dataset = build_cifar()
+ trainer.train(model,
+ optimizer=optimizer,
+ lr_scheduler=scheduler,
+ train_dataset=train_dataset,
+ valid_dataset=val_dataset,
+ metric_methods=[["accuracy", validate]],
+ collate_fn=collate_fn)
+```
+
+### Validation
+If you have trained a model, you can valite it again by following code.
+```python
+import torch
+from torchvision import transforms
+from torch.utils.data import DataLoader
+from torchvision.datasets import CIFAR100
+import ssl
+ssl._create_default_https_context = ssl._create_unverified_context
+from flagai.auto_model.auto_loader import AutoLoader
+import os
+from tqdm import tqdm
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+def build_cifar():
+
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ test_dataset = CIFAR100(root="./cifar100", train=False, download=True, transform=transform_test)
+ return test_dataset
+
+def collate_fn(batch):
+ images = torch.stack([b[0] for b in batch])
+ labels = [b[1] for b in batch]
+ labels = torch.tensor(labels).long()
+ return {"images": images, "labels": labels}
+
+def validate(logits, labels, meta=None):
+ _, predicted = logits.max(1)
+ total = labels.size(0)
+ correct = predicted.eq(labels).sum().item()
+ return correct / total
+
+if __name__ == '__main__':
+
+ model_save_dir = "./checkpoints_vit_cifar100"
+ print(f"loadding model in :{model_save_dir}")
+ loader = AutoLoader(task_name="backbone",
+ model_name="Vit-base-p16",
+ num_classes=100)
+
+ model = loader.get_model()
+
+ model.load_state_dict(torch.load(os.path.join(model_save_dir, "38000", "pytorch_model.bin"), map_location=device)["module"])
+ print(f"model load success.......")
+ model.to(device)
+
+ val_dataset = build_cifar()
+
+ val_dataloader = DataLoader(val_dataset,
+ batch_size=1,
+ shuffle=False,
+ collate_fn=collate_fn)
+ index = 0
+ accuracy = 0.0
+ for data in tqdm(val_dataloader, total=len(val_dataloader)):
+ index += 1
+ data = {k: v.to(device) for k, v in data.items()}
+ labels = data["labels"]
+ pred = model(**data)["logits"]
+ acc = validate(pred, labels)
+ accuracy += acc
+
+ print(f"accuracy is {accuracy / index}")
+```
diff --git a/examples/vit_cifar100/deepspeed.json b/examples/vit_cifar100/deepspeed.json
new file mode 100644
index 00000000..f2339ca3
--- /dev/null
+++ b/examples/vit_cifar100/deepspeed.json
@@ -0,0 +1,48 @@
+{
+ "train_micro_batch_size_per_gpu": 64,
+ "gradient_accumulation_steps": 1,
+ "steps_per_print": 100,
+ "gradient_clipping": 1.0,
+ "zero_optimization": {
+ "stage": 2,
+ "contiguous_gradients": false,
+ "overlap_comm": true,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 5e7,
+ "allgather_bucket_size": 5e7,
+ "cpu_offload": true
+ },
+ "scheduler": {
+ "type": "WarmupLR",
+ "params": {
+ "warmup_min_lr": 0,
+ "warmup_max_lr": 1e-5,
+ "warmup_num_steps": 2000
+ }
+ },
+ "zero_allow_untested_optimizer": true,
+ "fp16": {
+ "enabled": true,
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1e-5,
+ "weight_decay": 0.1,
+ "betas": [
+ 0.9,
+ 0.98
+ ],
+ "eps": 1e-6
+ }
+ },
+ "activation_checkpointing": {
+ "partition_activations": true,
+ "contiguous_memory_optimization": false
+ },
+ "wall_clock_breakdown": false
+ }
diff --git a/examples/vit_cifar100/hostfile b/examples/vit_cifar100/hostfile
new file mode 100644
index 00000000..51356577
--- /dev/null
+++ b/examples/vit_cifar100/hostfile
@@ -0,0 +1 @@
+127.0.0.1 slots=2
\ No newline at end of file
diff --git a/examples/vit_cifar100/train_DDP.py b/examples/vit_cifar100/train_DDP.py
new file mode 100644
index 00000000..35c997b1
--- /dev/null
+++ b/examples/vit_cifar100/train_DDP.py
@@ -0,0 +1,87 @@
+import torch
+from torchvision import transforms
+from torchvision.datasets import CIFAR100
+import ssl
+ssl._create_default_https_context = ssl._create_unverified_context
+from flagai.trainer import Trainer
+from flagai.auto_model.auto_loader import AutoLoader
+
+lr = 2e-5
+n_epochs = 50
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+env_type = "pytorchDDP"
+trainer = Trainer(
+ env_type=env_type,
+ experiment_name="vit-cifar100-8gpu",
+ batch_size=150,
+ num_gpus=8,
+ gradient_accumulation_steps=1,
+ lr=lr,
+ weight_decay=1e-5,
+ epochs=n_epochs,
+ log_interval=100,
+ eval_interval=1000,
+ load_dir=None,
+ pytorch_device=device,
+ save_dir="checkpoints_vit_cifar100_8gpu",
+ save_interval=1000,
+ num_checkpoints=1,
+ hostfile="./hostfile",
+ training_script="train_DDP.py"
+)
+
+def build_cifar():
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.Resize(224),
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ train_dataset = CIFAR100(root="./data/cifar100", train=True, download=True, transform=transform_train)
+ test_dataset = CIFAR100(root="./data/cifar100", train=False, download=True, transform=transform_test)
+ return train_dataset, test_dataset
+
+def collate_fn(batch):
+ images = torch.stack([b[0] for b in batch])
+ if trainer.fp16:
+ images = images.half()
+ labels = [b[1] for b in batch]
+ labels = torch.tensor(labels).long()
+ return {"images": images, "labels": labels}
+
+def validate(logits, labels, meta=None):
+ _, predicted = logits.max(1)
+ total = labels.size(0)
+ correct = predicted.eq(labels).sum().item()
+ return correct / total
+
+if __name__ == '__main__':
+ loader = AutoLoader(task_name="classification",
+ model_name="vit-base-p16-224",
+ num_classes=100)
+
+ model = loader.get_model()
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)
+ train_dataset, val_dataset = build_cifar()
+
+ trainer.train(model,
+ optimizer=optimizer,
+ lr_scheduler=scheduler,
+ train_dataset=train_dataset,
+ valid_dataset=val_dataset,
+ metric_methods=[["accuracy", validate]],
+ collate_fn=collate_fn)
+
+
+
+
+
diff --git a/examples/vit_cifar100/train_deepspeed.py b/examples/vit_cifar100/train_deepspeed.py
new file mode 100644
index 00000000..9d44b1df
--- /dev/null
+++ b/examples/vit_cifar100/train_deepspeed.py
@@ -0,0 +1,88 @@
+import torch
+from torchvision import transforms
+from torchvision.datasets import CIFAR100
+import ssl
+ssl._create_default_https_context = ssl._create_unverified_context
+from flagai.trainer import Trainer
+from flagai.auto_model.auto_loader import AutoLoader
+
+lr = 2e-5
+n_epochs = 50
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+env_type = "deepspeed"
+trainer = Trainer(
+ env_type=env_type,
+ experiment_name="vit-cifar100-deepspeed",
+ batch_size=150,
+ num_gpus=8,
+ fp16=True,
+ gradient_accumulation_steps=1,
+ lr=lr,
+ weight_decay=1e-5,
+ epochs=n_epochs,
+ log_interval=100,
+ eval_interval=1000,
+ load_dir=None,
+ pytorch_device=device,
+ save_dir="checkpoints_vit_cifar100_deepspeed",
+ save_interval=1000,
+ num_checkpoints=1,
+ hostfile="./hostfile",
+ training_script="train_deepspeed.py"
+)
+
+def build_cifar():
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.Resize(224),
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ train_dataset = CIFAR100(root="./data/cifar100", train=True, download=True, transform=transform_train)
+ test_dataset = CIFAR100(root="./data/cifar100", train=False, download=True, transform=transform_test)
+ return train_dataset, test_dataset
+
+def collate_fn(batch):
+ images = torch.stack([b[0] for b in batch])
+ if trainer.fp16:
+ images = images.half()
+ labels = [b[1] for b in batch]
+ labels = torch.tensor(labels).long()
+ return {"images": images, "labels": labels}
+
+def validate(logits, labels, meta=None):
+ _, predicted = logits.max(1)
+ total = labels.size(0)
+ correct = predicted.eq(labels).sum().item()
+ return correct / total
+
+if __name__ == '__main__':
+ loader = AutoLoader(task_name="classification",
+ model_name="vit-base-p16-224",
+ num_classes=100)
+
+ model = loader.get_model()
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)
+ train_dataset, val_dataset = build_cifar()
+
+ trainer.train(model,
+ optimizer=optimizer,
+ lr_scheduler=scheduler,
+ train_dataset=train_dataset,
+ valid_dataset=val_dataset,
+ metric_methods=[["accuracy", validate]],
+ collate_fn=collate_fn)
+
+
+
+
+
diff --git a/examples/vit_cifar100/train_env_trainer.py b/examples/vit_cifar100/train_env_trainer.py
new file mode 100644
index 00000000..fcb153cb
--- /dev/null
+++ b/examples/vit_cifar100/train_env_trainer.py
@@ -0,0 +1,90 @@
+import torch
+from torchvision import transforms
+from torchvision.datasets import CIFAR100
+import ssl
+ssl._create_default_https_context = ssl._create_unverified_context
+from flagai.env_trainer import EnvTrainer
+from flagai.auto_model.auto_loader import AutoLoader
+import argparse
+
+lr = 2e-5
+n_epochs = 50
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+from flagai.env_args import EnvArgs
+
+env_args = EnvArgs(
+ env_type="pytorch",
+ experiment_name="vit-cifar100-single_gpu",
+ batch_size=64,
+ num_gpus=1,
+ gradient_accumulation_steps=1,
+ lr=lr,
+ weight_decay=1e-5,
+ epochs=n_epochs,
+ log_interval=100,
+ eval_interval=1000,
+ load_dir=None,
+ pytorch_device=device,
+ save_dir="checkpoints_vit_cifar100_single_gpu",
+ save_interval=1000,
+ num_checkpoints=1,
+)
+
+env_args.add_arg(arg_name="test_args", default=0, type=int, )
+env_args = env_args.parse_args()
+trainer = EnvTrainer(env_args)
+
+def build_cifar():
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.Resize(224),
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ train_dataset = CIFAR100(root="./data/cifar100", train=True, download=True, transform=transform_train)
+ test_dataset = CIFAR100(root="./data/cifar100", train=False, download=True, transform=transform_test)
+ return train_dataset, test_dataset
+
+def collate_fn(batch):
+ images = torch.stack([b[0] for b in batch])
+ if trainer.fp16:
+ images = images.half()
+ labels = [b[1] for b in batch]
+ labels = torch.tensor(labels).long()
+ return {"images": images, "labels": labels}
+
+def validate(logits, labels, meta=None):
+ _, predicted = logits.max(1)
+ total = labels.size(0)
+ correct = predicted.eq(labels).sum().item()
+ return correct / total
+
+if __name__ == '__main__':
+ loader = AutoLoader(task_name="classification",
+ model_name="vit-base-p16-224",
+ num_classes=100)
+
+ model = loader.get_model()
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)
+ train_dataset, val_dataset = build_cifar()
+
+ trainer.train(model,
+ optimizer=optimizer,
+ lr_scheduler=scheduler,
+ train_dataset=train_dataset,
+ valid_dataset=val_dataset,
+ metric_methods=[["accuracy", validate]],
+ collate_fn=collate_fn)
+
+
+
+
+
diff --git a/examples/vit_cifar100/train_single_gpu.py b/examples/vit_cifar100/train_single_gpu.py
new file mode 100644
index 00000000..ef7e1356
--- /dev/null
+++ b/examples/vit_cifar100/train_single_gpu.py
@@ -0,0 +1,85 @@
+import torch
+from torchvision import transforms
+from torchvision.datasets import CIFAR100
+import ssl
+ssl._create_default_https_context = ssl._create_unverified_context
+from flagai.trainer import Trainer
+from flagai.auto_model.auto_loader import AutoLoader
+
+lr = 2e-5
+n_epochs = 50
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+env_type = "pytorch"
+trainer = Trainer(
+ env_type=env_type,
+ experiment_name="vit-cifar100-single_gpu",
+ batch_size=150,
+ num_gpus=1,
+ gradient_accumulation_steps=1,
+ lr=lr,
+ weight_decay=1e-5,
+ epochs=n_epochs,
+ log_interval=100,
+ eval_interval=1000,
+ load_dir=None,
+ pytorch_device=device,
+ save_dir="checkpoints_vit_cifar100_single_gpu",
+ save_interval=1000,
+ num_checkpoints=1,
+)
+
+def build_cifar():
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.Resize(224),
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ train_dataset = CIFAR100(root="./data/cifar100", train=True, download=True, transform=transform_train)
+ test_dataset = CIFAR100(root="./data/cifar100", train=False, download=True, transform=transform_test)
+ return train_dataset, test_dataset
+
+def collate_fn(batch):
+ images = torch.stack([b[0] for b in batch])
+ if trainer.fp16:
+ images = images.half()
+ labels = [b[1] for b in batch]
+ labels = torch.tensor(labels).long()
+ return {"images": images, "labels": labels}
+
+def validate(logits, labels, meta=None):
+ _, predicted = logits.max(1)
+ total = labels.size(0)
+ correct = predicted.eq(labels).sum().item()
+ return correct / total
+
+if __name__ == '__main__':
+ loader = AutoLoader(task_name="classification",
+ model_name="vit-base-p16-224",
+ num_classes=100)
+
+ model = loader.get_model()
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)
+ train_dataset, val_dataset = build_cifar()
+
+ trainer.train(model,
+ optimizer=optimizer,
+ lr_scheduler=scheduler,
+ train_dataset=train_dataset,
+ valid_dataset=val_dataset,
+ metric_methods=[["accuracy", validate]],
+ collate_fn=collate_fn)
+
+
+
+
+
diff --git a/examples/vit_cifar100/validate.py b/examples/vit_cifar100/validate.py
new file mode 100644
index 00000000..e52eb113
--- /dev/null
+++ b/examples/vit_cifar100/validate.py
@@ -0,0 +1,76 @@
+import torch
+from torchvision import transforms
+from torch.utils.data import DataLoader
+from torchvision.datasets import CIFAR100
+import ssl
+ssl._create_default_https_context = ssl._create_unverified_context
+from flagai.auto_model.auto_loader import AutoLoader
+import os
+from tqdm import tqdm
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+def build_cifar():
+
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ test_dataset = CIFAR100(root="./cifar100", train=False, download=True, transform=transform_test)
+ return test_dataset
+
+def collate_fn(batch):
+ images = torch.stack([b[0] for b in batch])
+ labels = [b[1] for b in batch]
+ labels = torch.tensor(labels).long()
+ return {"images": images, "labels": labels}
+
+def validate(logits, labels, meta=None):
+ _, predicted = logits.max(1)
+ total = labels.size(0)
+ correct = predicted.eq(labels).sum().item()
+ return correct / total
+
+if __name__ == '__main__':
+
+ model_save_dir = "./checkpoints_vit_cifar100"
+ print(f"loadding model in :{model_save_dir}")
+ loader = AutoLoader(task_name="backbone",
+ model_name="vit-base-p16-224",
+ num_classes=100)
+
+ model = loader.get_model()
+
+ model.load_state_dict(torch.load(os.path.join(model_save_dir, "38000", "pytorch_model.bin"), map_location=device)["module"])
+ print(f"model load success.......")
+ model.to(device)
+
+ val_dataset = build_cifar()
+
+ val_dataloader = DataLoader(val_dataset,
+ batch_size=1,
+ shuffle=False,
+ collate_fn=collate_fn)
+ index = 0
+ accuracy = 0.0
+ for data in tqdm(val_dataloader, total=len(val_dataloader)):
+ index += 1
+ data = {k: v.to(device) for k, v in data.items()}
+ labels = data["labels"]
+ pred = model(**data)["logits"]
+ acc = validate(pred, labels)
+ accuracy += acc
+
+ print(f"accuracy is {accuracy / index}")
+
+
+
+
+
+
+
+
+
+
diff --git a/flagai/auto_model/auto_loader.py b/flagai/auto_model/auto_loader.py
index 37d7ca90..17d74d3d 100644
--- a/flagai/auto_model/auto_loader.py
+++ b/flagai/auto_model/auto_loader.py
@@ -54,26 +54,37 @@ def __getattr__(self, name):
"glm_title-generation": ["flagai.model.glm_model", "GLMForSeq2Seq"],
"opt_seq2seq": ("flagai.model.opt_model","OPTModel"),
"opt_lm": ("flagai.model.opt_model","OPTModel"),
+ "vit_classification": ("flagai.model.vision.vit", "VisionTransformer")
+
}
MODEL_DICT = {
- "bert-base-en": ["flagai.model.bert_model", "BertModel", "bert"],
- "roberta-base-ch": ["flagai.model.bert_model", "BertModel", "bert"],
- "t5-base-en": ["flagai.model.t5_model", "T5Model", "t5"],
- "t5-base-ch": ["flagai.model.t5_model", "T5Model", "t5"],
- "glm-large-ch": ["flagai.model.glm_model", "GLMModel", "glm"],
- "glm-large-en": ["flagai.model.glm_model", "GLMModel", "glm"],
- "gpt2-base-ch": ["flagai.model.gpt2_model", "GPT2Model", "gpt2"],
- "cpm-large-ch": ["flagai.model.gpt2_model", "GPT2Model", "cpm"],
- "opt-125m-en": ["flagai.model.opt_model","OPTModel", "opt"],
- "opt-350m-en": ["flagai.model.opt_model","OPTModel", "opt"],
- "opt-1.3b-en": ["flagai.model.opt_model","OPTModel", "opt"],
- "opt-2.7b-en": ["flagai.model.opt_model","OPTModel", "opt"],
- "opt-6.7b-en": ["flagai.model.opt_model","OPTModel", "opt"],
- "opt-13b-en": ["flagai.model.opt_model","OPTModel", "opt"],
- "opt-30b-en": ["flagai.model.opt_model","OPTModel", "opt"],
- "opt-66b-en": ["flagai.model.opt_model","OPTModel", "opt"],
- "glm-10b-ch": ["flagai.model.glm_model", "GLMModel", "glm"],
+ "bert-base-en": ["flagai.model.bert_model", "BertModel", "bert", "nlp"],
+ "roberta-base-ch": ["flagai.model.bert_model", "BertModel", "bert", "nlp"],
+ "t5-base-en": ["flagai.model.t5_model", "T5Model", "t5", "nlp"],
+ "t5-base-ch": ["flagai.model.t5_model", "T5Model", "t5", "nlp"],
+ "glm-large-ch": ["flagai.model.glm_model", "GLMModel", "glm", "nlp"],
+ "glm-large-en": ["flagai.model.glm_model", "GLMModel", "glm", "nlp"],
+ "gpt2-base-ch": ["flagai.model.gpt2_model", "GPT2Model", "gpt2", "nlp"],
+ "cpm-large-ch": ["flagai.model.gpt2_model", "GPT2Model", "cpm", "nlp"],
+ "opt-125m-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"],
+ "opt-350m-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"],
+ "opt-1.3b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"],
+ "opt-2.7b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"],
+ "opt-6.7b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"],
+ "opt-13b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"],
+ "opt-30b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"],
+ "opt-66b-en": ["flagai.model.opt_model","OPTModel", "opt", "nlp"],
+ "glm-10b-ch": ["flagai.model.glm_model", "GLMModel", "glm", "nlp"],
+
+ "vit-base-p16-224":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"],
+ "vit-base-p16-384":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"],
+ "vit-base-p32-224":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"],
+ "vit-base-p32-384":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"],
+ "vit-large-p16-224":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"],
+ "vit-large-p16-384":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"],
+ "vit-large-p32-224":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"],
+ "vit-large-p32-384":["flagai.model.vision.vit", "VisionTransformer", "vit", "vision"],
}
TOKENIZER_DICT = {
@@ -103,10 +114,8 @@ def __getattr__(self, name):
"opt-13b-en": ["flagai.data.tokenizer.opt.opt_en_tokenizer","OPTTokenizer"],
"opt-30b-en": ["flagai.data.tokenizer.opt.opt_en_tokenizer","OPTTokenizer"],
"opt-66b-en": ["flagai.data.tokenizer.opt.opt_en_tokenizer","OPTTokenizer"],
-
}
-
class AutoLoader:
def __init__(self,
@@ -153,6 +162,8 @@ def __init__(self,
return
brief_model_name = MODEL_DICT[model_name][2]
+ model_type = MODEL_DICT[model_name][3]
+
# The dir to save config, vocab and model.
self.model_name = ALL_TASK.get(f"{brief_model_name}_{task_name}", None)
@@ -184,38 +195,41 @@ def __init__(self,
model_id = _get_model_id(model_name)
print("*"*20, task_name, model_id, model_name)
- if "glm" in model_name and "ch" in model_name:
- vocab_file = os.path.join(download_path,'cog-pretrained.model')
- if not os.path.exists(vocab_file):
- vocab_file = _get_vocab_path(download_path, "cog-pretrain.model", model_id)
- elif "glm" in model_name and "en" in model_name:
- vocab_file = "GLM-large-en"
- elif model_name == "cpm-large-ch":
- # two files to load
- vocab_file_1 = os.path.join(download_path, "vocab.json")
- vocab_file_2 = os.path.join(download_path, "chinese_vocab.model")
- if not os.path.exists(vocab_file_1):
- vocab_file_1 = _get_vocab_path(download_path, "vocab.json",
- model_id)
- if not os.path.exists(vocab_file_2):
- vocab_file_2 = _get_vocab_path(download_path,
- "chinese_vocab.model", model_id)
- else:
- vocab_file = os.path.join(download_path, 'vocab.txt')
- if not os.path.exists(vocab_file):
- vocab_file = _get_vocab_path(download_path, "vocab.txt",
- model_id)
- tokenizer_class = TOKENIZER_DICT[model_name]
- tokenizer_class = getattr(LazyImport(tokenizer_class[0]),
- tokenizer_class[1])
- if model_name == "cpm-large-ch":
- self.tokenizer = tokenizer_class(vocab_file_1, vocab_file_2)
- elif brief_model_name == "opt":
- self.tokenizer = tokenizer_class("facebook/opt-350m")
- elif model_name in ["glm-large-en", "glm-large-ch"]:
- self.tokenizer = tokenizer_class()
- else :
- self.tokenizer = tokenizer_class(vocab_file)
+ if model_type == "nlp":
+ if "glm" in model_name and "ch" in model_name:
+ vocab_file = os.path.join(download_path,'cog-pretrained.model')
+ if not os.path.exists(vocab_file):
+ vocab_file = _get_vocab_path(download_path, "cog-pretrain.model", model_id)
+ elif "glm" in model_name and "en" in model_name:
+ vocab_file = "GLM-large-en"
+ elif model_name == "cpm-large-ch":
+ # two files to load
+ vocab_file_1 = os.path.join(download_path, "vocab.json")
+ vocab_file_2 = os.path.join(download_path, "chinese_vocab.model")
+ if not os.path.exists(vocab_file_1):
+ vocab_file_1 = _get_vocab_path(download_path, "vocab.json",
+ model_id)
+ if not os.path.exists(vocab_file_2):
+ vocab_file_2 = _get_vocab_path(download_path,
+ "chinese_vocab.model", model_id)
+ else:
+ vocab_file = os.path.join(download_path, 'vocab.txt')
+ if not os.path.exists(vocab_file):
+ vocab_file = _get_vocab_path(download_path, "vocab.txt",
+ model_id)
+ tokenizer_class = TOKENIZER_DICT[model_name]
+ tokenizer_class = getattr(LazyImport(tokenizer_class[0]),
+ tokenizer_class[1])
+ if model_name == "cpm-large-ch":
+ self.tokenizer = tokenizer_class(vocab_file_1, vocab_file_2)
+ elif brief_model_name == "opt":
+ self.tokenizer = tokenizer_class("facebook/opt-350m")
+ elif model_name in ["glm-large-en", "glm-large-ch"]:
+ self.tokenizer = tokenizer_class()
+ else :
+ self.tokenizer = tokenizer_class(vocab_file)
+ elif model_type == "vision":
+ self.tokenizer = None
def get_task_name(self, brief_model_name):
all_model_task = list(ALL_TASK.keys())
diff --git a/flagai/env_args.py b/flagai/env_args.py
new file mode 100644
index 00000000..49c5ce29
--- /dev/null
+++ b/flagai/env_args.py
@@ -0,0 +1,110 @@
+import argparse
+
+def save_best(best_score, eval_dict):
+ return best_score if best_score < eval_dict['loss'] else eval_dict['loss']
+
+def str2bool(v):
+ if isinstance(v,bool):
+ return v
+ if v == 'True':
+ return True
+ if v == 'False':
+ return False
+
+class EnvArgs:
+ def __init__(self,
+ env_type="pytorch",
+ experiment_name="test_experiment",
+ epochs=1,
+ batch_size=1,
+ lr=1e-5,
+ seed=1234,
+
+ fp16=False,
+ pytorch_device="cpu",
+ clip_grad=1.0,
+ checkpoint_activations=False,
+ gradient_accumulation_steps=1,
+ weight_decay=1e-5,
+ warm_up=0.1,
+
+ log_interval=100,
+ eval_interval=1000,
+ save_interval=1000,
+
+ save_dir=None,
+ load_dir=None,
+ save_optim=False, # save current optimizer.')
+ save_rng=False, # save current rng state.')
+ load_type='latest', # latest, best
+ load_optim=False, # not load optimizer when loading checkpoint.')
+ load_rng=False,
+ tensorboard_dir="tensorboard_summary",
+
+ # distribute settings
+ deepspeed_activation_checkpointing=False,
+ num_checkpoints=1,
+ master_ip='localhost',
+ master_port=17750,
+ num_nodes=1,
+ num_gpus=1,
+ hostfile="./hostfile",
+ deepspeed_config="./deepspeed.json",
+ model_parallel_size=1,
+ training_script="train.py",
+ ):
+
+ self.parser = argparse.ArgumentParser(description='Env args parser')
+ self.parser.add_argument('--env_type', default=env_type, help='the model will be trained')
+ self.parser.add_argument('--experiment_name', default=experiment_name, help='start training from saved checkpoint')
+ self.parser.add_argument('--epochs', default=epochs, type=int, help='start training from saved checkpoint')
+ self.parser.add_argument('--batch_size', default=batch_size, type=int, help='start training from saved checkpoint')
+ self.parser.add_argument('--lr', default=lr, type=float, help='start training from saved checkpoint')
+ self.parser.add_argument('--seed', default=seed, type=int, help='start training from saved checkpoint')
+ self.parser.add_argument('--fp16', default=fp16, type=str2bool, help='start training from saved checkpoint')
+ self.parser.add_argument('--pytorch_device', default=pytorch_device, help='start training from saved checkpoint')
+ self.parser.add_argument('--clip_grad', default=clip_grad, type=float, help='start training from saved checkpoint')
+ self.parser.add_argument('--checkpoint_activations', default=checkpoint_activations, type=str2bool, help='start training from saved checkpoint')
+ self.parser.add_argument('--gradient_accumulation_steps', default=gradient_accumulation_steps, type=int, help='start training from saved checkpoint')
+ self.parser.add_argument('--weight_decay', default=weight_decay, type=float, help='start training from saved checkpoint')
+ self.parser.add_argument('--warm_up', default=warm_up, type=float, help='start training from saved checkpoint')
+ self.parser.add_argument('--log_interval', default=log_interval, type=int, help='start training from saved checkpoint')
+ self.parser.add_argument('--eval_interval', default=eval_interval, type=int, help='start training from saved checkpoint')
+ self.parser.add_argument('--save_interval', default=save_interval, type=int, help='start training from saved checkpoint')
+ self.parser.add_argument('--save_dir', default=save_dir, help='start training from saved checkpoint')
+ self.parser.add_argument('--load_dir', default=load_dir, help='start training from saved checkpoint')
+ self.parser.add_argument('--save_optim', default=save_optim, type=str2bool, help='start training from saved checkpoint')
+ self.parser.add_argument('--save_rng', default=save_rng, type=str2bool,help='start training from saved checkpoint')
+ self.parser.add_argument('--load_type', default=load_type, type=str2bool,help='start training from saved checkpoint')
+ self.parser.add_argument('--load_optim', default=load_optim, type=str2bool,help='start training from saved checkpoint')
+ self.parser.add_argument('--load_rng', default=load_rng, type=str2bool, help='start training from saved checkpoint')
+ self.parser.add_argument('--tensorboard_dir', default=tensorboard_dir, help='start training from saved checkpoint')
+ self.parser.add_argument('--deepspeed_activation_checkpointing', default=deepspeed_activation_checkpointing, help='start training from saved checkpoint')
+ self.parser.add_argument('--num_checkpoints', default=num_checkpoints, help='start training from saved checkpoint')
+ self.parser.add_argument('--deepspeed_config', default=deepspeed_config, help='start training from saved checkpoint')
+ self.parser.add_argument('--model_parallel_size', default=model_parallel_size, type=int, help='start training from saved checkpoint')
+ self.parser.add_argument('--training_script', default=training_script, help='start training from saved checkpoint')
+
+ self.parser.add_argument('--hostfile', default=hostfile, help='start training from saved checkpoint')
+ self.parser.add_argument('--master_ip', default=master_ip, help='start training from saved checkpoint')
+ self.parser.add_argument('--master_port', default=master_port, type=int, help='start training from saved checkpoint')
+ self.parser.add_argument('--num_nodes', default=num_nodes, type=int, help='start training from saved checkpoint')
+ self.parser.add_argument('--num_gpus', default=num_gpus, type=int, help='start training from saved checkpoint')
+ self.parser.add_argument('--not_call_launch', action="store_true", help='start training from saved checkpoint')
+ self.parser.add_argument('--local_rank', default=0, type=int, help='start training from saved checkpoint')
+
+ def add_arg(self, arg_name, default=None, type=str, help="", store_true=False):
+ if store_true:
+ self.parser.add_argument(f"--{arg_name}", default=default, type=type, action="store_true", help=help)
+ else :
+ self.parser.add_argument(f"--{arg_name}", default=default, type=type, help=help)
+
+
+ def parse_args(self):
+ args = self.parser.parse_args()
+ if args.env_type == "pytorch":
+ # not need the "not_call_launch" parameter
+ args.not_call_launch = True
+
+ return args
+
diff --git a/flagai/env_trainer.py b/flagai/env_trainer.py
new file mode 100644
index 00000000..c7ef8678
--- /dev/null
+++ b/flagai/env_trainer.py
@@ -0,0 +1,920 @@
+# Copyright © 2022 BAAI. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License")
+# Arguments for training
+try:
+ import deepspeed.utils
+ import deepspeed
+except:
+ pass
+try:
+ from flagai import mpu
+except Exception:
+ pass
+
+import torch
+import argparse
+import os
+import random
+import numpy as np
+import torch.distributed as dist
+from flagai.logger import log_dist
+from torch.utils.tensorboard import SummaryWriter
+from flagai.utils import load_checkpoint, save_checkpoint, load_optim, load_rng
+from flagai.schedulers import AnnealingLR
+from flagai.optimizers import get_optimizer, get_optimizer_param_groups
+from flagai.fp16 import FP16_Module
+from flagai.utils import Timers
+from flagai.launch import launch_dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+from flagai.fp16 import DynamicLossScaler
+"""
+The Trainer class, to easily train a pytorh model on a new task.
+"""
+def save_best(best_score, eval_dict):
+ return best_score if best_score < eval_dict['loss'] else eval_dict['loss']
+
+def get_args_list(env_args):
+ not_need_to_launch_args = ["not_call_launch", "local_rank", "master_port", "master_ip", "hostfile", "num_gpus", "num_nodes"]
+ args_list = []
+ args = dir(env_args)
+ for arg in args:
+ if not arg.startswith("__") and not arg.startswith("_") and arg not in not_need_to_launch_args:
+ args_list.append(f"--{arg}")
+ args_list.append(str(getattr(env_args, arg)))
+
+ print(f"args list is {args_list}")
+ return args_list
+
+class EnvTrainer():
+ def __init__(self,
+ env_args,
+ ):
+ self.timers = Timers()
+ self.env_type = env_args.env_type
+ if self.env_type not in set(
+ ["deepspeed", 'pytorch', 'pytorchDDP', 'deepspeed+mpu']):
+ raise Exception("Not supported env_type!!!!")
+ os.environ["ENV_TYPE"] = env_args.env_type
+ self.experiment_name = env_args.experiment_name
+ self.batch_size = env_args.batch_size
+ self.gradient_accumulation_steps = env_args.gradient_accumulation_steps
+ self.lr = env_args.lr
+ self.weight_decay = env_args.weight_decay
+ self.epochs = env_args.epochs
+ self.clip_grad = env_args.clip_grad
+ self.seed = env_args.seed
+ self.fp16 = env_args.fp16
+ self.warm_up = env_args.warm_up
+
+ self.log_interval = env_args.log_interval
+ self.eval_interval = env_args.eval_interval
+
+ # model checkpointing
+ self.save_dir = env_args.save_dir
+ self.save_interval = env_args.save_interval
+ self.save_optim = env_args.save_optim
+ self.save_rng = env_args.save_rng
+ self.save_best = save_best
+ self.load_dir = env_args.load_dir
+ self.load_type = env_args.load_type
+ self.load_optim = env_args.load_optim
+ self.load_rng = env_args.load_rng
+ self.tb_writer = SummaryWriter(
+ os.path.join(env_args.tensorboard_dir, env_args.experiment_name))
+
+ # distribute settings
+ self.pytorch_device = env_args.pytorch_device
+ self.checkpoint_activations = env_args.checkpoint_activations
+ self.deepspeed_activation_checkpointing = env_args.deepspeed_activation_checkpointing
+ self.num_checkpoints = env_args.num_checkpoints
+ self.env_type = env_args.env_type
+ self.not_call_launch = env_args.not_call_launch
+ self.deepspeed_config = env_args.deepspeed_config
+ self.model_parallel_size = env_args.model_parallel_size
+ self.num_nodes = env_args.num_nodes
+ self.num_gpus = env_args.num_gpus
+ self.master_ip = env_args.master_ip
+ self.master_port = env_args.master_port
+ self.hostfile = env_args.hostfile
+ self.training_script = env_args.training_script
+
+ if 'deepspeed' in self.env_type or self.env_type == 'pytorchDDP':
+ training_paras = get_args_list(env_args)
+ self.rank = int(os.environ.get('RANK', 0))
+ self.world_size = int(os.environ.get('WORLD_SIZE', 1))
+ self.local_rank = env_args.local_rank
+ log_dist("not_call_launch: {}".format(self.not_call_launch))
+ # Implement for AutoLaunch
+ # >>> python train.py # will call get_dist_args()
+ # `--not_call_launch` is default 'False'
+ # So, if `env_type` is `pytorch`, the `Trainer` will not call lanch_dist()
+ # Otherwise, the lanch_dist() is called to launch 'train.py' with `--not_call_launch`
+ if not self.not_call_launch:
+ launch_dist(launcher='distributed_deepspeed' if 'deepspeed'
+ in self.env_type else 'distributed_torch',
+ num_nodes=self.num_nodes,
+ gpus_per_node=self.num_gpus,
+ master_addr=self.master_ip,
+ master_port=self.master_port,
+ hostfile=self.hostfile,
+ training_script=self.training_script,
+ training_paras=training_paras)
+ os._exit(1)
+ self.initialize_distributed()
+
+ def set_seed(self, seed=1234):
+ """Set random seed for reproducability."""
+ if seed is not None and seed > 0:
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if self.env_type == 'deepspeed+mpu':
+ mpu.model_parallel_cuda_manual_seed(seed)
+
+ def initialize_distributed(self):
+ """Initialize torch.distributed."""
+ if self.env_type == 'pytorch':
+ log_dist('No need to initialize')
+ return
+ if self.env_type in ['deepspeed', 'deepspeed+mpu', 'pytorchDDP']:
+ torch.backends.cudnn.enabled = False
+ # Manually set the device ids.
+ device = self.rank % torch.cuda.device_count()
+ if self.local_rank is not None:
+ device = self.local_rank
+ torch.cuda.set_device(device)
+ # Call the init process
+ init_method = 'tcp://'
+ self.master_ip = os.getenv('MASTER_ADDR', 'localhost')
+ self.master_port = os.getenv('MASTER_PORT', '6000')
+
+ init_method += self.master_ip + ':' + self.master_port
+ log_dist(
+ "init method {}, rank {}, device {}, local_rank {}.".format(
+ init_method, self.rank, device, self.local_rank))
+ torch.distributed.init_process_group(
+ backend='nccl', # gloo
+ world_size=self.world_size,
+ rank=self.rank,
+ init_method=init_method)
+ # Set the model-parallel / data-parallel communicators.
+ if self.env_type == 'deepspeed+mpu':
+ os.environ["MODEL_PARALLEL_SIZE"] = str(self.model_parallel_size)
+ try:
+ mpu.initialize_model_parallel(self.model_parallel_size)
+ if 'deepspeed' in self.env_type and self.deepspeed_activation_checkpointing:
+ deepspeed.checkpointing.configure(
+ mpu,
+ deepspeed_config=self.deepspeed_config,
+ num_checkpoints=self.num_checkpoints)
+ mpu.checkpoint = deepspeed.checkpointing.checkpoint
+ mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
+ mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed
+ except Exception as e:
+ log_dist(e)
+ log_dist("No mpu is installed! No model parallel is used")
+ log_dist("initialize eviroments succesed")
+ self.set_seed(self.seed)
+
+ def get_dataloader(self, dataset, collate_fn, shuffle=False):
+ """ initilize the dataloader"""
+ if dataset is None:
+ return None
+ if self.env_type == 'pytorch':
+ return torch.utils.data.DataLoader(dataset,
+ batch_size=self.batch_size,
+ collate_fn=collate_fn,
+ num_workers=4,
+ prefetch_factor=4,
+ pin_memory=True,
+ drop_last=False,
+ shuffle=shuffle)
+ else:
+ if self.env_type == 'deepspeed+mpu':
+ # num_replicas = self.world_size // mpu.get_model_parallel_world_size(
+ # )
+ # rank = self.rank // mpu.get_model_parallel_world_size()
+ # rank = mpu.get_model_parallel_rank()
+ rank = mpu.get_model_parallel_src_rank()
+ print("*"*80)
+ print("local rank",self.rank, "model rank", rank)
+ print("*"*80)
+ sampler = torch.utils.data.distributed.DistributedSampler(
+ dataset,
+ # num_replicas=num_replicas,
+ rank=rank,
+ shuffle=shuffle)
+ else:
+ num_replicas = self.world_size
+ rank = self.rank
+ sampler = torch.utils.data.distributed.DistributedSampler(
+ dataset, rank=rank, shuffle=shuffle)
+ return torch.utils.data.DataLoader(dataset,
+ batch_size=self.batch_size,
+ sampler=sampler,
+ num_workers=4,
+ drop_last=False,
+ pin_memory=False,
+ prefetch_factor=4,
+ collate_fn=collate_fn)
+
+ def train(self,
+ model=None,
+ optimizer=None,
+ lr_scheduler=None,
+ train_dataset=None,
+ valid_dataset=None,
+ metric_methods=[],
+ collate_fn=None):
+ """Training Loops"""
+ """
+ Trainer is a simple but unifed training and eval loop for PyTorch/Deepspeed/Megatron-LM.
+ Args:
+ model (`torch.nn.Module`, *optional*):
+ The model to train, evaluate or use for predictions.
+ args ([`env_type`]):
+ The enviroment type for training. Will default to 'pytorch'.
+ env_type: `pytorch`, `pytorchDDP`, `deepspeed`, `deepspeed+mpu`
+ pytorch: single node cpu/gpu
+ pytorchDDP: single-/multi- node gpu
+ deepspeed: single-/multi- node gpu
+ deepspeed+mpu: single-/multi- node gpu
+ train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.DataLoader`, *optional*):
+ The dataset to use for training.
+ If it is an `Dataset`, we will create a `DataLoader` with the provided `Dataset` and `collate_fn' for the selected `env_type`.
+ `Dataset` is prefred to iterally return a sample as followings,
+ >>> {'text': 'I like big model.', 'label': 'positive'}
+ If it is an `DataLoader`, we will directly use it.
+ Important: Columns not accepted by the `model.forward()` method are automatically droped.
+ eval_dataset (`torch.utils.data.Dataset` or `torch.utils.data.DataLoader`, *optional*):
+ The dataset to use for evaluation. Similar to `train_dataset`.
+ collate_fn (`DataCollator` or `function`, *optional*):
+ The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`.
+ metrics (`function`, *optional*):
+ The function that will be used to compute metrics at evaluation. Must return
+ a dictionary string to metric values.
+ optimizers (`torch.optim.Optimizer`, *optional*): A optimizer to use. Will default to an instance of
+ [`AdamW`] on your model.
+ lr_scheduler (`torch.optim.lr_scheduler`, *optional*): A lr_scheduler to use. Will default to an instance of
+ [`AnnealingLR`].
+ """
+ if not isinstance(train_dataset, torch.utils.data.DataLoader):
+ train_dataloader = self.get_dataloader(train_dataset, collate_fn,
+ True)
+ else:
+ train_dataloader = train_dataset
+
+ if not isinstance(valid_dataset, torch.utils.data.DataLoader):
+
+ valid_dataloader = self.get_dataloader(valid_dataset, collate_fn,
+ False)
+ else:
+ valid_dataloader = valid_dataset
+
+ if self.load_dir:
+ log_dist("loading checkpoints form {}".format(self.load_dir))
+ sd = load_checkpoint(model,
+ load_dir=self.load_dir,
+ load_type=self.load_type)
+ """Train the model."""
+ # Turn on training mode which enables dropout.
+ model.train()
+ if self.fp16 and self.env_type == 'pytorchDDP':
+ log_dist(
+ "Warning: The pytorchDDP plus FP16 may not working togather!!!"
+ )
+ if self.fp16:
+ model.half()
+ if self.checkpoint_activations:
+ model.config[
+ 'checkpoint_activations'] = self.checkpoint_activations
+
+ if self.env_type == 'pytorchDDP':
+ model.to(torch.device('cuda', self.local_rank))
+ model = DDP(model,
+ device_ids=[self.local_rank],
+ find_unused_parameters=True)
+
+ elif self.env_type == 'pytorch':
+ model.to(self.pytorch_device)
+ else:
+ model.cuda(torch.device('cuda', self.local_rank))
+ if self.fp16:
+ model = FP16_Module(model)
+
+ param_groups = get_optimizer_param_groups(model)
+
+ if hasattr(param_groups[0], 'params'):
+ # for T5 Model
+ param_groups = param_groups[0]['params']
+
+ if optimizer is None and 'deepspeed' not in self.env_type and self.epochs > 0:
+ optimizer = get_optimizer(
+ param_groups=param_groups,
+ lr=self.lr,
+ weight_decay=self.weight_decay,
+ cpu_optimizer=False,
+ cpu_torch_adam=False,
+ fp16=self.fp16,
+ optimizer='adam') # if not self.fp16 else 'adafactor')
+
+ if lr_scheduler == None and optimizer != None and self.warm_up > 0 and 'deepspeed' not in self.env_type and self.epochs > 0:
+
+ lr_scheduler = AnnealingLR(
+ optimizer,
+ start_lr=self.lr,
+ warmup_iter=int(self.warm_up * self.epochs *
+ len(train_dataloader)),
+ decay_style='linear',
+ num_iters=self.epochs * len(train_dataloader))
+
+ if 'deepspeed' in self.env_type:
+ # initialize the deepspeed
+ model, optimizer, _, lr_scheduler = deepspeed.initialize(
+ model=model,
+ # if huggingface t5: param_groups[0]['params']
+ model_parameters=param_groups,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ mpu=mpu if self.env_type == 'deepspeed+mpu' else None,
+ config=self.deepspeed_config,
+ dist_init_required=True)
+ if self.load_optim:
+ print(self.load_optim)
+ print(type(self.load_optim))
+ load_optim(optimizer, lr_scheduler, sd)
+ if self.load_rng:
+ load_rng(sd)
+ # Tracking loss.
+ total_lm_loss = 0.0
+ self.iteration = 0
+ self.accumulate_count = 0
+ best_iteration = 0
+ best_loss = float('inf')
+ # For each remaining epoch
+ self.timers('interval time').start()
+ # self.eval_metrics = eval_metrics
+ # self.do_eval = valid_dataset!=None
+ self.metric_methods = metric_methods
+ best_score = float('inf')
+ if len(self.metric_methods) > 0:
+ best_score = -best_score
+
+ for epoch in range(self.epochs):
+ # log_dist('working on epoch {} ...'.format(epoch), [0])
+ # Set the data loader epoch to shuffle the index iterator.
+ # if self.env_type == 'deepspeed+mpu':
+ # if mpu.get_model_parallel_rank() == 0:
+ # train_dataloader.sampler.set_epoch(epoch + self.world_size)
+ if self.env_type != 'pytorch':
+ train_dataloader.sampler.set_epoch(epoch + self.world_size)
+
+ # For all the batches in the dataset.
+ for iteration_, batch in enumerate(train_dataloader):
+ # Train for one step.
+ if 'deepspeed' in self.env_type or self.env_type == 'pytorchDDP':
+ batch = {
+ x: batch[x].to(torch.device('cuda', self.local_rank))
+ for x in batch if x not in ['uid', 'meta', 'mode']
+ }
+ elif 'pytorch' == self.env_type:
+ batch = {
+ x: batch[x].to(torch.device(self.pytorch_device))
+ for x in batch if x not in ['uid', 'meta', 'mode']
+ }
+ if self.env_type == 'pytorchDDP':
+ lm_loss, _ = self.train_step_pytorchDDP(
+ batch, model, optimizer, lr_scheduler)
+ dist.barrier()
+
+ elif self.env_type == 'pytorch':
+ lm_loss, _ = self.train_step_pytorch(
+ batch, model, optimizer, lr_scheduler)
+ else:
+ lm_loss, _ = self.train_step_deepspeed(batch,
+ model,
+ optimizer,
+ lr_scheduler,
+ single_step=True)
+ dist.barrier()
+ if lm_loss is not None:
+ total_lm_loss += lm_loss.data.detach().float()
+
+ # Logging.
+ if (self.iteration + 1) % self.log_interval == 0:
+ if optimizer is not None:
+ learning_rate = optimizer.param_groups[0]['lr']
+ else:
+ learning_rate = model.optimizer.param_groups[0]['lr']
+ avg_lm_loss = total_lm_loss.item() / self.log_interval
+ elapsed_time = self.timers('interval time').elapsed()
+ self.report_iteration_metrics(
+ optimizer, learning_rate, avg_lm_loss,
+ elapsed_time * 1000.0 / self.log_interval,
+ self.iteration + 1,
+ self.epochs * len(train_dataloader))
+ self.tb_writer.add_scalar('train/loss', avg_lm_loss,
+ self.iteration + 1)
+ self.tb_writer.add_scalar('lr', learning_rate,
+ self.iteration + 1)
+ total_lm_loss = 0.0
+ # Evaluation #todo add train_args
+ if self.eval_interval and (
+ self.iteration + 1
+ ) % self.eval_interval == 0 and valid_dataloader is not None:
+ self.timers.log(['forward', 'backward', 'optimizer'],
+ normalizer=self.eval_interval)
+ prefix = 'epoch {}'.format(epoch)
+ eval_dict = self.evaluate_and_print_results(
+ prefix=prefix,
+ data_loader=valid_dataloader,
+ model=model,
+ forward_step_func=self.forward_step,
+ verbose=False)
+ if eval_dict is not None:
+ eval_loss = eval_dict.get("loss", 0.0)
+ self.tb_writer.add_scalar('eval/loss', eval_loss,
+ self.iteration + 1)
+ for i in range(len(self.metric_methods)):
+ name = self.metric_methods[i][0]
+ score = eval_dict.get(name, 0)
+ self.tb_writer.add_scalar(
+ 'eval_metrics/%s' % (name), score,
+ self.iteration + 1)
+
+ if self.save_best is not None and self.save_best(best_score, eval_dict) != best_score:
+ best_score = self.save_best(best_score, eval_dict)
+ log_dist("saving best model with score {:.4f}".format(best_score))
+ best_iteration = self.iteration
+ save_checkpoint(self.iteration+1,
+ best_iteration+1,
+
+ model,
+ optimizer,
+ lr_scheduler,
+ save_optim=self.save_optim,
+ save_dir=self.save_dir,
+ save_rng=self.save_rng)
+ if self.save_dir and (self.iteration + 1) % self.save_interval == 0 and \
+ self.iteration != best_iteration:
+ save_checkpoint(self.iteration+1,
+ best_iteration+1,
+ model,
+ optimizer,
+ lr_scheduler,
+ save_optim=self.save_optim,
+ save_dir=self.save_dir,
+ save_rng=self.save_rng)
+ self.iteration += 1
+
+ # Checkpointing at the end of each epoch.
+
+ # Evaluation #todo add train_args
+ if ((self.epochs == 0) or (self.eval_interval and
+ (self.iteration ) % self.eval_interval != 0)
+ ) and valid_dataloader is not None:
+ prefix = 'final evaluate'
+ self.evaluate_and_print_results(
+ prefix=prefix,
+ data_loader=valid_dataloader,
+ model=model,
+ forward_step_func=self.forward_step,
+ verbose=False)
+
+ def train_step_pytorch(self,
+ data,
+ model,
+ optimizer,
+ lr_scheduler,
+ mems=None):
+ """Single training step."""
+ # Forward model for one step.
+ self.timers('forward').start()
+ step_output = self.forward_step(data, model, mems)
+ self.timers('forward').stop()
+ # accumulate gradients
+ lm_loss = step_output['loss']
+ lm_loss /= self.gradient_accumulation_steps
+ reduced_loss = lm_loss.detach().clone().view(1)
+ # skip the iter while loss has NAN
+ if not DynamicLossScaler._has_inf_or_nan(reduced_loss):
+ # Calculate gradients, reduce across processes, and clip.
+ self.timers('backward').start()
+ if self.fp16 and hasattr(optimizer, 'backward'):
+ optimizer.backward(lm_loss,
+ update_master_grads=False,
+ retain_graph=True)
+ else:
+ lm_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_grad)
+ self.timers('backward').stop()
+
+ # Update parameters.
+ self.timers('optimizer').start()
+ if (self.accumulate_count +
+ 1) % self.gradient_accumulation_steps == 0:
+ if self.fp16:
+ # optimizer.update_master_grads()
+ optimizer.step()
+ optimizer.zero_grad()
+ else:
+ optimizer.step()
+ # optimizer.zero_grad()
+ self.accumulate_count = 0
+ else:
+ self.accumulate_count += 1
+ if lr_scheduler:
+ lr_scheduler.step()
+ self.timers('optimizer').stop()
+
+ else:
+ log_dist("Found NaN loss, skip backward", [0])
+ del lm_loss, reduced_loss
+ mems = None
+ reduced_loss = None
+ return reduced_loss, mems
+
+ def train_step_pytorchDDP(self,
+ data,
+ model,
+ optimizer,
+ lr_scheduler,
+ mems=None):
+ """Single training step."""
+
+ from contextlib import nullcontext
+ if self.fp16:
+ no_sync = model.module.no_sync
+ else:
+ no_sync = model.no_sync
+
+ mycontext = no_sync if (
+ self.accumulate_count +
+ 1) != self.gradient_accumulation_steps else nullcontext
+
+ with mycontext():
+ # Forward model for one step.
+ self.timers('forward').start()
+ step_output = self.forward_step(data, model, mems)
+ self.timers('forward').stop()
+
+ # accumulate gradients
+ lm_loss = step_output['loss']
+ lm_loss /= self.gradient_accumulation_steps
+ # reduce sum of losses
+ reduced_loss = lm_loss.detach().clone().view(1)
+ # dist.all_reduce(reduced_loss.data)
+ # reduced_loss.data = reduced_loss.data / self.world_size
+
+ # skip the iter while loss has NAN
+ if not DynamicLossScaler._has_inf_or_nan(reduced_loss):
+ # Calculate gradients, reduce across processes, and clip.
+ self.timers('backward').start()
+
+ if self.fp16 and hasattr(optimizer, 'backward'):
+ log_dist("The optimizer has backward function")
+ optimizer.backward(lm_loss,
+ update_master_grads=False,
+ retain_graph=True)
+ else:
+ lm_loss.backward()
+
+ torch.nn.utils.clip_grad_norm_(model.module.parameters(),
+ self.clip_grad)
+ self.timers('backward').stop()
+
+ # Update parameters.
+ self.timers('optimizer').start()
+ if (self.accumulate_count +
+ 1) % self.gradient_accumulation_steps == 0:
+ if self.fp16:
+ optimizer.update_master_grads()
+ optimizer.step()
+ optimizer.zero_grad()
+ else:
+ optimizer.step()
+ # model.zero_grad()
+
+ self.accumulate_count = 0
+ else:
+ self.accumulate_count += 1
+ if lr_scheduler:
+ lr_scheduler.step()
+ self.timers('optimizer').stop()
+ dist.barrier()
+
+ else:
+ log_dist("Found NaN loss, skip backward", [0])
+ del lm_loss, reduced_loss
+ mems = None
+ reduced_loss = None
+ return reduced_loss, mems
+
+ def train_step_deepspeed(self,
+ data,
+ model,
+ optimizer,
+ lr_scheduler,
+ mems=None,
+ single_step=False):
+ """Single training step."""
+
+ # Forward model for one step.
+ if (self.accumulate_count + 1) % self.gradient_accumulation_steps == 0:
+ model.set_gradient_accumulation_boundary(True)
+ else:
+ model.set_gradient_accumulation_boundary(False)
+ self.timers('forward').start()
+ step_output = self.forward_step(data, model, mems)
+ self.timers('forward').stop()
+ lm_loss = step_output['loss']
+ reduced_loss = lm_loss.detach().clone().view(1)
+
+ if self.env_type == 'deepspeed+mpu':
+ torch.distributed.all_reduce(reduced_loss.data,
+ group=mpu.get_data_parallel_group())
+ elif self.env_type == 'deepspeed':
+ torch.distributed.all_reduce(reduced_loss.data)
+ if 'deepspeed' in self.env_type:
+ reduced_loss.data = reduced_loss.data / \
+ (self.world_size / self.model_parallel_size)
+ if not DynamicLossScaler._has_inf_or_nan(reduced_loss):
+ # Calculate gradients, reduce across processes, and clip.
+ self.timers('backward').start()
+ model.backward(lm_loss)
+ self.timers('backward').stop()
+ # Update parameters.
+ self.timers('optimizer').start()
+ model.step()
+ if lr_scheduler:
+ lr_scheduler.step()
+ self.timers('optimizer').stop()
+ if (self.accumulate_count +
+ 1) % self.gradient_accumulation_steps == 0:
+ self.accumulate_count = 0
+ else:
+ self.accumulate_count += 1
+ dist.barrier()
+ else:
+ log_dist("Found NaN loss, skip backward", [0])
+ del lm_loss, reduced_loss
+ mems = []
+ reduced_loss = None
+ return reduced_loss, mems
+
+ def forward_step(self, data, model, mems=None):
+ """Simple forward step. """
+ data['mems'] = mems
+ model_output = model(**data)
+ logits = model_output['logits']
+ loss = model_output['loss']
+ hidden_states = None
+ if 'hidden_states' in model_output:
+ hidden_states = model_output['hidden_states']
+ elif 'encoder_hidden_states' in model_output:
+ hidden_states = model_output['encoder_hidden_states']
+
+ return {
+ 'loss': loss,
+ 'hidden_states': hidden_states,
+ 'logits': logits.contiguous().float()
+ }
+
+ def backward_step(self, optimizer, model, lm_loss):
+ """Backward step."""
+
+ # Total loss.
+ loss = lm_loss
+ # Backward pass.
+ # if self.train_args.deepspeed:
+ if 'deepspeed' in self.env_type:
+ model.backward(loss)
+ else:
+ # optimizer.zero_grad()
+ if hasattr(optimizer, 'backward'):
+ optimizer.backward(loss, update_master_grads=False)
+ else:
+ loss.backward()
+ if self.env_type == 'pytorchDDP':
+ optimizer.step()
+
+ # if self.train_args.deepspeed or self.train_args.DDP_impl == 'torch':
+ self.timers('allreduce').reset()
+ if self.env_type == 'pytorch':
+ torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_grad)
+ return lm_loss
+
+ def _gather_all(self, input_):
+
+ # Bypass the function if we are using only 1 GPU.
+ if torch.distributed.get_world_size() == 1:
+ return input_
+ # Size and dimension.
+ last_dim = input_.dim() - 1
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+
+ tensor_list = [
+ torch.empty_like(input_, device=input_.device)
+ for _ in range(world_size)
+ ]
+ tensor_list[rank] = input_
+
+ torch.distributed.all_gather(tensor_list, input_)
+
+ # Note: torch.cat already creates a contiguous tensor.
+ if last_dim >= 0:
+ output = torch.cat(tensor_list, dim=0).contiguous()
+ else:
+ output = torch.mean(torch.FloatTensor(tensor_list))
+
+ return output
+
+ def _gather_all_mpu(self, input_):
+ group = mpu.get_model_parallel_group()
+
+ # Bypass the function if we are using only 1 GPU.
+ if torch.distributed.get_world_size(group=group) == 1:
+ return input_
+ # Size and dimension.
+ last_dim = input_.dim() - 1
+ rank = torch.distributed.get_rank(group=group)
+ world_size = torch.distributed.get_world_size(group=group)
+
+ tensor_list = [
+ torch.empty_like(input_, device=input_.device)
+ for _ in range(world_size)
+ ]
+ tensor_list[rank] = input_
+ torch.distributed.all_gather(tensor_list, input_, group=group)
+
+ # Note: torch.cat already creates a contiguous tensor.
+ output = torch.cat(tensor_list, dim=last_dim).contiguous()
+
+ return output
+
+ def evaluate(self,
+ data_loader=None,
+ model=None,
+ forward_step_func=None,
+ verbose=False):
+ """Evaluation."""
+
+ # Turn off checkpoint_activations
+ tmp_checkpoint_activations = None
+ tmp_model = model
+ while hasattr(tmp_model, 'module'):
+ tmp_model = tmp_model.module
+ # Turn on evaluation mode which disables dropout.
+ tmp_model.eval()
+ if hasattr(tmp_model,
+ 'config') and 'checkpoint_activations' in tmp_model.config:
+ tmp_checkpoint_activations = tmp_model.config[
+ 'checkpoint_activations']
+ tmp_model.config['checkpoint_activations'] = False
+
+ mems = None
+ metrics = [0. for _ in range(len(self.metric_methods))]
+
+ with torch.no_grad():
+ assert data_loader is not None, "val loader is not None."
+ all_logits = []
+ all_labels = []
+ all_losses = []
+ for data_iterator in data_loader:
+ # Forward evaluation.
+
+ meta = data_iterator.get('meta', None)
+
+ if 'deepspeed' in self.env_type or 'DDP' in self.env_type:
+ data_iterator = {
+ x: data_iterator[x].to(
+ torch.device('cuda', self.local_rank))
+ for x in data_iterator
+ if x not in ['uid', 'meta', 'mode']
+ }
+ elif torch.cuda.is_available():
+
+ data_iterator = {
+ x:
+ data_iterator[x].to(torch.device(self.pytorch_device))
+ for x in data_iterator
+ if x not in ['uid', 'meta', 'mode']
+ }
+ step_output = forward_step_func(data_iterator, model, mems)
+ '''when contiguous memory optimizations are enabled, the buffers
+ allocated by the optimizations are deallocated during backward pass
+ in the absence of backward pass the buffers should be reset after each
+ forward pass'''
+ if 'deepspeed' in self.env_type and self.deepspeed_activation_checkpointing:
+ deepspeed.checkpointing.reset()
+ logits = step_output['logits']
+ lm_loss = step_output['loss']
+
+ if 'labels' in data_iterator:
+ labels = data_iterator['labels']
+ else:
+ labels = data_iterator['target_ids']
+
+ all_logits.append(logits)
+ all_labels.append(labels)
+ all_losses.append(lm_loss.view(1))
+
+ if len(self.metric_methods) != 0:
+ all_logits = torch.cat(all_logits, dim=0)
+ all_labels = torch.cat(all_labels, dim=0)
+
+ all_losses = torch.cat(all_losses, dim=0)
+
+ if self.env_type == 'pytorchDDP' or self.env_type == 'deepspeed':
+ if len(self.metric_methods) != 0:
+ all_logits = self._gather_all(all_logits)
+ all_labels = self._gather_all(all_labels)
+ all_losses = self._gather_all(all_losses)
+
+ elif self.env_type == 'deepspeed+mpu':
+ if len(self.metric_methods) != 0:
+ all_logits = self._gather_all_mpu(all_logits)
+ all_labels = self._gather_all_mpu(all_labels)
+ all_losses = self._gather_all_mpu(all_losses)
+
+ if all_losses.device != torch.device('cpu'):
+ all_losses = all_losses.cpu().detach().numpy()[0]
+
+ for i in range(len(self.metric_methods)):
+ eval_method = self.metric_methods[i][1]
+ metrics[i] += eval_method(all_logits, all_labels, meta=meta)
+
+ # Move model back to the train mode.
+
+ # model.train()
+ tmp_model.train()
+ # recover the settings for checkpoint_activations
+ if hasattr(tmp_model,
+ 'config') and 'checkpoint_activations' in tmp_model.config:
+ tmp_model.config[
+ 'checkpoint_activations'] = tmp_checkpoint_activations
+ metric_dct = {}
+ for i in range(len(self.metric_methods)):
+ metric_name = self.metric_methods[i][0]
+ metric_dct.update({metric_name: metrics[i]})
+ metric_dct.update({"loss": all_losses})
+ return metric_dct
+
+ def report_iteration_metrics(self, optimizer, lr, loss, elapsed_time, step,
+ total_step):
+ log_string = ' iteration {:8d}/{:8d} |'.format(step, total_step)
+ log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
+ elapsed_time)
+ log_string += ' learning rate {:.3E} |'.format(lr)
+ log_string += ' loss {:.6E} |'.format(loss)
+ if self.fp16:
+ log_string += ' loss scale {:.1f} |'.format(
+ optimizer.cur_scale if 'deepspeed' in self.env_type else
+ hasattr(optimizer, 'loss_scale') and optimizer.loss_scale)
+ # log_string += ' gradient_accumulation {}/{}'.format(self.accumulate_count, self.gradient_accumulation_steps)
+ log_dist(log_string, [0])
+
+ def report_evaluate_metrics(self, prefix, loss, ppl, gpt_loss, bert_loss,
+ sent_loss, multi_loss, step):
+ string = ' validation loss at {}'.format(prefix)
+ string += ' | LM loss: {:.6E}'.format(loss)
+ string += ' | LM PPL: {:.6E}'.format(ppl)
+ length = len(string) + 1
+ log_dist('-' * 100, [0])
+ log_dist('-' * length, [0])
+ log_dist(string, [0])
+ log_dist('-' * length, [0])
+
+ def evaluate_and_print_results(
+ self,
+ prefix=None,
+ forward_step_func=None,
+ data_loader=None,
+ model=None,
+ verbose=False,
+ ):
+ """Helper function to evaluate and dump results on screen."""
+ eval_dict = self.evaluate(forward_step_func=forward_step_func,
+ data_loader=data_loader,
+ model=model,
+ verbose=verbose)
+ if eval_dict.get("loss", None) is not None:
+ string = ' validation loss at {} | {:.4f}, '.format(
+ prefix, eval_dict["loss"])
+ # with open("results.txt", "a") as myfile:
+ # myfile.write(string)
+ if self.metric_methods is None:
+ return eval_dict
+
+ for i in range(len(self.metric_methods)):
+ name = self.metric_methods[i][0]
+ string += ", {} {:.3f}".format(name, eval_dict[name])
+ # string = ' validation loss at {} | {:.4f}, Acc {:.2f}'.format(
+ # prefix, eval_dict["loss"], eval_dict["metrics"])
+ length = len(string) + 1
+ log_dist('-' * length, [0])
+ log_dist(string, [0])
+ log_dist('-' * length, [0])
+ return eval_dict
\ No newline at end of file
diff --git a/flagai/launch.py b/flagai/launch.py
index 3dcfe22b..ecba3254 100644
--- a/flagai/launch.py
+++ b/flagai/launch.py
@@ -74,7 +74,8 @@ def launch_dist(launcher='distributed_deepspeed',
hostfile='hostfile',
nccl_info=False,
training_script='train.py',
- training_script_paras=None):
+ training_script_paras=None,
+ training_paras=None,):
try:
resource_pool = fetch_hostfile(hostfile)
except:
@@ -151,6 +152,9 @@ def launch_dist(launcher='distributed_deepspeed',
]
cmd_launch.extend(torch_distributed_args)
cmd_launch.append(training_script)
+ if training_paras:
+ cmd_launch.extend(training_paras)
+
cmd_launch.append('--not_call_launch')
run_cmd = ' '.join(cmd_launch)
log_dist(run_cmd)
@@ -196,6 +200,9 @@ def launch_dist(launcher='distributed_deepspeed',
if len(training_script_paras) > 0:
cmd_launch.extend(training_script_paras)
+ if training_paras:
+ cmd_launch.extend(training_paras)
+
cmd_launch.append('--not_call_launch')
run_cmd = ' '.join(cmd_launch)
log_dist(run_cmd)
@@ -226,6 +233,9 @@ def launch_dist(launcher='distributed_deepspeed',
if len(training_script_paras) > 0:
cmd_launch.extend(training_script_paras)
+ if training_paras:
+ cmd_launch.extend(training_paras)
+
run_cmd = ' '.join(cmd_launch)
log_dist(run_cmd)
subprocess.Popen(run_cmd, shell=True)
diff --git a/flagai/model/base_model.py b/flagai/model/base_model.py
index 5480b73b..c600cfb4 100644
--- a/flagai/model/base_model.py
+++ b/flagai/model/base_model.py
@@ -9,7 +9,6 @@
from flagai.model.file_utils import _get_model_id, _get_config_path, _get_checkpoint_path, _get_vocab_path, _get_model_files
import os
-
# The base model for models
class BaseModel(Module):
@@ -59,12 +58,34 @@ def from_pretrain(cls,
# downloading the files
model: Union[Module, None]
if model_id and model_id != "null":
+ model_files = eval(_get_model_files(model_name))
if not os.path.exists(os.path.join(download_path, 'vocab.txt')):
- _get_vocab_path(download_path, "vocab.txt", model_id)
+ if "vocab.txt" in model_files:
+ _get_vocab_path(download_path, "vocab.txt", model_id)
if not only_download_config and not os.path.exists(os.path.join(download_path, 'config.json')):
- model_files = eval(_get_model_files(model_name))
- if 'pytorch_model.bin' in model_files:
+ if os.getenv('ENV_TYPE') == 'deepspeed+mpu':
+ model_parallel_size = int(os.getenv("MODEL_PARALLEL_SIZE"))
+ if model_parallel_size > 1:
+ # if gpus == nums_of_modelhub_models
+ # can load
+ # else need to download the pytorch_model.bin and to recut.
+ model_hub_parallel_size = 0
+ for f in model_files:
+ if "pytorch_model_" in f:
+ model_hub_parallel_size += 1
+ else:
+ model_parallel_size = 1
+
+ if "pytorch_model_01.bin" in model_files and model_parallel_size > 1 and model_hub_parallel_size == model_parallel_size:
+ # Only to download the model slices(megatron-lm).
+ for file_to_load in model_files:
+ if "pytorch_model_" in file_to_load:
+ _get_checkpoint_path(download_path,
+ file_to_load,
+ model_id)
+
+ elif 'pytorch_model.bin' in model_files:
checkpoint_path = _get_checkpoint_path(download_path,
'pytorch_model.bin',
model_id)
diff --git a/flagai/model/vision/helpers.py b/flagai/model/vision/helpers.py
new file mode 100644
index 00000000..1e56190d
--- /dev/null
+++ b/flagai/model/vision/helpers.py
@@ -0,0 +1,70 @@
+
+import os
+if os.getenv('ENV_TYPE') == 'deepspeed':
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
+else:
+ from torch.utils.checkpoint import checkpoint
+import torch
+from itertools import chain
+
+def checkpoint_seq(
+ functions,
+ x,
+ every=1,
+ flatten=False,
+ skip_last=False,
+):
+ r"""A helper function for checkpointing sequential models.
+ Sequential models execute a list of modules/functions in order
+ (sequentially). Therefore, we can divide such a sequence into segments
+ and checkpoint each segment. All segments except run in :func:`torch.no_grad`
+ manner, i.e., not storing the intermediate activations. The inputs of each
+ checkpointed segment will be saved for re-running the segment in the backward pass.
+ See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
+ .. warning::
+ Checkpointing currently only supports :func:`torch.autograd.backward`
+ and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
+ is not supported.
+ .. warning:
+ At least one of the inputs needs to have :code:`requires_grad=True` if
+ grads are needed for model inputs, otherwise the checkpointed part of the
+ model won't have gradients.
+ Args:
+ functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
+ x: A Tensor that is input to :attr:`functions`
+ every: checkpoint every-n functions (default: 1)
+ flatten (bool): flatten nn.Sequential of nn.Sequentials
+ skip_last (bool): skip checkpointing the last function in the sequence if True
+ preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
+ the RNG state during each checkpoint.
+ Returns:
+ Output of running :attr:`functions` sequentially on :attr:`*inputs`
+ Example:
+ >>> model = nn.Sequential(...)
+ >>> input_var = checkpoint_seq(model, input_var, every=2)
+ """
+ def run_function(start, end, functions):
+ def forward(_x):
+ for j in range(start, end + 1):
+ _x = functions[j](_x)
+ return _x
+ return forward
+
+ if isinstance(functions, torch.nn.Sequential):
+ functions = functions.children()
+ if flatten:
+ functions = chain.from_iterable(functions)
+ if not isinstance(functions, (tuple, list)):
+ functions = tuple(functions)
+
+ num_checkpointed = len(functions)
+ if skip_last:
+ num_checkpointed -= 1
+ end = -1
+ for start in range(0, num_checkpointed, every):
+ end = min(start + every - 1, num_checkpointed - 1)
+ x = checkpoint(run_function(start, end, functions), x)
+ if skip_last:
+ return run_function(end + 1, len(functions) - 1, functions)(x)
+ return x
+
diff --git a/flagai/model/vision/layers/__init__.py b/flagai/model/vision/layers/__init__.py
new file mode 100755
index 00000000..7e9e7b19
--- /dev/null
+++ b/flagai/model/vision/layers/__init__.py
@@ -0,0 +1,42 @@
+from .activations import *
+from .adaptive_avgmax_pool import \
+ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
+from .blur_pool import BlurPool2d
+from .classifier import ClassifierHead, create_classifier
+from .cond_conv2d import CondConv2d, get_condconv_initializer
+from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
+ set_layer_config
+from .conv2d_same import Conv2dSame, conv2d_same
+from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
+from .create_act import create_act_layer, get_act_layer, get_act_fn
+from .create_attn import get_attn, create_attn
+from .create_conv2d import create_conv2d
+from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
+from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
+from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
+from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
+ EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
+from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
+from .gather_excite import GatherExcite
+from .global_context import GlobalContext
+from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
+from .inplace_abn import InplaceAbn
+from .linear import Linear
+from .mixed_conv2d import MixedConv2d
+from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
+from .non_local_attn import NonLocalAttn, BatNonLocalAttn
+from .norm import GroupNorm, LayerNorm2d
+from .norm_act import BatchNormAct2d, GroupNormAct
+from .padding import get_padding, get_same_padding, pad_same
+from .patch_embed import PatchEmbed
+from .pool2d_same import AvgPool2dSame, create_pool2d
+from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
+from .selective_kernel import SelectiveKernel
+from .separable_conv import SeparableConv2d, SeparableConvNormAct
+from .space_to_depth import SpaceToDepthModule
+from .split_attn import SplitAttn
+from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
+from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
+from .test_time_pool import TestTimePoolHead, apply_test_time_pool
+from .trace_utils import _assert, _float_to_int
+from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_
diff --git a/flagai/model/vision/layers/activations.py b/flagai/model/vision/layers/activations.py
new file mode 100755
index 00000000..e16b3bd3
--- /dev/null
+++ b/flagai/model/vision/layers/activations.py
@@ -0,0 +1,145 @@
+""" Activations
+
+A collection of activations fn and modules with a common interface so that they can
+easily be swapped. All have an `inplace` arg even if not used.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+def swish(x, inplace: bool = False):
+ """Swish - Described in: https://arxiv.org/abs/1710.05941
+ """
+ return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
+
+
+class Swish(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(Swish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return swish(x, self.inplace)
+
+
+def mish(x, inplace: bool = False):
+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
+ NOTE: I don't have a working inplace variant
+ """
+ return x.mul(F.softplus(x).tanh())
+
+
+class Mish(nn.Module):
+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
+ """
+ def __init__(self, inplace: bool = False):
+ super(Mish, self).__init__()
+
+ def forward(self, x):
+ return mish(x)
+
+
+def sigmoid(x, inplace: bool = False):
+ return x.sigmoid_() if inplace else x.sigmoid()
+
+
+# PyTorch has this, but not with a consistent inplace argmument interface
+class Sigmoid(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(Sigmoid, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return x.sigmoid_() if self.inplace else x.sigmoid()
+
+
+def tanh(x, inplace: bool = False):
+ return x.tanh_() if inplace else x.tanh()
+
+
+# PyTorch has this, but not with a consistent inplace argmument interface
+class Tanh(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(Tanh, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return x.tanh_() if self.inplace else x.tanh()
+
+
+def hard_swish(x, inplace: bool = False):
+ inner = F.relu6(x + 3.).div_(6.)
+ return x.mul_(inner) if inplace else x.mul(inner)
+
+
+class HardSwish(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSwish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return hard_swish(x, self.inplace)
+
+
+def hard_sigmoid(x, inplace: bool = False):
+ if inplace:
+ return x.add_(3.).clamp_(0., 6.).div_(6.)
+ else:
+ return F.relu6(x + 3.) / 6.
+
+
+class HardSigmoid(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSigmoid, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return hard_sigmoid(x, self.inplace)
+
+
+def hard_mish(x, inplace: bool = False):
+ """ Hard Mish
+ Experimental, based on notes by Mish author Diganta Misra at
+ https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
+ """
+ if inplace:
+ return x.mul_(0.5 * (x + 2).clamp(min=0, max=2))
+ else:
+ return 0.5 * x * (x + 2).clamp(min=0, max=2)
+
+
+class HardMish(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardMish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return hard_mish(x, self.inplace)
+
+
+class PReLU(nn.PReLU):
+ """Applies PReLU (w/ dummy inplace arg)
+ """
+ def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
+ super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.prelu(input, self.weight)
+
+
+def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
+ return F.gelu(x)
+
+
+class GELU(nn.Module):
+ """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
+ """
+ def __init__(self, inplace: bool = False):
+ super(GELU, self).__init__()
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.gelu(input)
diff --git a/flagai/model/vision/layers/activations_jit.py b/flagai/model/vision/layers/activations_jit.py
new file mode 100755
index 00000000..b4a51653
--- /dev/null
+++ b/flagai/model/vision/layers/activations_jit.py
@@ -0,0 +1,90 @@
+""" Activations
+
+A collection of jit-scripted activations fn and modules with a common interface so that they can
+easily be swapped. All have an `inplace` arg even if not used.
+
+All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
+currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
+versions if they contain in-place ops.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+@torch.jit.script
+def swish_jit(x, inplace: bool = False):
+ """Swish - Described in: https://arxiv.org/abs/1710.05941
+ """
+ return x.mul(x.sigmoid())
+
+
+@torch.jit.script
+def mish_jit(x, _inplace: bool = False):
+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
+ """
+ return x.mul(F.softplus(x).tanh())
+
+
+class SwishJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(SwishJit, self).__init__()
+
+ def forward(self, x):
+ return swish_jit(x)
+
+
+class MishJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(MishJit, self).__init__()
+
+ def forward(self, x):
+ return mish_jit(x)
+
+
+@torch.jit.script
+def hard_sigmoid_jit(x, inplace: bool = False):
+ # return F.relu6(x + 3.) / 6.
+ return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
+
+
+class HardSigmoidJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSigmoidJit, self).__init__()
+
+ def forward(self, x):
+ return hard_sigmoid_jit(x)
+
+
+@torch.jit.script
+def hard_swish_jit(x, inplace: bool = False):
+ # return x * (F.relu6(x + 3.) / 6)
+ return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
+
+
+class HardSwishJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSwishJit, self).__init__()
+
+ def forward(self, x):
+ return hard_swish_jit(x)
+
+
+@torch.jit.script
+def hard_mish_jit(x, inplace: bool = False):
+ """ Hard Mish
+ Experimental, based on notes by Mish author Diganta Misra at
+ https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
+ """
+ return 0.5 * x * (x + 2).clamp(min=0, max=2)
+
+
+class HardMishJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardMishJit, self).__init__()
+
+ def forward(self, x):
+ return hard_mish_jit(x)
diff --git a/flagai/model/vision/layers/activations_me.py b/flagai/model/vision/layers/activations_me.py
new file mode 100755
index 00000000..9a12bb7e
--- /dev/null
+++ b/flagai/model/vision/layers/activations_me.py
@@ -0,0 +1,218 @@
+""" Activations (memory-efficient w/ custom autograd)
+
+A collection of activations fn and modules with a common interface so that they can
+easily be swapped. All have an `inplace` arg even if not used.
+
+These activations are not compatible with jit scripting or ONNX export of the model, please use either
+the JIT or basic versions of the activations.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+@torch.jit.script
+def swish_jit_fwd(x):
+ return x.mul(torch.sigmoid(x))
+
+
+@torch.jit.script
+def swish_jit_bwd(x, grad_output):
+ x_sigmoid = torch.sigmoid(x)
+ return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
+
+
+class SwishJitAutoFn(torch.autograd.Function):
+ """ torch.jit.script optimised Swish w/ memory-efficient checkpoint
+ Inspired by conversation btw Jeremy Howard & Adam Pazske
+ https://twitter.com/jeremyphoward/status/1188251041835315200
+ """
+ @staticmethod
+ def symbolic(g, x):
+ return g.op("Mul", x, g.op("Sigmoid", x))
+
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return swish_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return swish_jit_bwd(x, grad_output)
+
+
+def swish_me(x, inplace=False):
+ return SwishJitAutoFn.apply(x)
+
+
+class SwishMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(SwishMe, self).__init__()
+
+ def forward(self, x):
+ return SwishJitAutoFn.apply(x)
+
+
+@torch.jit.script
+def mish_jit_fwd(x):
+ return x.mul(torch.tanh(F.softplus(x)))
+
+
+@torch.jit.script
+def mish_jit_bwd(x, grad_output):
+ x_sigmoid = torch.sigmoid(x)
+ x_tanh_sp = F.softplus(x).tanh()
+ return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
+
+
+class MishJitAutoFn(torch.autograd.Function):
+ """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
+ A memory efficient, jit scripted variant of Mish
+ """
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return mish_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return mish_jit_bwd(x, grad_output)
+
+
+def mish_me(x, inplace=False):
+ return MishJitAutoFn.apply(x)
+
+
+class MishMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(MishMe, self).__init__()
+
+ def forward(self, x):
+ return MishJitAutoFn.apply(x)
+
+
+@torch.jit.script
+def hard_sigmoid_jit_fwd(x, inplace: bool = False):
+ return (x + 3).clamp(min=0, max=6).div(6.)
+
+
+@torch.jit.script
+def hard_sigmoid_jit_bwd(x, grad_output):
+ m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
+ return grad_output * m
+
+
+class HardSigmoidJitAutoFn(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return hard_sigmoid_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return hard_sigmoid_jit_bwd(x, grad_output)
+
+
+def hard_sigmoid_me(x, inplace: bool = False):
+ return HardSigmoidJitAutoFn.apply(x)
+
+
+class HardSigmoidMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSigmoidMe, self).__init__()
+
+ def forward(self, x):
+ return HardSigmoidJitAutoFn.apply(x)
+
+
+@torch.jit.script
+def hard_swish_jit_fwd(x):
+ return x * (x + 3).clamp(min=0, max=6).div(6.)
+
+
+@torch.jit.script
+def hard_swish_jit_bwd(x, grad_output):
+ m = torch.ones_like(x) * (x >= 3.)
+ m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
+ return grad_output * m
+
+
+class HardSwishJitAutoFn(torch.autograd.Function):
+ """A memory efficient, jit-scripted HardSwish activation"""
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return hard_swish_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return hard_swish_jit_bwd(x, grad_output)
+
+ @staticmethod
+ def symbolic(g, self):
+ input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float)))
+ hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
+ hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
+ return g.op("Mul", self, hardtanh_)
+
+
+def hard_swish_me(x, inplace=False):
+ return HardSwishJitAutoFn.apply(x)
+
+
+class HardSwishMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSwishMe, self).__init__()
+
+ def forward(self, x):
+ return HardSwishJitAutoFn.apply(x)
+
+
+@torch.jit.script
+def hard_mish_jit_fwd(x):
+ return 0.5 * x * (x + 2).clamp(min=0, max=2)
+
+
+@torch.jit.script
+def hard_mish_jit_bwd(x, grad_output):
+ m = torch.ones_like(x) * (x >= -2.)
+ m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
+ return grad_output * m
+
+
+class HardMishJitAutoFn(torch.autograd.Function):
+ """ A memory efficient, jit scripted variant of Hard Mish
+ Experimental, based on notes by Mish author Diganta Misra at
+ https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
+ """
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return hard_mish_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return hard_mish_jit_bwd(x, grad_output)
+
+
+def hard_mish_me(x, inplace: bool = False):
+ return HardMishJitAutoFn.apply(x)
+
+
+class HardMishMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardMishMe, self).__init__()
+
+ def forward(self, x):
+ return HardMishJitAutoFn.apply(x)
+
+
+
diff --git a/flagai/model/vision/layers/adaptive_avgmax_pool.py b/flagai/model/vision/layers/adaptive_avgmax_pool.py
new file mode 100755
index 00000000..ebc6ada8
--- /dev/null
+++ b/flagai/model/vision/layers/adaptive_avgmax_pool.py
@@ -0,0 +1,118 @@
+""" PyTorch selectable adaptive pooling
+Adaptive pooling with the ability to select the type of pooling from:
+ * 'avg' - Average pooling
+ * 'max' - Max pooling
+ * 'avgmax' - Sum of average and max pooling re-scaled by 0.5
+ * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim
+
+Both a functional and a nn.Module version of the pooling is provided.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def adaptive_pool_feat_mult(pool_type='avg'):
+ if pool_type == 'catavgmax':
+ return 2
+ else:
+ return 1
+
+
+def adaptive_avgmax_pool2d(x, output_size=1):
+ x_avg = F.adaptive_avg_pool2d(x, output_size)
+ x_max = F.adaptive_max_pool2d(x, output_size)
+ return 0.5 * (x_avg + x_max)
+
+
+def adaptive_catavgmax_pool2d(x, output_size=1):
+ x_avg = F.adaptive_avg_pool2d(x, output_size)
+ x_max = F.adaptive_max_pool2d(x, output_size)
+ return torch.cat((x_avg, x_max), 1)
+
+
+def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
+ """Selectable global pooling function with dynamic input kernel size
+ """
+ if pool_type == 'avg':
+ x = F.adaptive_avg_pool2d(x, output_size)
+ elif pool_type == 'avgmax':
+ x = adaptive_avgmax_pool2d(x, output_size)
+ elif pool_type == 'catavgmax':
+ x = adaptive_catavgmax_pool2d(x, output_size)
+ elif pool_type == 'max':
+ x = F.adaptive_max_pool2d(x, output_size)
+ else:
+ assert False, 'Invalid pool type: %s' % pool_type
+ return x
+
+
+class FastAdaptiveAvgPool2d(nn.Module):
+ def __init__(self, flatten=False):
+ super(FastAdaptiveAvgPool2d, self).__init__()
+ self.flatten = flatten
+
+ def forward(self, x):
+ return x.mean((2, 3), keepdim=not self.flatten)
+
+
+class AdaptiveAvgMaxPool2d(nn.Module):
+ def __init__(self, output_size=1):
+ super(AdaptiveAvgMaxPool2d, self).__init__()
+ self.output_size = output_size
+
+ def forward(self, x):
+ return adaptive_avgmax_pool2d(x, self.output_size)
+
+
+class AdaptiveCatAvgMaxPool2d(nn.Module):
+ def __init__(self, output_size=1):
+ super(AdaptiveCatAvgMaxPool2d, self).__init__()
+ self.output_size = output_size
+
+ def forward(self, x):
+ return adaptive_catavgmax_pool2d(x, self.output_size)
+
+
+class SelectAdaptivePool2d(nn.Module):
+ """Selectable global pooling layer with dynamic input kernel size
+ """
+ def __init__(self, output_size=1, pool_type='fast', flatten=False):
+ super(SelectAdaptivePool2d, self).__init__()
+ self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
+ self.flatten = nn.Flatten(1) if flatten else nn.Identity()
+ if pool_type == '':
+ self.pool = nn.Identity() # pass through
+ elif pool_type == 'fast':
+ assert output_size == 1
+ self.pool = FastAdaptiveAvgPool2d(flatten)
+ self.flatten = nn.Identity()
+ elif pool_type == 'avg':
+ self.pool = nn.AdaptiveAvgPool2d(output_size)
+ elif pool_type == 'avgmax':
+ self.pool = AdaptiveAvgMaxPool2d(output_size)
+ elif pool_type == 'catavgmax':
+ self.pool = AdaptiveCatAvgMaxPool2d(output_size)
+ elif pool_type == 'max':
+ self.pool = nn.AdaptiveMaxPool2d(output_size)
+ else:
+ assert False, 'Invalid pool type: %s' % pool_type
+
+ def is_identity(self):
+ return not self.pool_type
+
+ def forward(self, x):
+ x = self.pool(x)
+ x = self.flatten(x)
+ return x
+
+ def feat_mult(self):
+ return adaptive_pool_feat_mult(self.pool_type)
+
+ def __repr__(self):
+ return self.__class__.__name__ + ' (' \
+ + 'pool_type=' + self.pool_type \
+ + ', flatten=' + str(self.flatten) + ')'
+
diff --git a/flagai/model/vision/layers/attention_pool2d.py b/flagai/model/vision/layers/attention_pool2d.py
new file mode 100755
index 00000000..a13a6881
--- /dev/null
+++ b/flagai/model/vision/layers/attention_pool2d.py
@@ -0,0 +1,131 @@
+""" Attention Pool 2D
+
+Implementations of 2D spatial feature pooling using multi-head attention instead of average pool.
+
+Based on idea in CLIP by OpenAI, licensed Apache 2.0
+https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from typing import Union, Tuple
+
+import torch
+import torch.nn as nn
+
+from .helpers import to_2tuple
+from .pos_embed import apply_rot_embed, RotaryEmbedding
+from .weight_init import trunc_normal_
+
+
+class RotAttentionPool2d(nn.Module):
+ """ Attention based 2D feature pooling w/ rotary (relative) pos embedding.
+ This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
+
+ Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.
+ https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
+
+ NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
+ train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
+ """
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int = None,
+ embed_dim: int = None,
+ num_heads: int = 4,
+ qkv_bias: bool = True,
+ ):
+ super().__init__()
+ embed_dim = embed_dim or in_features
+ out_features = out_features or in_features
+ self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(embed_dim, out_features)
+ self.num_heads = num_heads
+ assert embed_dim % num_heads == 0
+ self.head_dim = embed_dim // num_heads
+ self.scale = self.head_dim ** -0.5
+ self.pos_embed = RotaryEmbedding(self.head_dim)
+
+ trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
+ nn.init.zeros_(self.qkv.bias)
+
+ def forward(self, x):
+ B, _, H, W = x.shape
+ N = H * W
+ x = x.reshape(B, -1, N).permute(0, 2, 1)
+
+ x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
+
+ x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = x[0], x[1], x[2]
+
+ qc, q = q[:, :, :1], q[:, :, 1:]
+ sin_emb, cos_emb = self.pos_embed.get_embed((H, W))
+ q = apply_rot_embed(q, sin_emb, cos_emb)
+ q = torch.cat([qc, q], dim=2)
+
+ kc, k = k[:, :, :1], k[:, :, 1:]
+ k = apply_rot_embed(k, sin_emb, cos_emb)
+ k = torch.cat([kc, k], dim=2)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
+ x = self.proj(x)
+ return x[:, 0]
+
+
+class AttentionPool2d(nn.Module):
+ """ Attention based 2D feature pooling w/ learned (absolute) pos embedding.
+ This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
+
+ It was based on impl in CLIP by OpenAI
+ https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
+
+ NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
+ """
+ def __init__(
+ self,
+ in_features: int,
+ feat_size: Union[int, Tuple[int, int]],
+ out_features: int = None,
+ embed_dim: int = None,
+ num_heads: int = 4,
+ qkv_bias: bool = True,
+ ):
+ super().__init__()
+
+ embed_dim = embed_dim or in_features
+ out_features = out_features or in_features
+ assert embed_dim % num_heads == 0
+ self.feat_size = to_2tuple(feat_size)
+ self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(embed_dim, out_features)
+ self.num_heads = num_heads
+ self.head_dim = embed_dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ spatial_dim = self.feat_size[0] * self.feat_size[1]
+ self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
+ trunc_normal_(self.pos_embed, std=in_features ** -0.5)
+ trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
+ nn.init.zeros_(self.qkv.bias)
+
+ def forward(self, x):
+ B, _, H, W = x.shape
+ N = H * W
+ assert self.feat_size[0] == H
+ assert self.feat_size[1] == W
+ x = x.reshape(B, -1, N).permute(0, 2, 1)
+ x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
+ x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
+
+ x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = x[0], x[1], x[2]
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
+ x = self.proj(x)
+ return x[:, 0]
diff --git a/flagai/model/vision/layers/blur_pool.py b/flagai/model/vision/layers/blur_pool.py
new file mode 100755
index 00000000..e73d8863
--- /dev/null
+++ b/flagai/model/vision/layers/blur_pool.py
@@ -0,0 +1,42 @@
+"""
+BlurPool layer inspired by
+ - Kornia's Max_BlurPool2d
+ - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
+
+Hacked together by Chris Ha and Ross Wightman
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from .padding import get_padding
+
+
+class BlurPool2d(nn.Module):
+ r"""Creates a module that computes blurs and downsample a given feature map.
+ See :cite:`zhang2019shiftinvar` for more details.
+ Corresponds to the Downsample class, which does blurring and subsampling
+
+ Args:
+ channels = Number of input channels
+ filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
+ stride (int): downsampling filter stride
+
+ Returns:
+ torch.Tensor: the transformed tensor.
+ """
+ def __init__(self, channels, filt_size=3, stride=2) -> None:
+ super(BlurPool2d, self).__init__()
+ assert filt_size > 1
+ self.channels = channels
+ self.filt_size = filt_size
+ self.stride = stride
+ self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
+ coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
+ blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
+ self.register_buffer('filt', blur_filter, persistent=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.pad(x, self.padding, 'reflect')
+ return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)
diff --git a/flagai/model/vision/layers/bottleneck_attn.py b/flagai/model/vision/layers/bottleneck_attn.py
new file mode 100755
index 00000000..c3db464e
--- /dev/null
+++ b/flagai/model/vision/layers/bottleneck_attn.py
@@ -0,0 +1,157 @@
+""" Bottleneck Self Attention (Bottleneck Transformers)
+
+Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
+
+@misc{2101.11605,
+Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani},
+Title = {Bottleneck Transformers for Visual Recognition},
+Year = {2021},
+}
+
+Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
+
+This impl is a WIP but given that it is based on the ref gist likely not too far off.
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .helpers import to_2tuple, make_divisible
+from .weight_init import trunc_normal_
+from .trace_utils import _assert
+
+
+def rel_logits_1d(q, rel_k, permute_mask: List[int]):
+ """ Compute relative logits along one dimension
+
+ As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
+ Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
+
+ Args:
+ q: (batch, heads, height, width, dim)
+ rel_k: (2 * width - 1, dim)
+ permute_mask: permute output dim according to this
+ """
+ B, H, W, dim = q.shape
+ x = (q @ rel_k.transpose(-1, -2))
+ x = x.reshape(-1, W, 2 * W -1)
+
+ # pad to shift from relative to absolute indexing
+ x_pad = F.pad(x, [0, 1]).flatten(1)
+ x_pad = F.pad(x_pad, [0, W - 1])
+
+ # reshape and slice out the padded elements
+ x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1)
+ x = x_pad[:, :W, W - 1:]
+
+ # reshape and tile
+ x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1)
+ return x.permute(permute_mask)
+
+
+class PosEmbedRel(nn.Module):
+ """ Relative Position Embedding
+ As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
+ Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
+ """
+ def __init__(self, feat_size, dim_head, scale):
+ super().__init__()
+ self.height, self.width = to_2tuple(feat_size)
+ self.dim_head = dim_head
+ self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale)
+ self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale)
+
+ def forward(self, q):
+ B, HW, _ = q.shape
+
+ # relative logits in width dimension.
+ q = q.reshape(B, self.height, self.width, -1)
+ rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
+
+ # relative logits in height dimension.
+ q = q.transpose(1, 2)
+ rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
+
+ rel_logits = rel_logits_h + rel_logits_w
+ rel_logits = rel_logits.reshape(B, HW, HW)
+ return rel_logits
+
+
+class BottleneckAttn(nn.Module):
+ """ Bottleneck Attention
+ Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
+
+ The internal dimensions of the attention module are controlled by the interaction of several arguments.
+ * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
+ * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
+ * the query and key (qk) dimensions are determined by
+ * num_heads * dim_head if dim_head is not None
+ * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
+ * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
+
+ Args:
+ dim (int): input dimension to the module
+ dim_out (int): output dimension of the module, same as dim if not set
+ stride (int): output stride of the module, avg pool used if stride == 2 (default: 1).
+ num_heads (int): parallel attention heads (default: 4)
+ dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
+ qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
+ qkv_bias (bool): add bias to q, k, and v projections
+ scale_pos_embed (bool): scale the position embedding as well as Q @ K
+ """
+ def __init__(
+ self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None,
+ qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False):
+ super().__init__()
+ assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
+ dim_out = dim_out or dim
+ assert dim_out % num_heads == 0
+ self.num_heads = num_heads
+ self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
+ self.dim_head_v = dim_out // self.num_heads
+ self.dim_out_qk = num_heads * self.dim_head_qk
+ self.dim_out_v = num_heads * self.dim_head_v
+ self.scale = self.dim_head_qk ** -0.5
+ self.scale_pos_embed = scale_pos_embed
+
+ self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
+
+ # NOTE I'm only supporting relative pos embedding for now
+ self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
+
+ self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
+ trunc_normal_(self.pos_embed.height_rel, std=self.scale)
+ trunc_normal_(self.pos_embed.width_rel, std=self.scale)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ _assert(H == self.pos_embed.height, '')
+ _assert(W == self.pos_embed.width, '')
+
+ x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
+
+ # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v
+ # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted.
+ q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
+ q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
+ k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
+ v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
+
+ if self.scale_pos_embed:
+ attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
+ else:
+ attn = (q @ k) * self.scale + self.pos_embed(q)
+ attn = attn.softmax(dim=-1)
+
+ out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
+ out = self.pool(out)
+ return out
diff --git a/flagai/model/vision/layers/cbam.py b/flagai/model/vision/layers/cbam.py
new file mode 100755
index 00000000..576a8306
--- /dev/null
+++ b/flagai/model/vision/layers/cbam.py
@@ -0,0 +1,112 @@
+""" CBAM (sort-of) Attention
+
+Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
+
+WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on
+some tasks, especially fine-grained it seems. I may end up removing this impl.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+from torch import nn as nn
+import torch.nn.functional as F
+
+from .conv_bn_act import ConvNormAct
+from .create_act import create_act_layer, get_act_layer
+from .helpers import make_divisible
+
+
+class ChannelAttn(nn.Module):
+ """ Original CBAM channel attention module, currently avg + max pool variant only.
+ """
+ def __init__(
+ self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
+ act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
+ super(ChannelAttn, self).__init__()
+ if not rd_channels:
+ rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
+ self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias)
+ self.act = act_layer(inplace=True)
+ self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True))))
+ x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True))))
+ return x * self.gate(x_avg + x_max)
+
+
+class LightChannelAttn(ChannelAttn):
+ """An experimental 'lightweight' that sums avg + max pool first
+ """
+ def __init__(
+ self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
+ act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
+ super(LightChannelAttn, self).__init__(
+ channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias)
+
+ def forward(self, x):
+ x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True)
+ x_attn = self.fc2(self.act(self.fc1(x_pool)))
+ return x * F.sigmoid(x_attn)
+
+
+class SpatialAttn(nn.Module):
+ """ Original CBAM spatial attention module
+ """
+ def __init__(self, kernel_size=7, gate_layer='sigmoid'):
+ super(SpatialAttn, self).__init__()
+ self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1)
+ x_attn = self.conv(x_attn)
+ return x * self.gate(x_attn)
+
+
+class LightSpatialAttn(nn.Module):
+ """An experimental 'lightweight' variant that sums avg_pool and max_pool results.
+ """
+ def __init__(self, kernel_size=7, gate_layer='sigmoid'):
+ super(LightSpatialAttn, self).__init__()
+ self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True)
+ x_attn = self.conv(x_attn)
+ return x * self.gate(x_attn)
+
+
+class CbamModule(nn.Module):
+ def __init__(
+ self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
+ spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
+ super(CbamModule, self).__init__()
+ self.channel = ChannelAttn(
+ channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
+ rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
+ self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer)
+
+ def forward(self, x):
+ x = self.channel(x)
+ x = self.spatial(x)
+ return x
+
+
+class LightCbamModule(nn.Module):
+ def __init__(
+ self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
+ spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
+ super(LightCbamModule, self).__init__()
+ self.channel = LightChannelAttn(
+ channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
+ rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
+ self.spatial = LightSpatialAttn(spatial_kernel_size)
+
+ def forward(self, x):
+ x = self.channel(x)
+ x = self.spatial(x)
+ return x
+
diff --git a/flagai/model/vision/layers/classifier.py b/flagai/model/vision/layers/classifier.py
new file mode 100755
index 00000000..3ac33387
--- /dev/null
+++ b/flagai/model/vision/layers/classifier.py
@@ -0,0 +1,56 @@
+""" Classifier head and layer factory
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch import nn as nn
+from torch.nn import functional as F
+
+from .adaptive_avgmax_pool import SelectAdaptivePool2d
+
+
+def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
+ flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
+ if not pool_type:
+ assert num_classes == 0 or use_conv,\
+ 'Pooling can only be disabled if classifier is also removed or conv classifier is used'
+ flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
+ global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
+ num_pooled_features = num_features * global_pool.feat_mult()
+ return global_pool, num_pooled_features
+
+
+def _create_fc(num_features, num_classes, use_conv=False):
+ if num_classes <= 0:
+ fc = nn.Identity() # pass-through (no classifier)
+ elif use_conv:
+ fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
+ else:
+ fc = nn.Linear(num_features, num_classes, bias=True)
+ return fc
+
+
+def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
+ global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
+ fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
+ return global_pool, fc
+
+
+class ClassifierHead(nn.Module):
+ """Classifier head w/ configurable global pooling and dropout."""
+
+ def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
+ super(ClassifierHead, self).__init__()
+ self.drop_rate = drop_rate
+ self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
+ self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
+ self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
+
+ def forward(self, x, pre_logits: bool = False):
+ x = self.global_pool(x)
+ if self.drop_rate:
+ x = F.dropout(x, p=float(self.drop_rate), training=self.training)
+ if pre_logits:
+ return x.flatten(1)
+ else:
+ x = self.fc(x)
+ return self.flatten(x)
diff --git a/flagai/model/vision/layers/cond_conv2d.py b/flagai/model/vision/layers/cond_conv2d.py
new file mode 100755
index 00000000..43654c59
--- /dev/null
+++ b/flagai/model/vision/layers/cond_conv2d.py
@@ -0,0 +1,123 @@
+""" PyTorch Conditionally Parameterized Convolution (CondConv)
+
+Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
+(https://arxiv.org/abs/1904.04971)
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import math
+from functools import partial
+import numpy as np
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from .helpers import to_2tuple
+from .conv2d_same import conv2d_same
+from .padding import get_padding_value
+
+
+def get_condconv_initializer(initializer, num_experts, expert_shape):
+ def condconv_initializer(weight):
+ """CondConv initializer function."""
+ num_params = np.prod(expert_shape)
+ if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
+ weight.shape[1] != num_params):
+ raise (ValueError(
+ 'CondConv variables must have shape [num_experts, num_params]'))
+ for i in range(num_experts):
+ initializer(weight[i].view(expert_shape))
+ return condconv_initializer
+
+
+class CondConv2d(nn.Module):
+ """ Conditionally Parameterized Convolution
+ Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
+
+ Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
+ https://github.com/pytorch/pytorch/issues/17983
+ """
+ __constants__ = ['in_channels', 'out_channels', 'dynamic_padding']
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
+ super(CondConv2d, self).__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = to_2tuple(kernel_size)
+ self.stride = to_2tuple(stride)
+ padding_val, is_padding_dynamic = get_padding_value(
+ padding, kernel_size, stride=stride, dilation=dilation)
+ self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
+ self.padding = to_2tuple(padding_val)
+ self.dilation = to_2tuple(dilation)
+ self.groups = groups
+ self.num_experts = num_experts
+
+ self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
+ weight_num_param = 1
+ for wd in self.weight_shape:
+ weight_num_param *= wd
+ self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
+
+ if bias:
+ self.bias_shape = (self.out_channels,)
+ self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
+ else:
+ self.register_parameter('bias', None)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ init_weight = get_condconv_initializer(
+ partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
+ init_weight(self.weight)
+ if self.bias is not None:
+ fan_in = np.prod(self.weight_shape[1:])
+ bound = 1 / math.sqrt(fan_in)
+ init_bias = get_condconv_initializer(
+ partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
+ init_bias(self.bias)
+
+ def forward(self, x, routing_weights):
+ B, C, H, W = x.shape
+ weight = torch.matmul(routing_weights, self.weight)
+ new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
+ weight = weight.view(new_weight_shape)
+ bias = None
+ if self.bias is not None:
+ bias = torch.matmul(routing_weights, self.bias)
+ bias = bias.view(B * self.out_channels)
+ # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
+ # reshape instead of view to work with channels_last input
+ x = x.reshape(1, B * C, H, W)
+ if self.dynamic_padding:
+ out = conv2d_same(
+ x, weight, bias, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups * B)
+ else:
+ out = F.conv2d(
+ x, weight, bias, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups * B)
+ out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
+
+ # Literal port (from TF definition)
+ # x = torch.split(x, 1, 0)
+ # weight = torch.split(weight, 1, 0)
+ # if self.bias is not None:
+ # bias = torch.matmul(routing_weights, self.bias)
+ # bias = torch.split(bias, 1, 0)
+ # else:
+ # bias = [None] * B
+ # out = []
+ # for xi, wi, bi in zip(x, weight, bias):
+ # wi = wi.view(*self.weight_shape)
+ # if bi is not None:
+ # bi = bi.view(*self.bias_shape)
+ # out.append(self.conv_fn(
+ # xi, wi, bi, stride=self.stride, padding=self.padding,
+ # dilation=self.dilation, groups=self.groups))
+ # out = torch.cat(out, 0)
+ return out
diff --git a/flagai/model/vision/layers/config.py b/flagai/model/vision/layers/config.py
new file mode 100755
index 00000000..f07b9d78
--- /dev/null
+++ b/flagai/model/vision/layers/config.py
@@ -0,0 +1,115 @@
+""" Model / Layer Config singleton state
+"""
+from typing import Any, Optional
+
+__all__ = [
+ 'is_exportable', 'is_scriptable', 'is_no_jit',
+ 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
+]
+
+# Set to True if prefer to have layers with no jit optimization (includes activations)
+_NO_JIT = False
+
+# Set to True if prefer to have activation layers with no jit optimization
+# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
+# the jit flags so far are activations. This will change as more layers are updated and/or added.
+_NO_ACTIVATION_JIT = False
+
+# Set to True if exporting a model with Same padding via ONNX
+_EXPORTABLE = False
+
+# Set to True if wanting to use torch.jit.script on a model
+_SCRIPTABLE = False
+
+
+def is_no_jit():
+ return _NO_JIT
+
+
+class set_no_jit:
+ def __init__(self, mode: bool) -> None:
+ global _NO_JIT
+ self.prev = _NO_JIT
+ _NO_JIT = mode
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _NO_JIT
+ _NO_JIT = self.prev
+ return False
+
+
+def is_exportable():
+ return _EXPORTABLE
+
+
+class set_exportable:
+ def __init__(self, mode: bool) -> None:
+ global _EXPORTABLE
+ self.prev = _EXPORTABLE
+ _EXPORTABLE = mode
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _EXPORTABLE
+ _EXPORTABLE = self.prev
+ return False
+
+
+def is_scriptable():
+ return _SCRIPTABLE
+
+
+class set_scriptable:
+ def __init__(self, mode: bool) -> None:
+ global _SCRIPTABLE
+ self.prev = _SCRIPTABLE
+ _SCRIPTABLE = mode
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _SCRIPTABLE
+ _SCRIPTABLE = self.prev
+ return False
+
+
+class set_layer_config:
+ """ Layer config context manager that allows setting all layer config flags at once.
+ If a flag arg is None, it will not change the current value.
+ """
+ def __init__(
+ self,
+ scriptable: Optional[bool] = None,
+ exportable: Optional[bool] = None,
+ no_jit: Optional[bool] = None,
+ no_activation_jit: Optional[bool] = None):
+ global _SCRIPTABLE
+ global _EXPORTABLE
+ global _NO_JIT
+ global _NO_ACTIVATION_JIT
+ self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
+ if scriptable is not None:
+ _SCRIPTABLE = scriptable
+ if exportable is not None:
+ _EXPORTABLE = exportable
+ if no_jit is not None:
+ _NO_JIT = no_jit
+ if no_activation_jit is not None:
+ _NO_ACTIVATION_JIT = no_activation_jit
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _SCRIPTABLE
+ global _EXPORTABLE
+ global _NO_JIT
+ global _NO_ACTIVATION_JIT
+ _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
+ return False
diff --git a/flagai/model/vision/layers/conv2d_same.py b/flagai/model/vision/layers/conv2d_same.py
new file mode 100755
index 00000000..75f0f98d
--- /dev/null
+++ b/flagai/model/vision/layers/conv2d_same.py
@@ -0,0 +1,42 @@
+""" Conv2d w/ Same Padding
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Tuple, Optional
+
+from .padding import pad_same, get_padding_value
+
+
+def conv2d_same(
+ x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
+ padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
+ x = pad_same(x, weight.shape[-2:], stride, dilation)
+ return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
+
+
+class Conv2dSame(nn.Conv2d):
+ """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True):
+ super(Conv2dSame, self).__init__(
+ in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
+
+ def forward(self, x):
+ return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+
+def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
+ padding = kwargs.pop('padding', '')
+ kwargs.setdefault('bias', False)
+ padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
+ if is_dynamic:
+ return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
+ else:
+ return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
+
+
diff --git a/flagai/model/vision/layers/conv_bn_act.py b/flagai/model/vision/layers/conv_bn_act.py
new file mode 100755
index 00000000..af010573
--- /dev/null
+++ b/flagai/model/vision/layers/conv_bn_act.py
@@ -0,0 +1,73 @@
+""" Conv2d + BN + Act
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch import nn as nn
+
+from .create_conv2d import create_conv2d
+from .create_norm_act import get_norm_act_layer
+
+
+class ConvNormAct(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
+ bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None):
+ super(ConvNormAct, self).__init__()
+ self.conv = create_conv2d(
+ in_channels, out_channels, kernel_size, stride=stride,
+ padding=padding, dilation=dilation, groups=groups, bias=bias)
+
+ # NOTE for backwards compatibility with models that use separate norm and act layer definitions
+ norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
+ # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
+ norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
+ self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
+
+ @property
+ def in_channels(self):
+ return self.conv.in_channels
+
+ @property
+ def out_channels(self):
+ return self.conv.out_channels
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+ConvBnAct = ConvNormAct
+
+
+class ConvNormActAa(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
+ bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None):
+ super(ConvNormActAa, self).__init__()
+ use_aa = aa_layer is not None
+
+ self.conv = create_conv2d(
+ in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
+ padding=padding, dilation=dilation, groups=groups, bias=bias)
+
+ # NOTE for backwards compatibility with models that use separate norm and act layer definitions
+ norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
+ # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
+ norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
+ self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
+ self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity()
+
+ @property
+ def in_channels(self):
+ return self.conv.in_channels
+
+ @property
+ def out_channels(self):
+ return self.conv.out_channels
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.aa(x)
+ return x
diff --git a/flagai/model/vision/layers/create_act.py b/flagai/model/vision/layers/create_act.py
new file mode 100755
index 00000000..e38f2e03
--- /dev/null
+++ b/flagai/model/vision/layers/create_act.py
@@ -0,0 +1,148 @@
+""" Activation Factory
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from typing import Union, Callable, Type
+
+from .activations import *
+from .activations_jit import *
+from .activations_me import *
+from .config import is_exportable, is_scriptable, is_no_jit
+
+# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
+# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
+# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
+_has_silu = 'silu' in dir(torch.nn.functional)
+_has_hardswish = 'hardswish' in dir(torch.nn.functional)
+_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional)
+_has_mish = 'mish' in dir(torch.nn.functional)
+
+
+_ACT_FN_DEFAULT = dict(
+ silu=F.silu if _has_silu else swish,
+ swish=F.silu if _has_silu else swish,
+ mish=F.mish if _has_mish else mish,
+ relu=F.relu,
+ relu6=F.relu6,
+ leaky_relu=F.leaky_relu,
+ elu=F.elu,
+ celu=F.celu,
+ selu=F.selu,
+ gelu=gelu,
+ sigmoid=sigmoid,
+ tanh=tanh,
+ hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
+ hard_swish=F.hardswish if _has_hardswish else hard_swish,
+ hard_mish=hard_mish,
+)
+
+_ACT_FN_JIT = dict(
+ silu=F.silu if _has_silu else swish_jit,
+ swish=F.silu if _has_silu else swish_jit,
+ mish=F.mish if _has_mish else mish_jit,
+ hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
+ hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
+ hard_mish=hard_mish_jit
+)
+
+_ACT_FN_ME = dict(
+ silu=F.silu if _has_silu else swish_me,
+ swish=F.silu if _has_silu else swish_me,
+ mish=F.mish if _has_mish else mish_me,
+ hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me,
+ hard_swish=F.hardswish if _has_hardswish else hard_swish_me,
+ hard_mish=hard_mish_me,
+)
+
+_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
+for a in _ACT_FNS:
+ a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
+ a.setdefault('hardswish', a.get('hard_swish'))
+
+
+_ACT_LAYER_DEFAULT = dict(
+ silu=nn.SiLU if _has_silu else Swish,
+ swish=nn.SiLU if _has_silu else Swish,
+ mish=nn.Mish if _has_mish else Mish,
+ relu=nn.ReLU,
+ relu6=nn.ReLU6,
+ leaky_relu=nn.LeakyReLU,
+ elu=nn.ELU,
+ prelu=PReLU,
+ celu=nn.CELU,
+ selu=nn.SELU,
+ gelu=GELU,
+ sigmoid=Sigmoid,
+ tanh=Tanh,
+ hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
+ hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
+ hard_mish=HardMish,
+)
+
+_ACT_LAYER_JIT = dict(
+ silu=nn.SiLU if _has_silu else SwishJit,
+ swish=nn.SiLU if _has_silu else SwishJit,
+ mish=nn.Mish if _has_mish else MishJit,
+ hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
+ hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
+ hard_mish=HardMishJit
+)
+
+_ACT_LAYER_ME = dict(
+ silu=nn.SiLU if _has_silu else SwishMe,
+ swish=nn.SiLU if _has_silu else SwishMe,
+ mish=nn.Mish if _has_mish else MishMe,
+ hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe,
+ hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe,
+ hard_mish=HardMishMe,
+)
+
+_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
+for a in _ACT_LAYERS:
+ a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
+ a.setdefault('hardswish', a.get('hard_swish'))
+
+
+def get_act_fn(name: Union[Callable, str] = 'relu'):
+ """ Activation Function Factory
+ Fetching activation fns by name with this function allows export or torch script friendly
+ functions to be returned dynamically based on current config.
+ """
+ if not name:
+ return None
+ if isinstance(name, Callable):
+ return name
+ if not (is_no_jit() or is_exportable() or is_scriptable()):
+ # If not exporting or scripting the model, first look for a memory-efficient version with
+ # custom autograd, then fallback
+ if name in _ACT_FN_ME:
+ return _ACT_FN_ME[name]
+ if not (is_no_jit() or is_exportable()):
+ if name in _ACT_FN_JIT:
+ return _ACT_FN_JIT[name]
+ return _ACT_FN_DEFAULT[name]
+
+
+def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
+ """ Activation Layer Factory
+ Fetching activation layers by name with this function allows export or torch script friendly
+ functions to be returned dynamically based on current config.
+ """
+ if not name:
+ return None
+ if not isinstance(name, str):
+ # callable, module, etc
+ return name
+ if not (is_no_jit() or is_exportable() or is_scriptable()):
+ if name in _ACT_LAYER_ME:
+ return _ACT_LAYER_ME[name]
+ if not (is_no_jit() or is_exportable()):
+ if name in _ACT_LAYER_JIT:
+ return _ACT_LAYER_JIT[name]
+ return _ACT_LAYER_DEFAULT[name]
+
+
+def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
+ act_layer = get_act_layer(name)
+ if act_layer is None:
+ return None
+ return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs)
diff --git a/flagai/model/vision/layers/create_attn.py b/flagai/model/vision/layers/create_attn.py
new file mode 100755
index 00000000..028c0f75
--- /dev/null
+++ b/flagai/model/vision/layers/create_attn.py
@@ -0,0 +1,89 @@
+""" Attention Factory
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import torch
+from functools import partial
+
+from .bottleneck_attn import BottleneckAttn
+from .cbam import CbamModule, LightCbamModule
+from .eca import EcaModule, CecaModule
+from .gather_excite import GatherExcite
+from .global_context import GlobalContext
+from .halo_attn import HaloAttn
+from .lambda_layer import LambdaLayer
+from .non_local_attn import NonLocalAttn, BatNonLocalAttn
+from .selective_kernel import SelectiveKernel
+from .split_attn import SplitAttn
+from .squeeze_excite import SEModule, EffectiveSEModule
+
+
+def get_attn(attn_type):
+ if isinstance(attn_type, torch.nn.Module):
+ return attn_type
+ module_cls = None
+ if attn_type is not None:
+ if isinstance(attn_type, str):
+ attn_type = attn_type.lower()
+ # Lightweight attention modules (channel and/or coarse spatial).
+ # Typically added to existing network architecture blocks in addition to existing convolutions.
+ if attn_type == 'se':
+ module_cls = SEModule
+ elif attn_type == 'ese':
+ module_cls = EffectiveSEModule
+ elif attn_type == 'eca':
+ module_cls = EcaModule
+ elif attn_type == 'ecam':
+ module_cls = partial(EcaModule, use_mlp=True)
+ elif attn_type == 'ceca':
+ module_cls = CecaModule
+ elif attn_type == 'ge':
+ module_cls = GatherExcite
+ elif attn_type == 'gc':
+ module_cls = GlobalContext
+ elif attn_type == 'gca':
+ module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
+ elif attn_type == 'cbam':
+ module_cls = CbamModule
+ elif attn_type == 'lcbam':
+ module_cls = LightCbamModule
+
+ # Attention / attention-like modules w/ significant params
+ # Typically replace some of the existing workhorse convs in a network architecture.
+ # All of these accept a stride argument and can spatially downsample the input.
+ elif attn_type == 'sk':
+ module_cls = SelectiveKernel
+ elif attn_type == 'splat':
+ module_cls = SplitAttn
+
+ # Self-attention / attention-like modules w/ significant compute and/or params
+ # Typically replace some of the existing workhorse convs in a network architecture.
+ # All of these accept a stride argument and can spatially downsample the input.
+ elif attn_type == 'lambda':
+ return LambdaLayer
+ elif attn_type == 'bottleneck':
+ return BottleneckAttn
+ elif attn_type == 'halo':
+ return HaloAttn
+ elif attn_type == 'nl':
+ module_cls = NonLocalAttn
+ elif attn_type == 'bat':
+ module_cls = BatNonLocalAttn
+
+ # Woops!
+ else:
+ assert False, "Invalid attn module (%s)" % attn_type
+ elif isinstance(attn_type, bool):
+ if attn_type:
+ module_cls = SEModule
+ else:
+ module_cls = attn_type
+ return module_cls
+
+
+def create_attn(attn_type, channels, **kwargs):
+ module_cls = get_attn(attn_type)
+ if module_cls is not None:
+ # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
+ return module_cls(channels, **kwargs)
+ return None
diff --git a/flagai/model/vision/layers/create_conv2d.py b/flagai/model/vision/layers/create_conv2d.py
new file mode 100755
index 00000000..ac9489ce
--- /dev/null
+++ b/flagai/model/vision/layers/create_conv2d.py
@@ -0,0 +1,36 @@
+""" Create Conv2d Factory Method
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+from .mixed_conv2d import MixedConv2d
+from .cond_conv2d import CondConv2d
+from .conv2d_same import create_conv2d_pad
+
+
+def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
+ """ Select a 2d convolution implementation based on arguments
+ Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
+
+ Used extensively by EfficientNet, MobileNetv3 and related networks.
+ """
+ if isinstance(kernel_size, list):
+ assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
+ if 'groups' in kwargs:
+ groups = kwargs.pop('groups')
+ if groups == in_channels:
+ kwargs['depthwise'] = True
+ else:
+ assert groups == 1
+ # We're going to use only lists for defining the MixedConv2d kernel groups,
+ # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
+ m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
+ else:
+ depthwise = kwargs.pop('depthwise', False)
+ # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
+ groups = in_channels if depthwise else kwargs.pop('groups', 1)
+ if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
+ m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
+ else:
+ m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
+ return m
diff --git a/flagai/model/vision/layers/create_norm_act.py b/flagai/model/vision/layers/create_norm_act.py
new file mode 100755
index 00000000..cd15c2f8
--- /dev/null
+++ b/flagai/model/vision/layers/create_norm_act.py
@@ -0,0 +1,88 @@
+""" NormAct (Normalizaiton + Activation Layer) Factory
+
+Create norm + act combo modules that attempt to be backwards compatible with separate norm + act
+isntances in models. Where these are used it will be possible to swap separate BN + act layers with
+combined modules like IABN or EvoNorms.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import types
+import functools
+
+from .evo_norm import *
+from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
+from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d
+from .inplace_abn import InplaceAbn
+
+_NORM_ACT_MAP = dict(
+ batchnorm=BatchNormAct2d,
+ batchnorm2d=BatchNormAct2d,
+ groupnorm=GroupNormAct,
+ layernorm=LayerNormAct,
+ layernorm2d=LayerNormAct2d,
+ evonormb0=EvoNorm2dB0,
+ evonormb1=EvoNorm2dB1,
+ evonormb2=EvoNorm2dB2,
+ evonorms0=EvoNorm2dS0,
+ evonorms0a=EvoNorm2dS0a,
+ evonorms1=EvoNorm2dS1,
+ evonorms1a=EvoNorm2dS1a,
+ evonorms2=EvoNorm2dS2,
+ evonorms2a=EvoNorm2dS2a,
+ frn=FilterResponseNormAct2d,
+ frntlu=FilterResponseNormTlu2d,
+ inplaceabn=InplaceAbn,
+ iabn=InplaceAbn,
+)
+_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
+# has act_layer arg to define act type
+_NORM_ACT_REQUIRES_ARG = {
+ BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn}
+
+
+def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
+ layer = get_norm_act_layer(layer_name, act_layer=act_layer)
+ layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
+ if jit:
+ layer_instance = torch.jit.script(layer_instance)
+ return layer_instance
+
+
+def get_norm_act_layer(norm_layer, act_layer=None):
+ assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
+ assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
+ norm_act_kwargs = {}
+
+ # unbind partial fn, so args can be rebound later
+ if isinstance(norm_layer, functools.partial):
+ norm_act_kwargs.update(norm_layer.keywords)
+ norm_layer = norm_layer.func
+
+ if isinstance(norm_layer, str):
+ layer_name = norm_layer.replace('_', '').lower().split('-')[0]
+ norm_act_layer = _NORM_ACT_MAP.get(layer_name, None)
+ elif norm_layer in _NORM_ACT_TYPES:
+ norm_act_layer = norm_layer
+ elif isinstance(norm_layer, types.FunctionType):
+ # if function type, must be a lambda/fn that creates a norm_act layer
+ norm_act_layer = norm_layer
+ else:
+ type_name = norm_layer.__name__.lower()
+ if type_name.startswith('batchnorm'):
+ norm_act_layer = BatchNormAct2d
+ elif type_name.startswith('groupnorm'):
+ norm_act_layer = GroupNormAct
+ elif type_name.startswith('layernorm2d'):
+ norm_act_layer = LayerNormAct2d
+ elif type_name.startswith('layernorm'):
+ norm_act_layer = LayerNormAct
+ else:
+ assert False, f"No equivalent norm_act layer for {type_name}"
+
+ if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
+ # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
+ # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
+ norm_act_kwargs.setdefault('act_layer', act_layer)
+ if norm_act_kwargs:
+ norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args
+ return norm_act_layer
diff --git a/flagai/model/vision/layers/drop.py b/flagai/model/vision/layers/drop.py
new file mode 100755
index 00000000..ae065277
--- /dev/null
+++ b/flagai/model/vision/layers/drop.py
@@ -0,0 +1,166 @@
+""" DropBlock, DropPath
+
+PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
+
+Papers:
+DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
+
+Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
+
+Code:
+DropBlock impl inspired by two Tensorflow impl that I liked:
+ - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
+ - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def drop_block_2d(
+ x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
+ with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
+ """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
+
+ DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
+ runs with success, but needs further validation and possibly optimization for lower runtime impact.
+ """
+ B, C, H, W = x.shape
+ total_size = W * H
+ clipped_block_size = min(block_size, min(W, H))
+ # seed_drop_rate, the gamma parameter
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
+ (W - block_size + 1) * (H - block_size + 1))
+
+ # Forces the block to be inside the feature map.
+ w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
+ valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
+ ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
+ valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
+
+ if batchwise:
+ # one mask for whole batch, quite a bit faster
+ uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
+ else:
+ uniform_noise = torch.rand_like(x)
+ block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
+ block_mask = -F.max_pool2d(
+ -block_mask,
+ kernel_size=clipped_block_size, # block_size,
+ stride=1,
+ padding=clipped_block_size // 2)
+
+ if with_noise:
+ normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
+ if inplace:
+ x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
+ else:
+ x = x * block_mask + normal_noise * (1 - block_mask)
+ else:
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
+ if inplace:
+ x.mul_(block_mask * normalize_scale)
+ else:
+ x = x * block_mask * normalize_scale
+ return x
+
+
+def drop_block_fast_2d(
+ x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
+ gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False):
+ """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
+
+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
+ block mask at edges.
+ """
+ B, C, H, W = x.shape
+ total_size = W * H
+ clipped_block_size = min(block_size, min(W, H))
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
+ (W - block_size + 1) * (H - block_size + 1))
+
+ block_mask = torch.empty_like(x).bernoulli_(gamma)
+ block_mask = F.max_pool2d(
+ block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
+
+ if with_noise:
+ normal_noise = torch.empty_like(x).normal_()
+ if inplace:
+ x.mul_(1. - block_mask).add_(normal_noise * block_mask)
+ else:
+ x = x * (1. - block_mask) + normal_noise * block_mask
+ else:
+ block_mask = 1 - block_mask
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
+ if inplace:
+ x.mul_(block_mask * normalize_scale)
+ else:
+ x = x * block_mask * normalize_scale
+ return x
+
+
+class DropBlock2d(nn.Module):
+ """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
+ """
+
+ def __init__(
+ self,
+ drop_prob: float = 0.1,
+ block_size: int = 7,
+ gamma_scale: float = 1.0,
+ with_noise: bool = False,
+ inplace: bool = False,
+ batchwise: bool = False,
+ fast: bool = True):
+ super(DropBlock2d, self).__init__()
+ self.drop_prob = drop_prob
+ self.gamma_scale = gamma_scale
+ self.block_size = block_size
+ self.with_noise = with_noise
+ self.inplace = inplace
+ self.batchwise = batchwise
+ self.fast = fast # FIXME finish comparisons of fast vs not
+
+ def forward(self, x):
+ if not self.training or not self.drop_prob:
+ return x
+ if self.fast:
+ return drop_block_fast_2d(
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace)
+ else:
+ return drop_block_2d(
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
diff --git a/flagai/model/vision/layers/eca.py b/flagai/model/vision/layers/eca.py
new file mode 100755
index 00000000..e29be6ac
--- /dev/null
+++ b/flagai/model/vision/layers/eca.py
@@ -0,0 +1,145 @@
+"""
+ECA module from ECAnet
+
+paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
+https://arxiv.org/abs/1910.03151
+
+Original ECA model borrowed from https://github.com/BangguWu/ECANet
+
+Modified circular ECA implementation and adaption for use in timm package
+by Chris Ha https://github.com/VRandme
+
+Original License:
+
+MIT License
+
+Copyright (c) 2019 BangguWu, Qilong Wang
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+import math
+from torch import nn
+import torch.nn.functional as F
+
+
+from .create_act import create_act_layer
+from .helpers import make_divisible
+
+
+class EcaModule(nn.Module):
+ """Constructs an ECA module.
+
+ Args:
+ channels: Number of channels of the input feature map for use in adaptive kernel sizes
+ for actual calculations according to channel.
+ gamma, beta: when channel is given parameters of mapping function
+ refer to original paper https://arxiv.org/pdf/1910.03151.pdf
+ (default=None. if channel size not given, use k_size given for kernel size.)
+ kernel_size: Adaptive selection of kernel size (default=3)
+ gamm: used in kernel_size calc, see above
+ beta: used in kernel_size calc, see above
+ act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
+ gate_layer: gating non-linearity to use
+ """
+ def __init__(
+ self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid',
+ rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False):
+ super(EcaModule, self).__init__()
+ if channels is not None:
+ t = int(abs(math.log(channels, 2) + beta) / gamma)
+ kernel_size = max(t if t % 2 else t + 1, 3)
+ assert kernel_size % 2 == 1
+ padding = (kernel_size - 1) // 2
+ if use_mlp:
+ # NOTE 'mlp' mode is a timm experiment, not in paper
+ assert channels is not None
+ if rd_channels is None:
+ rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor)
+ act_layer = act_layer or nn.ReLU
+ self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True)
+ self.act = create_act_layer(act_layer)
+ self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True)
+ else:
+ self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
+ self.act = None
+ self.conv2 = None
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
+ y = self.conv(y)
+ if self.conv2 is not None:
+ y = self.act(y)
+ y = self.conv2(y)
+ y = self.gate(y).view(x.shape[0], -1, 1, 1)
+ return x * y.expand_as(x)
+
+
+EfficientChannelAttn = EcaModule # alias
+
+
+class CecaModule(nn.Module):
+ """Constructs a circular ECA module.
+
+ ECA module where the conv uses circular padding rather than zero padding.
+ Unlike the spatial dimension, the channels do not have inherent ordering nor
+ locality. Although this module in essence, applies such an assumption, it is unnecessary
+ to limit the channels on either "edge" from being circularly adapted to each other.
+ This will fundamentally increase connectivity and possibly increase performance metrics
+ (accuracy, robustness), without significantly impacting resource metrics
+ (parameter size, throughput,latency, etc)
+
+ Args:
+ channels: Number of channels of the input feature map for use in adaptive kernel sizes
+ for actual calculations according to channel.
+ gamma, beta: when channel is given parameters of mapping function
+ refer to original paper https://arxiv.org/pdf/1910.03151.pdf
+ (default=None. if channel size not given, use k_size given for kernel size.)
+ kernel_size: Adaptive selection of kernel size (default=3)
+ gamm: used in kernel_size calc, see above
+ beta: used in kernel_size calc, see above
+ act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
+ gate_layer: gating non-linearity to use
+ """
+
+ def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'):
+ super(CecaModule, self).__init__()
+ if channels is not None:
+ t = int(abs(math.log(channels, 2) + beta) / gamma)
+ kernel_size = max(t if t % 2 else t + 1, 3)
+ has_act = act_layer is not None
+ assert kernel_size % 2 == 1
+
+ # PyTorch circular padding mode is buggy as of pytorch 1.4
+ # see https://github.com/pytorch/pytorch/pull/17240
+ # implement manual circular padding
+ self.padding = (kernel_size - 1) // 2
+ self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ y = x.mean((2, 3)).view(x.shape[0], 1, -1)
+ # Manually implement circular padding, F.pad does not seemed to be bugged
+ y = F.pad(y, (self.padding, self.padding), mode='circular')
+ y = self.conv(y)
+ y = self.gate(y).view(x.shape[0], -1, 1, 1)
+ return x * y.expand_as(x)
+
+
+CircularEfficientChannelAttn = CecaModule
diff --git a/flagai/model/vision/layers/evo_norm.py b/flagai/model/vision/layers/evo_norm.py
new file mode 100755
index 00000000..b643302c
--- /dev/null
+++ b/flagai/model/vision/layers/evo_norm.py
@@ -0,0 +1,350 @@
+""" EvoNorm in PyTorch
+
+Based on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967
+@inproceedings{NEURIPS2020,
+ author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc},
+ booktitle = {Advances in Neural Information Processing Systems},
+ editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
+ pages = {13539--13550},
+ publisher = {Curran Associates, Inc.},
+ title = {Evolving Normalization-Activation Layers},
+ url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf},
+ volume = {33},
+ year = {2020}
+}
+
+An attempt at getting decent performing EvoNorms running in PyTorch.
+While faster than other PyTorch impl, still quite a ways off the built-in BatchNorm
+in terms of memory usage and throughput on GPUs.
+
+I'm testing these modules on TPU w/ PyTorch XLA. Promising start but
+currently working around some issues with builtin torch/tensor.var/std. Unlike
+GPU, similar train speeds for EvoNormS variants and BatchNorm.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from typing import Sequence, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .create_act import create_act_layer
+from .trace_utils import _assert
+
+
+def instance_std(x, eps: float = 1e-5):
+ std = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype)
+ return std.expand(x.shape)
+
+
+def instance_std_tpu(x, eps: float = 1e-5):
+ std = manual_var(x, dim=(2, 3)).add(eps).sqrt()
+ return std.expand(x.shape)
+# instance_std = instance_std_tpu
+
+
+def instance_rms(x, eps: float = 1e-5):
+ rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype)
+ return rms.expand(x.shape)
+
+
+def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False):
+ xm = x.mean(dim=dim, keepdim=True)
+ if diff_sqm:
+ # difference of squared mean and mean squared, faster on TPU can be less stable
+ var = ((x * x).mean(dim=dim, keepdim=True) - (xm * xm)).clamp(0)
+ else:
+ var = ((x - xm) * (x - xm)).mean(dim=dim, keepdim=True)
+ return var
+
+
+def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):
+ B, C, H, W = x.shape
+ x_dtype = x.dtype
+ _assert(C % groups == 0, '')
+ if flatten:
+ x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
+ std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
+ else:
+ x = x.reshape(B, groups, C // groups, H, W)
+ std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
+ return std.expand(x.shape).reshape(B, C, H, W)
+
+
+def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False):
+ # This is a workaround for some stability / odd behaviour of .var and .std
+ # running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results
+ B, C, H, W = x.shape
+ _assert(C % groups == 0, '')
+ if flatten:
+ x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
+ var = manual_var(x, dim=-1, diff_sqm=diff_sqm)
+ else:
+ x = x.reshape(B, groups, C // groups, H, W)
+ var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm)
+ return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W)
+#group_std = group_std_tpu # FIXME TPU temporary
+
+
+def group_rms(x, groups: int = 32, eps: float = 1e-5):
+ B, C, H, W = x.shape
+ _assert(C % groups == 0, '')
+ x_dtype = x.dtype
+ x = x.reshape(B, groups, C // groups, H, W)
+ rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype)
+ return rms.expand(x.shape).reshape(B, C, H, W)
+
+
+class EvoNorm2dB0(nn.Module):
+ def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-3, **_):
+ super().__init__()
+ self.apply_act = apply_act # apply activation (non-linearity)
+ self.momentum = momentum
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+ self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+ if self.v is not None:
+ nn.init.ones_(self.v)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_dtype = x.dtype
+ v_shape = (1, -1, 1, 1)
+ if self.v is not None:
+ if self.training:
+ var = x.float().var(dim=(0, 2, 3), unbiased=False)
+ # var = manual_var(x, dim=(0, 2, 3)).squeeze()
+ n = x.numel() / x.shape[1]
+ self.running_var.copy_(
+ self.running_var * (1 - self.momentum) +
+ var.detach() * self.momentum * (n / (n - 1)))
+ else:
+ var = self.running_var
+ left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x)
+ v = self.v.to(x_dtype).view(v_shape)
+ right = x * v + instance_std(x, self.eps)
+ x = x / left.max(right)
+ return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view(v_shape)
+
+
+class EvoNorm2dB1(nn.Module):
+ def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
+ super().__init__()
+ self.apply_act = apply_act # apply activation (non-linearity)
+ self.momentum = momentum
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_dtype = x.dtype
+ v_shape = (1, -1, 1, 1)
+ if self.apply_act:
+ if self.training:
+ var = x.float().var(dim=(0, 2, 3), unbiased=False)
+ n = x.numel() / x.shape[1]
+ self.running_var.copy_(
+ self.running_var * (1 - self.momentum) +
+ var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
+ else:
+ var = self.running_var
+ var = var.to(x_dtype).view(v_shape)
+ left = var.add(self.eps).sqrt_()
+ right = (x + 1) * instance_rms(x, self.eps)
+ x = x / left.max(right)
+ return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
+
+
+class EvoNorm2dB2(nn.Module):
+ def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
+ super().__init__()
+ self.apply_act = apply_act # apply activation (non-linearity)
+ self.momentum = momentum
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_dtype = x.dtype
+ v_shape = (1, -1, 1, 1)
+ if self.apply_act:
+ if self.training:
+ var = x.float().var(dim=(0, 2, 3), unbiased=False)
+ n = x.numel() / x.shape[1]
+ self.running_var.copy_(
+ self.running_var * (1 - self.momentum) +
+ var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
+ else:
+ var = self.running_var
+ var = var.to(x_dtype).view(v_shape)
+ left = var.add(self.eps).sqrt_()
+ right = instance_rms(x, self.eps) - x
+ x = x / left.max(right)
+ return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
+
+
+class EvoNorm2dS0(nn.Module):
+ def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_):
+ super().__init__()
+ self.apply_act = apply_act # apply activation (non-linearity)
+ if group_size:
+ assert num_features % group_size == 0
+ self.groups = num_features // group_size
+ else:
+ self.groups = groups
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+ self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+ if self.v is not None:
+ nn.init.ones_(self.v)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_dtype = x.dtype
+ v_shape = (1, -1, 1, 1)
+ if self.v is not None:
+ v = self.v.view(v_shape).to(x_dtype)
+ x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
+ return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
+
+
+class EvoNorm2dS0a(EvoNorm2dS0):
+ def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-3, **_):
+ super().__init__(
+ num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_dtype = x.dtype
+ v_shape = (1, -1, 1, 1)
+ d = group_std(x, self.groups, self.eps)
+ if self.v is not None:
+ v = self.v.view(v_shape).to(x_dtype)
+ x = x * (x * v).sigmoid()
+ x = x / d
+ return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
+
+
+class EvoNorm2dS1(nn.Module):
+ def __init__(
+ self, num_features, groups=32, group_size=None,
+ apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
+ super().__init__()
+ self.apply_act = apply_act # apply activation (non-linearity)
+ if act_layer is not None and apply_act:
+ self.act = create_act_layer(act_layer)
+ else:
+ self.act = nn.Identity()
+ if group_size:
+ assert num_features % group_size == 0
+ self.groups = num_features // group_size
+ else:
+ self.groups = groups
+ self.eps = eps
+ self.pre_act_norm = False
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_dtype = x.dtype
+ v_shape = (1, -1, 1, 1)
+ if self.apply_act:
+ x = self.act(x) / group_std(x, self.groups, self.eps)
+ return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
+
+
+class EvoNorm2dS1a(EvoNorm2dS1):
+ def __init__(
+ self, num_features, groups=32, group_size=None,
+ apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
+ super().__init__(
+ num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_dtype = x.dtype
+ v_shape = (1, -1, 1, 1)
+ x = self.act(x) / group_std(x, self.groups, self.eps)
+ return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
+
+
+class EvoNorm2dS2(nn.Module):
+ def __init__(
+ self, num_features, groups=32, group_size=None,
+ apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
+ super().__init__()
+ self.apply_act = apply_act # apply activation (non-linearity)
+ if act_layer is not None and apply_act:
+ self.act = create_act_layer(act_layer)
+ else:
+ self.act = nn.Identity()
+ if group_size:
+ assert num_features % group_size == 0
+ self.groups = num_features // group_size
+ else:
+ self.groups = groups
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_dtype = x.dtype
+ v_shape = (1, -1, 1, 1)
+ if self.apply_act:
+ x = self.act(x) / group_rms(x, self.groups, self.eps)
+ return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
+
+
+class EvoNorm2dS2a(EvoNorm2dS2):
+ def __init__(
+ self, num_features, groups=32, group_size=None,
+ apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
+ super().__init__(
+ num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_dtype = x.dtype
+ v_shape = (1, -1, 1, 1)
+ x = self.act(x) / group_rms(x, self.groups, self.eps)
+ return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
diff --git a/flagai/model/vision/layers/filter_response_norm.py b/flagai/model/vision/layers/filter_response_norm.py
new file mode 100755
index 00000000..a66a1cd4
--- /dev/null
+++ b/flagai/model/vision/layers/filter_response_norm.py
@@ -0,0 +1,68 @@
+""" Filter Response Norm in PyTorch
+
+Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+
+from .create_act import create_act_layer
+from .trace_utils import _assert
+
+
+def inv_instance_rms(x, eps: float = 1e-5):
+ rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)
+ return rms.expand(x.shape)
+
+
+class FilterResponseNormTlu2d(nn.Module):
+ def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_):
+ super(FilterResponseNormTlu2d, self).__init__()
+ self.apply_act = apply_act # apply activation (non-linearity)
+ self.rms = rms
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+ self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+ if self.tau is not None:
+ nn.init.zeros_(self.tau)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_dtype = x.dtype
+ v_shape = (1, -1, 1, 1)
+ x = x * inv_instance_rms(x, self.eps)
+ x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
+ return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x
+
+
+class FilterResponseNormAct2d(nn.Module):
+ def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_):
+ super(FilterResponseNormAct2d, self).__init__()
+ if act_layer is not None and apply_act:
+ self.act = create_act_layer(act_layer, inplace=inplace)
+ else:
+ self.act = nn.Identity()
+ self.rms = rms
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_dtype = x.dtype
+ v_shape = (1, -1, 1, 1)
+ x = x * inv_instance_rms(x, self.eps)
+ x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
+ return self.act(x)
diff --git a/flagai/model/vision/layers/gather_excite.py b/flagai/model/vision/layers/gather_excite.py
new file mode 100755
index 00000000..2d60dc96
--- /dev/null
+++ b/flagai/model/vision/layers/gather_excite.py
@@ -0,0 +1,90 @@
+""" Gather-Excite Attention Block
+
+Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348
+
+Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet
+
+I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another
+impl that covers all of the cases.
+
+NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import math
+
+from torch import nn as nn
+import torch.nn.functional as F
+
+from .create_act import create_act_layer, get_act_layer
+from .create_conv2d import create_conv2d
+from .helpers import make_divisible
+from .mlp import ConvMlp
+
+
+class GatherExcite(nn.Module):
+ """ Gather-Excite Attention Module
+ """
+ def __init__(
+ self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
+ rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False,
+ act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
+ super(GatherExcite, self).__init__()
+ self.add_maxpool = add_maxpool
+ act_layer = get_act_layer(act_layer)
+ self.extent = extent
+ if extra_params:
+ self.gather = nn.Sequential()
+ if extent == 0:
+ assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
+ self.gather.add_module(
+ 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
+ if norm_layer:
+ self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
+ else:
+ assert extent % 2 == 0
+ num_conv = int(math.log2(extent))
+ for i in range(num_conv):
+ self.gather.add_module(
+ f'conv{i + 1}',
+ create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
+ if norm_layer:
+ self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))
+ if i != num_conv - 1:
+ self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
+ else:
+ self.gather = None
+ if self.extent == 0:
+ self.gk = 0
+ self.gs = 0
+ else:
+ assert extent % 2 == 0
+ self.gk = self.extent * 2 - 1
+ self.gs = self.extent
+
+ if not rd_channels:
+ rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
+ self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ size = x.shape[-2:]
+ if self.gather is not None:
+ x_ge = self.gather(x)
+ else:
+ if self.extent == 0:
+ # global extent
+ x_ge = x.mean(dim=(2, 3), keepdims=True)
+ if self.add_maxpool:
+ # experimental codepath, may remove or change
+ x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
+ else:
+ x_ge = F.avg_pool2d(
+ x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
+ if self.add_maxpool:
+ # experimental codepath, may remove or change
+ x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
+ x_ge = self.mlp(x_ge)
+ if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
+ x_ge = F.interpolate(x_ge, size=size)
+ return x * self.gate(x_ge)
diff --git a/flagai/model/vision/layers/global_context.py b/flagai/model/vision/layers/global_context.py
new file mode 100755
index 00000000..de7fb5c1
--- /dev/null
+++ b/flagai/model/vision/layers/global_context.py
@@ -0,0 +1,67 @@
+""" Global Context Attention Block
+
+Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`
+ - https://arxiv.org/abs/1904.11492
+
+Official code consulted as reference: https://github.com/xvjiarui/GCNet
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from torch import nn as nn
+import torch.nn.functional as F
+
+from .create_act import create_act_layer, get_act_layer
+from .helpers import make_divisible
+from .mlp import ConvMlp
+from .norm import LayerNorm2d
+
+
+class GlobalContext(nn.Module):
+
+ def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False,
+ rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
+ super(GlobalContext, self).__init__()
+ act_layer = get_act_layer(act_layer)
+
+ self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
+
+ if rd_channels is None:
+ rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
+ if fuse_add:
+ self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
+ else:
+ self.mlp_add = None
+ if fuse_scale:
+ self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
+ else:
+ self.mlp_scale = None
+
+ self.gate = create_act_layer(gate_layer)
+ self.init_last_zero = init_last_zero
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.conv_attn is not None:
+ nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
+ if self.mlp_add is not None:
+ nn.init.zeros_(self.mlp_add.fc2.weight)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+
+ if self.conv_attn is not None:
+ attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
+ attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
+ context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
+ context = context.view(B, C, 1, 1)
+ else:
+ context = x.mean(dim=(2, 3), keepdim=True)
+
+ if self.mlp_scale is not None:
+ mlp_x = self.mlp_scale(context)
+ x = x * self.gate(mlp_x)
+ if self.mlp_add is not None:
+ mlp_x = self.mlp_add(context)
+ x = x + mlp_x
+
+ return x
diff --git a/flagai/model/vision/layers/halo_attn.py b/flagai/model/vision/layers/halo_attn.py
new file mode 100755
index 00000000..f2ac64f8
--- /dev/null
+++ b/flagai/model/vision/layers/halo_attn.py
@@ -0,0 +1,233 @@
+""" Halo Self Attention
+
+Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
+ - https://arxiv.org/abs/2103.12731
+
+@misc{2103.12731,
+Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and
+ Jonathon Shlens},
+Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones},
+Year = {2021},
+}
+
+Status:
+This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.
+The attention mechanism works but it's slow as implemented.
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from typing import List
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .helpers import make_divisible
+from .weight_init import trunc_normal_
+from .trace_utils import _assert
+
+
+def rel_logits_1d(q, rel_k, permute_mask: List[int]):
+ """ Compute relative logits along one dimension
+
+ As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
+ Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
+
+ Args:
+ q: (batch, height, width, dim)
+ rel_k: (2 * window - 1, dim)
+ permute_mask: permute output dim according to this
+ """
+ B, H, W, dim = q.shape
+ rel_size = rel_k.shape[0]
+ win_size = (rel_size + 1) // 2
+
+ x = (q @ rel_k.transpose(-1, -2))
+ x = x.reshape(-1, W, rel_size)
+
+ # pad to shift from relative to absolute indexing
+ x_pad = F.pad(x, [0, 1]).flatten(1)
+ x_pad = F.pad(x_pad, [0, rel_size - W])
+
+ # reshape and slice out the padded elements
+ x_pad = x_pad.reshape(-1, W + 1, rel_size)
+ x = x_pad[:, :W, win_size - 1:]
+
+ # reshape and tile
+ x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1)
+ return x.permute(permute_mask)
+
+
+class PosEmbedRel(nn.Module):
+ """ Relative Position Embedding
+ As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
+ Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
+
+ """
+ def __init__(self, block_size, win_size, dim_head, scale):
+ """
+ Args:
+ block_size (int): block size
+ win_size (int): neighbourhood window size
+ dim_head (int): attention head dim
+ scale (float): scale factor (for init)
+ """
+ super().__init__()
+ self.block_size = block_size
+ self.dim_head = dim_head
+ self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
+ self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
+
+ def forward(self, q):
+ B, BB, HW, _ = q.shape
+
+ # relative logits in width dimension.
+ q = q.reshape(-1, self.block_size, self.block_size, self.dim_head)
+ rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
+
+ # relative logits in height dimension.
+ q = q.transpose(1, 2)
+ rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
+
+ rel_logits = rel_logits_h + rel_logits_w
+ rel_logits = rel_logits.reshape(B, BB, HW, -1)
+ return rel_logits
+
+
+class HaloAttn(nn.Module):
+ """ Halo Attention
+
+ Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
+ - https://arxiv.org/abs/2103.12731
+
+ The internal dimensions of the attention module are controlled by the interaction of several arguments.
+ * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
+ * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
+ * the query and key (qk) dimensions are determined by
+ * num_heads * dim_head if dim_head is not None
+ * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
+ * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
+
+ Args:
+ dim (int): input dimension to the module
+ dim_out (int): output dimension of the module, same as dim if not set
+ feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda)
+ stride: output stride of the module, query downscaled if > 1 (default: 1).
+ num_heads: parallel attention heads (default: 8).
+ dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
+ block_size (int): size of blocks. (default: 8)
+ halo_size (int): size of halo overlap. (default: 3)
+ qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
+ qkv_bias (bool) : add bias to q, k, and v projections
+ avg_down (bool): use average pool downsample instead of strided query blocks
+ scale_pos_embed (bool): scale the position embedding as well as Q @ K
+ """
+ def __init__(
+ self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3,
+ qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False):
+ super().__init__()
+ dim_out = dim_out or dim
+ assert dim_out % num_heads == 0
+ assert stride in (1, 2)
+ self.num_heads = num_heads
+ self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
+ self.dim_head_v = dim_out // self.num_heads
+ self.dim_out_qk = num_heads * self.dim_head_qk
+ self.dim_out_v = num_heads * self.dim_head_v
+ self.scale = self.dim_head_qk ** -0.5
+ self.scale_pos_embed = scale_pos_embed
+ self.block_size = self.block_size_ds = block_size
+ self.halo_size = halo_size
+ self.win_size = block_size + halo_size * 2 # neighbourhood window size
+ self.block_stride = 1
+ use_avg_pool = False
+ if stride > 1:
+ use_avg_pool = avg_down or block_size % stride != 0
+ self.block_stride = 1 if use_avg_pool else stride
+ self.block_size_ds = self.block_size // self.block_stride
+
+ # FIXME not clear if this stride behaviour is what the paper intended
+ # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
+ # data in unfolded block form. I haven't wrapped my head around how that'd look.
+ self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias)
+ self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)
+
+ self.pos_embed = PosEmbedRel(
+ block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale)
+
+ self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity()
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ std = self.q.weight.shape[1] ** -0.5 # fan-in
+ trunc_normal_(self.q.weight, std=std)
+ trunc_normal_(self.kv.weight, std=std)
+ trunc_normal_(self.pos_embed.height_rel, std=self.scale)
+ trunc_normal_(self.pos_embed.width_rel, std=self.scale)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ _assert(H % self.block_size == 0, '')
+ _assert(W % self.block_size == 0, '')
+ num_h_blocks = H // self.block_size
+ num_w_blocks = W // self.block_size
+ num_blocks = num_h_blocks * num_w_blocks
+
+ q = self.q(x)
+ # unfold
+ q = q.reshape(
+ -1, self.dim_head_qk,
+ num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4)
+ # B, num_heads * dim_head * block_size ** 2, num_blocks
+ q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)
+ # B * num_heads, num_blocks, block_size ** 2, dim_head
+
+ kv = self.kv(x)
+ # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not
+ # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach.
+ # FIXME figure out how to switch impl between this and conv2d if XLA being used.
+ kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
+ kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
+ B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)
+ k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
+ # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
+
+ if self.scale_pos_embed:
+ attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale
+ else:
+ attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q)
+ # B * num_heads, num_blocks, block_size ** 2, win_size ** 2
+ attn = attn.softmax(dim=-1)
+
+ out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
+ # fold
+ out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks)
+ out = out.permute(0, 3, 1, 4, 2).contiguous().view(
+ B, self.dim_out_v, H // self.block_stride, W // self.block_stride)
+ # B, dim_out, H // block_stride, W // block_stride
+ out = self.pool(out)
+ return out
+
+
+""" Three alternatives for overlapping windows.
+
+`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()
+
+ if is_xla:
+ # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is
+ # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment.
+ WW = self.win_size ** 2
+ pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size)
+ kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size)
+ elif self.stride_tricks:
+ kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
+ kv = kv.as_strided((
+ B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
+ stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
+ else:
+ kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
+
+ kv = kv.reshape(
+ B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)
+"""
diff --git a/flagai/model/vision/layers/helpers.py b/flagai/model/vision/layers/helpers.py
new file mode 100755
index 00000000..cc54ca7f
--- /dev/null
+++ b/flagai/model/vision/layers/helpers.py
@@ -0,0 +1,31 @@
+""" Layer/Module Helpers
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from itertools import repeat
+import collections.abc
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+
+def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
+ min_value = min_value or divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < round_limit * v:
+ new_v += divisor
+ return new_v
diff --git a/flagai/model/vision/layers/inplace_abn.py b/flagai/model/vision/layers/inplace_abn.py
new file mode 100755
index 00000000..a8088933
--- /dev/null
+++ b/flagai/model/vision/layers/inplace_abn.py
@@ -0,0 +1,87 @@
+import torch
+from torch import nn as nn
+
+try:
+ from inplace_abn.functions import inplace_abn, inplace_abn_sync
+ has_iabn = True
+except ImportError:
+ has_iabn = False
+
+ def inplace_abn(x, weight, bias, running_mean, running_var,
+ training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
+ raise ImportError(
+ "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")
+
+ def inplace_abn_sync(**kwargs):
+ inplace_abn(**kwargs)
+
+
+class InplaceAbn(nn.Module):
+ """Activated Batch Normalization
+
+ This gathers a BatchNorm and an activation function in a single module
+
+ Parameters
+ ----------
+ num_features : int
+ Number of feature channels in the input and output.
+ eps : float
+ Small constant to prevent numerical issues.
+ momentum : float
+ Momentum factor applied to compute running statistics.
+ affine : bool
+ If `True` apply learned scale and shift transformation after normalization.
+ act_layer : str or nn.Module type
+ Name or type of the activation functions, one of: `leaky_relu`, `elu`
+ act_param : float
+ Negative slope for the `leaky_relu` activation.
+ """
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
+ act_layer="leaky_relu", act_param=0.01, drop_layer=None):
+ super(InplaceAbn, self).__init__()
+ self.num_features = num_features
+ self.affine = affine
+ self.eps = eps
+ self.momentum = momentum
+ if apply_act:
+ if isinstance(act_layer, str):
+ assert act_layer in ('leaky_relu', 'elu', 'identity', '')
+ self.act_name = act_layer if act_layer else 'identity'
+ else:
+ # convert act layer passed as type to string
+ if act_layer == nn.ELU:
+ self.act_name = 'elu'
+ elif act_layer == nn.LeakyReLU:
+ self.act_name = 'leaky_relu'
+ elif act_layer is None or act_layer == nn.Identity:
+ self.act_name = 'identity'
+ else:
+ assert False, f'Invalid act layer {act_layer.__name__} for IABN'
+ else:
+ self.act_name = 'identity'
+ self.act_param = act_param
+ if self.affine:
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.constant_(self.running_mean, 0)
+ nn.init.constant_(self.running_var, 1)
+ if self.affine:
+ nn.init.constant_(self.weight, 1)
+ nn.init.constant_(self.bias, 0)
+
+ def forward(self, x):
+ output = inplace_abn(
+ x, self.weight, self.bias, self.running_mean, self.running_var,
+ self.training, self.momentum, self.eps, self.act_name, self.act_param)
+ if isinstance(output, tuple):
+ output = output[0]
+ return output
diff --git a/flagai/model/vision/layers/lambda_layer.py b/flagai/model/vision/layers/lambda_layer.py
new file mode 100755
index 00000000..e50b43c8
--- /dev/null
+++ b/flagai/model/vision/layers/lambda_layer.py
@@ -0,0 +1,133 @@
+""" Lambda Layer
+
+Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
+ - https://arxiv.org/abs/2102.08602
+
+@misc{2102.08602,
+Author = {Irwan Bello},
+Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention},
+Year = {2021},
+}
+
+Status:
+This impl is a WIP. Code snippets in the paper were used as reference but
+good chance some details are missing/wrong.
+
+I've only implemented local lambda conv based pos embeddings.
+
+For a PyTorch impl that includes other embedding options checkout
+https://github.com/lucidrains/lambda-networks
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .helpers import to_2tuple, make_divisible
+from .weight_init import trunc_normal_
+
+
+def rel_pos_indices(size):
+ size = to_2tuple(size)
+ pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
+ rel_pos = pos[:, None, :] - pos[:, :, None]
+ rel_pos[0] += size[0] - 1
+ rel_pos[1] += size[1] - 1
+ return rel_pos # 2, H * W, H * W
+
+
+class LambdaLayer(nn.Module):
+ """Lambda Layer
+
+ Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
+ - https://arxiv.org/abs/2102.08602
+
+ NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
+
+ The internal dimensions of the lambda module are controlled via the interaction of several arguments.
+ * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
+ * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
+ * the query (q) and key (k) dimension are determined by
+ * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None
+ * q = num_heads * dim_head, k = dim_head
+ * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set
+
+ Args:
+ dim (int): input dimension to the module
+ dim_out (int): output dimension of the module, same as dim if not set
+ feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W
+ stride (int): output stride of the module, avg pool used if stride == 2
+ num_heads (int): parallel attention heads.
+ dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
+ r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)
+ qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
+ qkv_bias (bool): add bias to q, k, and v projections
+ """
+ def __init__(
+ self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9,
+ qk_ratio=1.0, qkv_bias=False):
+ super().__init__()
+ dim_out = dim_out or dim
+ assert dim_out % num_heads == 0, ' should be divided by num_heads'
+ self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
+ self.num_heads = num_heads
+ self.dim_v = dim_out // num_heads
+
+ self.qkv = nn.Conv2d(
+ dim,
+ num_heads * self.dim_qk + self.dim_qk + self.dim_v,
+ kernel_size=1, bias=qkv_bias)
+ self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk)
+ self.norm_v = nn.BatchNorm2d(self.dim_v)
+
+ if r is not None:
+ # local lambda convolution for pos
+ self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0))
+ self.pos_emb = None
+ self.rel_pos_indices = None
+ else:
+ # relative pos embedding
+ assert feat_size is not None
+ feat_size = to_2tuple(feat_size)
+ rel_size = [2 * s - 1 for s in feat_size]
+ self.conv_lambda = None
+ self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk))
+ self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False)
+
+ self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
+ if self.conv_lambda is not None:
+ trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5)
+ if self.pos_emb is not None:
+ trunc_normal_(self.pos_emb, std=.02)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ M = H * W
+ qkv = self.qkv(x)
+ q, k, v = torch.split(qkv, [
+ self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1)
+ q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K
+ v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
+ k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
+
+ content_lam = k @ v # B, K, V
+ content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
+
+ if self.pos_emb is None:
+ position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
+ position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
+ else:
+ # FIXME relative pos embedding path not fully verified
+ pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)
+ position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V
+ position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V
+
+ out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
+ out = self.pool(out)
+ return out
diff --git a/flagai/model/vision/layers/linear.py b/flagai/model/vision/layers/linear.py
new file mode 100755
index 00000000..38fe3380
--- /dev/null
+++ b/flagai/model/vision/layers/linear.py
@@ -0,0 +1,19 @@
+""" Linear layer (alternate definition)
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn as nn
+
+
+class Linear(nn.Linear):
+ r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
+
+ Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
+ weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
+ """
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ if torch.jit.is_scripting():
+ bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
+ return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
+ else:
+ return F.linear(input, self.weight, self.bias)
diff --git a/flagai/model/vision/layers/median_pool.py b/flagai/model/vision/layers/median_pool.py
new file mode 100755
index 00000000..40bd71a7
--- /dev/null
+++ b/flagai/model/vision/layers/median_pool.py
@@ -0,0 +1,49 @@
+""" Median Pool
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch.nn as nn
+import torch.nn.functional as F
+from .helpers import to_2tuple, to_4tuple
+
+
+class MedianPool2d(nn.Module):
+ """ Median pool (usable as median filter when stride=1) module.
+
+ Args:
+ kernel_size: size of pooling kernel, int or 2-tuple
+ stride: pool stride, int or 2-tuple
+ padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
+ same: override padding and enforce same padding, boolean
+ """
+ def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
+ super(MedianPool2d, self).__init__()
+ self.k = to_2tuple(kernel_size)
+ self.stride = to_2tuple(stride)
+ self.padding = to_4tuple(padding) # convert to l, r, t, b
+ self.same = same
+
+ def _padding(self, x):
+ if self.same:
+ ih, iw = x.size()[2:]
+ if ih % self.stride[0] == 0:
+ ph = max(self.k[0] - self.stride[0], 0)
+ else:
+ ph = max(self.k[0] - (ih % self.stride[0]), 0)
+ if iw % self.stride[1] == 0:
+ pw = max(self.k[1] - self.stride[1], 0)
+ else:
+ pw = max(self.k[1] - (iw % self.stride[1]), 0)
+ pl = pw // 2
+ pr = pw - pl
+ pt = ph // 2
+ pb = ph - pt
+ padding = (pl, pr, pt, pb)
+ else:
+ padding = self.padding
+ return padding
+
+ def forward(self, x):
+ x = F.pad(x, self._padding(x), mode='reflect')
+ x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
+ x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
+ return x
diff --git a/flagai/model/vision/layers/mixed_conv2d.py b/flagai/model/vision/layers/mixed_conv2d.py
new file mode 100755
index 00000000..fa0ce565
--- /dev/null
+++ b/flagai/model/vision/layers/mixed_conv2d.py
@@ -0,0 +1,51 @@
+""" PyTorch Mixed Convolution
+
+Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import torch
+from torch import nn as nn
+
+from .conv2d_same import create_conv2d_pad
+
+
+def _split_channels(num_chan, num_groups):
+ split = [num_chan // num_groups for _ in range(num_groups)]
+ split[0] += num_chan - sum(split)
+ return split
+
+
+class MixedConv2d(nn.ModuleDict):
+ """ Mixed Grouped Convolution
+
+ Based on MDConv and GroupedConv in MixNet impl:
+ https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
+ """
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding='', dilation=1, depthwise=False, **kwargs):
+ super(MixedConv2d, self).__init__()
+
+ kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
+ num_groups = len(kernel_size)
+ in_splits = _split_channels(in_channels, num_groups)
+ out_splits = _split_channels(out_channels, num_groups)
+ self.in_channels = sum(in_splits)
+ self.out_channels = sum(out_splits)
+ for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
+ conv_groups = in_ch if depthwise else 1
+ # use add_module to keep key space clean
+ self.add_module(
+ str(idx),
+ create_conv2d_pad(
+ in_ch, out_ch, k, stride=stride,
+ padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
+ )
+ self.splits = in_splits
+
+ def forward(self, x):
+ x_split = torch.split(x, self.splits, 1)
+ x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
+ x = torch.cat(x_out, 1)
+ return x
diff --git a/flagai/model/vision/layers/ml_decoder.py b/flagai/model/vision/layers/ml_decoder.py
new file mode 100755
index 00000000..3f828c6d
--- /dev/null
+++ b/flagai/model/vision/layers/ml_decoder.py
@@ -0,0 +1,156 @@
+from typing import Optional
+
+import torch
+from torch import nn
+from torch import nn, Tensor
+from torch.nn.modules.transformer import _get_activation_fn
+
+
+def add_ml_decoder_head(model):
+ if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
+ model.global_pool = nn.Identity()
+ del model.fc
+ num_classes = model.num_classes
+ num_features = model.num_features
+ model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
+ elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
+ model.global_pool = nn.Identity()
+ del model.classifier
+ num_classes = model.num_classes
+ num_features = model.num_features
+ model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
+ elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
+ del model.head
+ num_classes = model.num_classes
+ num_features = model.num_features
+ model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
+ else:
+ print("Model code-writing is not aligned currently with ml-decoder")
+ exit(-1)
+ if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
+ model.drop_rate = 0
+ return model
+
+
+class TransformerDecoderLayerOptimal(nn.Module):
+ def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu",
+ layer_norm_eps=1e-5) -> None:
+ super(TransformerDecoderLayerOptimal, self).__init__()
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.dropout = nn.Dropout(dropout)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+
+ self.activation = _get_activation_fn(activation)
+
+ def __setstate__(self, state):
+ if 'activation' not in state:
+ state['activation'] = torch.nn.functional.relu
+ super(TransformerDecoderLayerOptimal, self).__setstate__(state)
+
+ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
+ tgt = tgt + self.dropout1(tgt)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(tgt, memory, memory)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+
+# @torch.jit.script
+# class ExtrapClasses(object):
+# def __init__(self, num_queries: int, group_size: int):
+# self.num_queries = num_queries
+# self.group_size = group_size
+#
+# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap:
+# torch.Tensor):
+# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size)
+# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups])
+# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size))
+# out = (h * w).sum(dim=2) + class_embed_b
+# out = out.view((h.shape[0], self.group_size * self.num_queries))
+# return out
+
+@torch.jit.script
+class GroupFC(object):
+ def __init__(self, embed_len_decoder: int):
+ self.embed_len_decoder = embed_len_decoder
+
+ def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
+ for i in range(self.embed_len_decoder):
+ h_i = h[:, i, :]
+ w_i = duplicate_pooling[i, :, :]
+ out_extrap[:, i, :] = torch.matmul(h_i, w_i)
+
+
+class MLDecoder(nn.Module):
+ def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
+ super(MLDecoder, self).__init__()
+ embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
+ if embed_len_decoder > num_classes:
+ embed_len_decoder = num_classes
+
+ # switching to 768 initial embeddings
+ decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
+ self.embed_standart = nn.Linear(initial_num_features, decoder_embedding)
+
+ # decoder
+ decoder_dropout = 0.1
+ num_layers_decoder = 1
+ dim_feedforward = 2048
+ layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding,
+ dim_feedforward=dim_feedforward, dropout=decoder_dropout)
+ self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder)
+
+ # non-learnable queries
+ self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding)
+ self.query_embed.requires_grad_(False)
+
+ # group fully-connected
+ self.num_classes = num_classes
+ self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999)
+ self.duplicate_pooling = torch.nn.Parameter(
+ torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor))
+ self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
+ torch.nn.init.xavier_normal_(self.duplicate_pooling)
+ torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
+ self.group_fc = GroupFC(embed_len_decoder)
+
+ def forward(self, x):
+ if len(x.shape) == 4: # [bs,2048, 7,7]
+ embedding_spatial = x.flatten(2).transpose(1, 2)
+ else: # [bs, 197,468]
+ embedding_spatial = x
+ embedding_spatial_786 = self.embed_standart(embedding_spatial)
+ embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)
+
+ bs = embedding_spatial_786.shape[0]
+ query_embed = self.query_embed.weight
+ # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
+ tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand
+ h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768]
+ h = h.transpose(0, 1)
+
+ out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
+ self.group_fc(h, self.duplicate_pooling, out_extrap)
+ h_out = out_extrap.flatten(1)[:, :self.num_classes]
+ h_out += self.duplicate_pooling_bias
+ logits = h_out
+ return logits
diff --git a/flagai/model/vision/layers/mlp.py b/flagai/model/vision/layers/mlp.py
new file mode 100755
index 00000000..91e80a84
--- /dev/null
+++ b/flagai/model/vision/layers/mlp.py
@@ -0,0 +1,126 @@
+""" MLP module w/ dropout and configurable activation layer
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch import nn as nn
+
+from .helpers import to_2tuple
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class GluMlp(nn.Module):
+ """ MLP w/ GLU style gating
+ See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ assert hidden_features % 2 == 0
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def init_weights(self):
+ # override init of fc1 w/ gate portion set to weight near zero, bias=1
+ fc1_mid = self.fc1.bias.shape[0] // 2
+ nn.init.ones_(self.fc1.bias[fc1_mid:])
+ nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x, gates = x.chunk(2, dim=-1)
+ x = x * self.act(gates)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class GatedMlp(nn.Module):
+ """ MLP as used in gMLP
+ """
+ def __init__(
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
+ gate_layer=None, bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ if gate_layer is not None:
+ assert hidden_features % 2 == 0
+ self.gate = gate_layer(hidden_features)
+ hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
+ else:
+ self.gate = nn.Identity()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.gate(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class ConvMlp(nn.Module):
+ """ MLP using 1x1 convs that keeps spatial dims
+ """
+ def __init__(
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
+ norm_layer=None, bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+
+ self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])
+ self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
+ self.act = act_layer()
+ self.drop = nn.Dropout(drop)
+ self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.norm(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ return x
diff --git a/flagai/model/vision/layers/non_local_attn.py b/flagai/model/vision/layers/non_local_attn.py
new file mode 100755
index 00000000..670e8f24
--- /dev/null
+++ b/flagai/model/vision/layers/non_local_attn.py
@@ -0,0 +1,145 @@
+""" Bilinear-Attention-Transform and Non-Local Attention
+
+Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms`
+ - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html
+Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification
+"""
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .conv_bn_act import ConvNormAct
+from .helpers import make_divisible
+from .trace_utils import _assert
+
+
+class NonLocalAttn(nn.Module):
+ """Spatial NL block for image classification.
+
+ This was adapted from https://github.com/BA-Transform/BAT-Image-Classification
+ Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net.
+ """
+
+ def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs):
+ super(NonLocalAttn, self).__init__()
+ if rd_channels is None:
+ rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
+ self.scale = in_channels ** -0.5 if use_scale else 1.0
+ self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
+ self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
+ self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
+ self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True)
+ self.norm = nn.BatchNorm2d(in_channels)
+ self.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+
+ t = self.t(x)
+ p = self.p(x)
+ g = self.g(x)
+
+ B, C, H, W = t.size()
+ t = t.view(B, C, -1).permute(0, 2, 1)
+ p = p.view(B, C, -1)
+ g = g.view(B, C, -1).permute(0, 2, 1)
+
+ att = torch.bmm(t, p) * self.scale
+ att = F.softmax(att, dim=2)
+ x = torch.bmm(att, g)
+
+ x = x.permute(0, 2, 1).reshape(B, C, H, W)
+ x = self.z(x)
+ x = self.norm(x) + shortcut
+
+ return x
+
+ def reset_parameters(self):
+ for name, m in self.named_modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ if len(list(m.parameters())) > 1:
+ nn.init.constant_(m.bias, 0.0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 0)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.GroupNorm):
+ nn.init.constant_(m.weight, 0)
+ nn.init.constant_(m.bias, 0)
+
+
+class BilinearAttnTransform(nn.Module):
+
+ def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
+ super(BilinearAttnTransform, self).__init__()
+
+ self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
+ self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1))
+ self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size))
+ self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
+ self.block_size = block_size
+ self.groups = groups
+ self.in_channels = in_channels
+
+ def resize_mat(self, x, t: int):
+ B, C, block_size, block_size1 = x.shape
+ _assert(block_size == block_size1, '')
+ if t <= 1:
+ return x
+ x = x.view(B * C, -1, 1, 1)
+ x = x * torch.eye(t, t, dtype=x.dtype, device=x.device)
+ x = x.view(B * C, block_size, block_size, t, t)
+ x = torch.cat(torch.split(x, 1, dim=1), dim=3)
+ x = torch.cat(torch.split(x, 1, dim=2), dim=4)
+ x = x.view(B, C, block_size * t, block_size * t)
+ return x
+
+ def forward(self, x):
+ _assert(x.shape[-1] % self.block_size == 0, '')
+ _assert(x.shape[-2] % self.block_size == 0, '')
+ B, C, H, W = x.shape
+ out = self.conv1(x)
+ rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
+ cp = F.adaptive_max_pool2d(out, (1, self.block_size))
+ p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
+ q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
+ p = p / p.sum(dim=3, keepdim=True)
+ q = q / q.sum(dim=2, keepdim=True)
+ p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
+ 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
+ p = p.view(B, C, self.block_size, self.block_size)
+ q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
+ 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
+ q = q.view(B, C, self.block_size, self.block_size)
+ p = self.resize_mat(p, H // self.block_size)
+ q = self.resize_mat(q, W // self.block_size)
+ y = p.matmul(x)
+ y = y.matmul(q)
+
+ y = self.conv2(y)
+ return y
+
+
+class BatNonLocalAttn(nn.Module):
+ """ BAT
+ Adapted from: https://github.com/BA-Transform/BAT-Image-Classification
+ """
+
+ def __init__(
+ self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
+ drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_):
+ super().__init__()
+ if rd_channels is None:
+ rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
+ self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
+ self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer)
+ self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
+ self.dropout = nn.Dropout2d(p=drop_rate)
+
+ def forward(self, x):
+ xl = self.conv1(x)
+ y = self.ba(xl)
+ y = self.conv2(y)
+ y = self.dropout(y)
+ return y + x
diff --git a/flagai/model/vision/layers/norm.py b/flagai/model/vision/layers/norm.py
new file mode 100755
index 00000000..85297420
--- /dev/null
+++ b/flagai/model/vision/layers/norm.py
@@ -0,0 +1,24 @@
+""" Normalization layers and wrappers
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class GroupNorm(nn.GroupNorm):
+ def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
+ # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
+ super().__init__(num_groups, num_channels, eps=eps, affine=affine)
+
+ def forward(self, x):
+ return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
+
+
+class LayerNorm2d(nn.LayerNorm):
+ """ LayerNorm for channels of '2D' spatial BCHW tensors """
+ def __init__(self, num_channels):
+ super().__init__(num_channels)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.layer_norm(
+ x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
diff --git a/flagai/model/vision/layers/norm_act.py b/flagai/model/vision/layers/norm_act.py
new file mode 100755
index 00000000..34c4fd64
--- /dev/null
+++ b/flagai/model/vision/layers/norm_act.py
@@ -0,0 +1,151 @@
+""" Normalization + Activation Layers
+"""
+from typing import Union, List
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from .trace_utils import _assert
+from .create_act import get_act_layer
+
+
+class BatchNormAct2d(nn.BatchNorm2d):
+ """BatchNorm + Activation
+
+ This module performs BatchNorm + Activation in a manner that will remain backwards
+ compatible with weights trained with separate bn, act. This is why we inherit from BN
+ instead of composing it as a .bn member.
+ """
+ def __init__(
+ self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
+ apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
+ super(BatchNormAct2d, self).__init__(
+ num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
+ self.drop = drop_layer() if drop_layer is not None else nn.Identity()
+ act_layer = get_act_layer(act_layer) # string -> nn.Module
+ if act_layer is not None and apply_act:
+ act_args = dict(inplace=True) if inplace else {}
+ self.act = act_layer(**act_args)
+ else:
+ self.act = nn.Identity()
+
+ def forward(self, x):
+ # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
+ _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')
+
+ # exponential_average_factor is set to self.momentum
+ # (when it is available) only so that it gets updated
+ # in ONNX graph when this node is exported to ONNX.
+ if self.momentum is None:
+ exponential_average_factor = 0.0
+ else:
+ exponential_average_factor = self.momentum
+
+ if self.training and self.track_running_stats:
+ # TODO: if statement only here to tell the jit to skip emitting this when it is None
+ if self.num_batches_tracked is not None: # type: ignore[has-type]
+ self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / float(self.num_batches_tracked)
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
+
+ r"""
+ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
+ Mini-batch stats are used in training mode, and in eval mode when buffers are None.
+ """
+ if self.training:
+ bn_training = True
+ else:
+ bn_training = (self.running_mean is None) and (self.running_var is None)
+
+ r"""
+ Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
+ passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
+ used for normalization (i.e. in eval mode when buffers are not None).
+ """
+ x = F.batch_norm(
+ x,
+ # If buffers are not to be tracked, ensure that they won't be updated
+ self.running_mean if not self.training or self.track_running_stats else None,
+ self.running_var if not self.training or self.track_running_stats else None,
+ self.weight,
+ self.bias,
+ bn_training,
+ exponential_average_factor,
+ self.eps,
+ )
+ x = self.drop(x)
+ x = self.act(x)
+ return x
+
+
+def _num_groups(num_channels, num_groups, group_size):
+ if group_size:
+ assert num_channels % group_size == 0
+ return num_channels // group_size
+ return num_groups
+
+
+class GroupNormAct(nn.GroupNorm):
+ # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
+ def __init__(
+ self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None,
+ apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
+ super(GroupNormAct, self).__init__(
+ _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine)
+ self.drop = drop_layer() if drop_layer is not None else nn.Identity()
+ act_layer = get_act_layer(act_layer) # string -> nn.Module
+ if act_layer is not None and apply_act:
+ act_args = dict(inplace=True) if inplace else {}
+ self.act = act_layer(**act_args)
+ else:
+ self.act = nn.Identity()
+
+ def forward(self, x):
+ x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
+ x = self.drop(x)
+ x = self.act(x)
+ return x
+
+
+class LayerNormAct(nn.LayerNorm):
+ def __init__(
+ self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True,
+ apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
+ super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
+ self.drop = drop_layer() if drop_layer is not None else nn.Identity()
+ act_layer = get_act_layer(act_layer) # string -> nn.Module
+ if act_layer is not None and apply_act:
+ act_args = dict(inplace=True) if inplace else {}
+ self.act = act_layer(**act_args)
+ else:
+ self.act = nn.Identity()
+
+ def forward(self, x):
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ x = self.drop(x)
+ x = self.act(x)
+ return x
+
+
+class LayerNormAct2d(nn.LayerNorm):
+ def __init__(
+ self, num_channels, eps=1e-5, affine=True,
+ apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
+ super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
+ self.drop = drop_layer() if drop_layer is not None else nn.Identity()
+ act_layer = get_act_layer(act_layer) # string -> nn.Module
+ if act_layer is not None and apply_act:
+ act_args = dict(inplace=True) if inplace else {}
+ self.act = act_layer(**act_args)
+ else:
+ self.act = nn.Identity()
+
+ def forward(self, x):
+ x = F.layer_norm(
+ x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
+ x = self.drop(x)
+ x = self.act(x)
+ return x
diff --git a/flagai/model/vision/layers/padding.py b/flagai/model/vision/layers/padding.py
new file mode 100755
index 00000000..34afc37c
--- /dev/null
+++ b/flagai/model/vision/layers/padding.py
@@ -0,0 +1,56 @@
+""" Padding Helpers
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+from typing import List, Tuple
+
+import torch.nn.functional as F
+
+
+# Calculate symmetric padding for a convolution
+def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
+ return padding
+
+
+# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
+def get_same_padding(x: int, k: int, s: int, d: int):
+ return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
+
+
+# Can SAME padding for given args be done statically?
+def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
+ return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
+
+
+# Dynamically pad input x with 'SAME' padding for conv with specified args
+def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
+ ih, iw = x.size()[-2:]
+ pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
+ return x
+
+
+def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
+ dynamic = False
+ if isinstance(padding, str):
+ # for any string padding, the padding will be calculated for you, one of three ways
+ padding = padding.lower()
+ if padding == 'same':
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
+ if is_static_pad(kernel_size, **kwargs):
+ # static case, no extra overhead
+ padding = get_padding(kernel_size, **kwargs)
+ else:
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
+ padding = 0
+ dynamic = True
+ elif padding == 'valid':
+ # 'VALID' padding, same as padding=0
+ padding = 0
+ else:
+ # Default to PyTorch style 'same'-ish symmetric padding
+ padding = get_padding(kernel_size, **kwargs)
+ return padding, dynamic
diff --git a/flagai/model/vision/layers/patch_embed.py b/flagai/model/vision/layers/patch_embed.py
new file mode 100755
index 00000000..b074798b
--- /dev/null
+++ b/flagai/model/vision/layers/patch_embed.py
@@ -0,0 +1,40 @@
+""" Image to Patch Embedding using Conv2d
+
+A convolution based approach to patchifying a 2D image w/ embedding projection.
+
+Based on the impl in https://github.com/google-research/vision_transformer
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch import nn as nn
+
+from .helpers import to_2tuple
+from .trace_utils import _assert
+
+
+class PatchEmbed(nn.Module):
+ """ 2D Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
+ _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
+
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x
diff --git a/flagai/model/vision/layers/pool2d_same.py b/flagai/model/vision/layers/pool2d_same.py
new file mode 100755
index 00000000..4c2a1c44
--- /dev/null
+++ b/flagai/model/vision/layers/pool2d_same.py
@@ -0,0 +1,73 @@
+""" AvgPool2d w/ Same Padding
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List, Tuple, Optional
+
+from .helpers import to_2tuple
+from .padding import pad_same, get_padding_value
+
+
+def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
+ ceil_mode: bool = False, count_include_pad: bool = True):
+ # FIXME how to deal with count_include_pad vs not for external padding?
+ x = pad_same(x, kernel_size, stride)
+ return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
+
+
+class AvgPool2dSame(nn.AvgPool2d):
+ """ Tensorflow like 'SAME' wrapper for 2D average pooling
+ """
+ def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
+
+ def forward(self, x):
+ x = pad_same(x, self.kernel_size, self.stride)
+ return F.avg_pool2d(
+ x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
+
+
+def max_pool2d_same(
+ x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
+ dilation: List[int] = (1, 1), ceil_mode: bool = False):
+ x = pad_same(x, kernel_size, stride, value=-float('inf'))
+ return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
+
+
+class MaxPool2dSame(nn.MaxPool2d):
+ """ Tensorflow like 'SAME' wrapper for 2D max pooling
+ """
+ def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+ super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
+
+ def forward(self, x):
+ x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
+ return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
+
+
+def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
+ stride = stride or kernel_size
+ padding = kwargs.pop('padding', '')
+ padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
+ if is_dynamic:
+ if pool_type == 'avg':
+ return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
+ elif pool_type == 'max':
+ return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
+ else:
+ assert False, f'Unsupported pool type {pool_type}'
+ else:
+ if pool_type == 'avg':
+ return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
+ elif pool_type == 'max':
+ return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
+ else:
+ assert False, f'Unsupported pool type {pool_type}'
diff --git a/flagai/model/vision/layers/pos_embed.py b/flagai/model/vision/layers/pos_embed.py
new file mode 100755
index 00000000..99a122a0
--- /dev/null
+++ b/flagai/model/vision/layers/pos_embed.py
@@ -0,0 +1,207 @@
+import math
+from typing import List, Tuple, Optional, Union
+
+import torch
+from torch import nn as nn
+
+
+def pixel_freq_bands(
+ num_bands: int,
+ max_freq: float = 224.,
+ linear_bands: bool = True,
+ dtype: torch.dtype = torch.float32,
+ device: Optional[torch.device] = None,
+):
+ if linear_bands:
+ bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
+ else:
+ bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
+ return bands * torch.pi
+
+
+def inv_freq_bands(
+ num_bands: int,
+ temperature: float = 100000.,
+ step: int = 2,
+ dtype: torch.dtype = torch.float32,
+ device: Optional[torch.device] = None,
+) -> torch.Tensor:
+ inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
+ return inv_freq
+
+
+def build_sincos2d_pos_embed(
+ feat_shape: List[int],
+ dim: int = 64,
+ temperature: float = 10000.,
+ reverse_coord: bool = False,
+ interleave_sin_cos: bool = False,
+ dtype: torch.dtype = torch.float32,
+ device: Optional[torch.device] = None
+) -> torch.Tensor:
+ """
+
+ Args:
+ feat_shape:
+ dim:
+ temperature:
+ reverse_coord: stack grid order W, H instead of H, W
+ interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
+ dtype:
+ device:
+
+ Returns:
+
+ """
+ assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
+ pos_dim = dim // 4
+ bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
+
+ if reverse_coord:
+ feat_shape = feat_shape[::-1] # stack W, H instead of H, W
+ grid = torch.stack(
+ torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
+ pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
+ # FIXME add support for unflattened spatial dim?
+
+ stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
+ pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
+ return pos_emb
+
+
+def build_fourier_pos_embed(
+ feat_shape: List[int],
+ bands: Optional[torch.Tensor] = None,
+ num_bands: int = 64,
+ max_res: int = 224,
+ linear_bands: bool = False,
+ include_grid: bool = False,
+ concat_out: bool = True,
+ in_pixels: bool = True,
+ dtype: torch.dtype = torch.float32,
+ device: Optional[torch.device] = None,
+) -> List[torch.Tensor]:
+ if bands is None:
+ if in_pixels:
+ bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
+ else:
+ bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
+ else:
+ if device is None:
+ device = bands.device
+ if dtype is None:
+ dtype = bands.dtype
+
+ if in_pixels:
+ grid = torch.stack(torch.meshgrid(
+ [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
+ else:
+ grid = torch.stack(torch.meshgrid(
+ [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
+ grid = grid.unsqueeze(-1)
+ pos = grid * bands
+
+ pos_sin, pos_cos = pos.sin(), pos.cos()
+ out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
+ # FIXME torchscript doesn't like multiple return types, probably need to always cat?
+ if concat_out:
+ out = torch.cat(out, dim=-1)
+ return out
+
+
+class FourierEmbed(nn.Module):
+
+ def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
+ super().__init__()
+ self.max_res = max_res
+ self.num_bands = num_bands
+ self.concat_grid = concat_grid
+ self.keep_spatial = keep_spatial
+ self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
+
+ def forward(self, x):
+ B, C = x.shape[:2]
+ feat_shape = x.shape[2:]
+ emb = build_fourier_pos_embed(
+ feat_shape,
+ self.bands,
+ include_grid=self.concat_grid,
+ dtype=x.dtype,
+ device=x.device)
+ emb = emb.transpose(-1, -2).flatten(len(feat_shape))
+ batch_expand = (B,) + (-1,) * (x.ndim - 1)
+
+ # FIXME support nD
+ if self.keep_spatial:
+ x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
+ else:
+ x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
+ x = x.reshape(B, feat_shape.numel(), -1)
+
+ return x
+
+
+def rot(x):
+ return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
+
+
+def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
+ return x * cos_emb + rot(x) * sin_emb
+
+
+def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
+ if isinstance(x, torch.Tensor):
+ x = [x]
+ return [t * cos_emb + rot(t) * sin_emb for t in x]
+
+
+def apply_rot_embed_split(x: torch.Tensor, emb):
+ split = emb.shape[-1] // 2
+ return x * emb[:, :split] + rot(x) * emb[:, split:]
+
+
+def build_rotary_pos_embed(
+ feat_shape: List[int],
+ bands: Optional[torch.Tensor] = None,
+ dim: int = 64,
+ max_freq: float = 224,
+ linear_bands: bool = False,
+ dtype: torch.dtype = torch.float32,
+ device: Optional[torch.device] = None,
+):
+ """
+ NOTE: shape arg should include spatial dim only
+ """
+ feat_shape = torch.Size(feat_shape)
+
+ sin_emb, cos_emb = build_fourier_pos_embed(
+ feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands,
+ concat_out=False, device=device, dtype=dtype)
+ N = feat_shape.numel()
+ sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
+ cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
+ return sin_emb, cos_emb
+
+
+class RotaryEmbedding(nn.Module):
+ """ Rotary position embedding
+
+ NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
+ been well tested, and will likely change. It will be moved to its own file.
+
+ The following impl/resources were referenced for this impl:
+ * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
+ * https://blog.eleuther.ai/rotary-embeddings/
+ """
+ def __init__(self, dim, max_res=224, linear_bands: bool = False):
+ super().__init__()
+ self.dim = dim
+ self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
+
+ def get_embed(self, shape: List[int]):
+ return build_rotary_pos_embed(shape, self.bands)
+
+ def forward(self, x):
+ # assuming channel-first tensor where spatial dim are >= 2
+ sin_emb, cos_emb = self.get_embed(x.shape[2:])
+ return apply_rot_embed(x, sin_emb, cos_emb)
diff --git a/flagai/model/vision/layers/selective_kernel.py b/flagai/model/vision/layers/selective_kernel.py
new file mode 100755
index 00000000..3d71e3aa
--- /dev/null
+++ b/flagai/model/vision/layers/selective_kernel.py
@@ -0,0 +1,119 @@
+""" Selective Kernel Convolution/Attention
+
+Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+from torch import nn as nn
+
+from .conv_bn_act import ConvNormActAa
+from .helpers import make_divisible
+from .trace_utils import _assert
+
+
+def _kernel_valid(k):
+ if isinstance(k, (list, tuple)):
+ for ki in k:
+ return _kernel_valid(ki)
+ assert k >= 3 and k % 2
+
+
+class SelectiveKernelAttn(nn.Module):
+ def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
+ """ Selective Kernel Attention Module
+
+ Selective Kernel attention mechanism factored out into its own module.
+
+ """
+ super(SelectiveKernelAttn, self).__init__()
+ self.num_paths = num_paths
+ self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
+ self.bn = norm_layer(attn_channels)
+ self.act = act_layer(inplace=True)
+ self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
+
+ def forward(self, x):
+ _assert(x.shape[1] == self.num_paths, '')
+ x = x.sum(1).mean((2, 3), keepdim=True)
+ x = self.fc_reduce(x)
+ x = self.bn(x)
+ x = self.act(x)
+ x = self.fc_select(x)
+ B, C, H, W = x.shape
+ x = x.view(B, self.num_paths, C // self.num_paths, H, W)
+ x = torch.softmax(x, dim=1)
+ return x
+
+
+class SelectiveKernel(nn.Module):
+
+ def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
+ rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True,
+ act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None):
+ """ Selective Kernel Convolution Module
+
+ As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
+
+ Largest change is the input split, which divides the input channels across each convolution path, this can
+ be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps
+ the parameter count from ballooning when the convolutions themselves don't have groups, but still provides
+ a noteworthy increase in performance over similar param count models without this attention layer. -Ross W
+
+ Args:
+ in_channels (int): module input (feature) channel count
+ out_channels (int): module output (feature) channel count
+ kernel_size (int, list): kernel size for each convolution branch
+ stride (int): stride for convolutions
+ dilation (int): dilation for module as a whole, impacts dilation of each branch
+ groups (int): number of groups for each branch
+ rd_ratio (int, float): reduction factor for attention features
+ keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
+ split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
+ can be viewed as grouping by path, output expands to module out_channels count
+ act_layer (nn.Module): activation layer to use
+ norm_layer (nn.Module): batchnorm/norm layer to use
+ aa_layer (nn.Module): anti-aliasing module
+ drop_layer (nn.Module): spatial drop module in convs (drop block, etc)
+ """
+ super(SelectiveKernel, self).__init__()
+ out_channels = out_channels or in_channels
+ kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
+ _kernel_valid(kernel_size)
+ if not isinstance(kernel_size, list):
+ kernel_size = [kernel_size] * 2
+ if keep_3x3:
+ dilation = [dilation * (k - 1) // 2 for k in kernel_size]
+ kernel_size = [3] * len(kernel_size)
+ else:
+ dilation = [dilation] * len(kernel_size)
+ self.num_paths = len(kernel_size)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.split_input = split_input
+ if self.split_input:
+ assert in_channels % self.num_paths == 0
+ in_channels = in_channels // self.num_paths
+ groups = min(out_channels, groups)
+
+ conv_kwargs = dict(
+ stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer,
+ aa_layer=aa_layer, drop_layer=drop_layer)
+ self.paths = nn.ModuleList([
+ ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
+ for k, d in zip(kernel_size, dilation)])
+
+ attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
+ self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
+
+ def forward(self, x):
+ if self.split_input:
+ x_split = torch.split(x, self.in_channels // self.num_paths, 1)
+ x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
+ else:
+ x_paths = [op(x) for op in self.paths]
+ x = torch.stack(x_paths, dim=1)
+ x_attn = self.attn(x)
+ x = x * x_attn
+ x = torch.sum(x, dim=1)
+ return x
diff --git a/flagai/model/vision/layers/separable_conv.py b/flagai/model/vision/layers/separable_conv.py
new file mode 100755
index 00000000..c081e02b
--- /dev/null
+++ b/flagai/model/vision/layers/separable_conv.py
@@ -0,0 +1,76 @@
+""" Depthwise Separable Conv Modules
+
+Basic DWS convs. Other variations of DWS exist with batch norm or activations between the
+DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch import nn as nn
+
+from .create_conv2d import create_conv2d
+from .create_norm_act import get_norm_act_layer
+
+
+class SeparableConvNormAct(nn.Module):
+ """ Separable Conv w/ trailing Norm and Activation
+ """
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
+ channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU,
+ apply_act=True, drop_layer=None):
+ super(SeparableConvNormAct, self).__init__()
+
+ self.conv_dw = create_conv2d(
+ in_channels, int(in_channels * channel_multiplier), kernel_size,
+ stride=stride, dilation=dilation, padding=padding, depthwise=True)
+
+ self.conv_pw = create_conv2d(
+ int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
+
+ norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
+ norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
+ self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
+
+ @property
+ def in_channels(self):
+ return self.conv_dw.in_channels
+
+ @property
+ def out_channels(self):
+ return self.conv_pw.out_channels
+
+ def forward(self, x):
+ x = self.conv_dw(x)
+ x = self.conv_pw(x)
+ x = self.bn(x)
+ return x
+
+
+SeparableConvBnAct = SeparableConvNormAct
+
+
+class SeparableConv2d(nn.Module):
+ """ Separable Conv
+ """
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
+ channel_multiplier=1.0, pw_kernel_size=1):
+ super(SeparableConv2d, self).__init__()
+
+ self.conv_dw = create_conv2d(
+ in_channels, int(in_channels * channel_multiplier), kernel_size,
+ stride=stride, dilation=dilation, padding=padding, depthwise=True)
+
+ self.conv_pw = create_conv2d(
+ int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
+
+ @property
+ def in_channels(self):
+ return self.conv_dw.in_channels
+
+ @property
+ def out_channels(self):
+ return self.conv_pw.out_channels
+
+ def forward(self, x):
+ x = self.conv_dw(x)
+ x = self.conv_pw(x)
+ return x
diff --git a/flagai/model/vision/layers/space_to_depth.py b/flagai/model/vision/layers/space_to_depth.py
new file mode 100755
index 00000000..a7e8e0b2
--- /dev/null
+++ b/flagai/model/vision/layers/space_to_depth.py
@@ -0,0 +1,53 @@
+import torch
+import torch.nn as nn
+
+
+class SpaceToDepth(nn.Module):
+ def __init__(self, block_size=4):
+ super().__init__()
+ assert block_size == 4
+ self.bs = block_size
+
+ def forward(self, x):
+ N, C, H, W = x.size()
+ x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
+ x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
+ return x
+
+
+@torch.jit.script
+class SpaceToDepthJit(object):
+ def __call__(self, x: torch.Tensor):
+ # assuming hard-coded that block_size==4 for acceleration
+ N, C, H, W = x.size()
+ x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
+ x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
+ return x
+
+
+class SpaceToDepthModule(nn.Module):
+ def __init__(self, no_jit=False):
+ super().__init__()
+ if not no_jit:
+ self.op = SpaceToDepthJit()
+ else:
+ self.op = SpaceToDepth()
+
+ def forward(self, x):
+ return self.op(x)
+
+
+class DepthToSpace(nn.Module):
+
+ def __init__(self, block_size):
+ super().__init__()
+ self.bs = block_size
+
+ def forward(self, x):
+ N, C, H, W = x.size()
+ x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
+ x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
+ return x
diff --git a/flagai/model/vision/layers/split_attn.py b/flagai/model/vision/layers/split_attn.py
new file mode 100755
index 00000000..ac54f898
--- /dev/null
+++ b/flagai/model/vision/layers/split_attn.py
@@ -0,0 +1,84 @@
+""" Split Attention Conv2d (for ResNeSt Models)
+
+Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
+
+Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
+
+Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .helpers import make_divisible
+
+
+class RadixSoftmax(nn.Module):
+ def __init__(self, radix, cardinality):
+ super(RadixSoftmax, self).__init__()
+ self.radix = radix
+ self.cardinality = cardinality
+
+ def forward(self, x):
+ batch = x.size(0)
+ if self.radix > 1:
+ x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
+ x = F.softmax(x, dim=1)
+ x = x.reshape(batch, -1)
+ else:
+ x = torch.sigmoid(x)
+ return x
+
+
+class SplitAttn(nn.Module):
+ """Split-Attention (aka Splat)
+ """
+ def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
+ dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
+ act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs):
+ super(SplitAttn, self).__init__()
+ out_channels = out_channels or in_channels
+ self.radix = radix
+ mid_chs = out_channels * radix
+ if rd_channels is None:
+ attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
+ else:
+ attn_chs = rd_channels * radix
+
+ padding = kernel_size // 2 if padding is None else padding
+ self.conv = nn.Conv2d(
+ in_channels, mid_chs, kernel_size, stride, padding, dilation,
+ groups=groups * radix, bias=bias, **kwargs)
+ self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
+ self.drop = drop_layer() if drop_layer is not None else nn.Identity()
+ self.act0 = act_layer(inplace=True)
+ self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
+ self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
+ self.act1 = act_layer(inplace=True)
+ self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
+ self.rsoftmax = RadixSoftmax(radix, groups)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn0(x)
+ x = self.drop(x)
+ x = self.act0(x)
+
+ B, RC, H, W = x.shape
+ if self.radix > 1:
+ x = x.reshape((B, self.radix, RC // self.radix, H, W))
+ x_gap = x.sum(dim=1)
+ else:
+ x_gap = x
+ x_gap = x_gap.mean((2, 3), keepdim=True)
+ x_gap = self.fc1(x_gap)
+ x_gap = self.bn1(x_gap)
+ x_gap = self.act1(x_gap)
+ x_attn = self.fc2(x_gap)
+
+ x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
+ if self.radix > 1:
+ out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
+ else:
+ out = x * x_attn
+ return out.contiguous()
diff --git a/flagai/model/vision/layers/split_batchnorm.py b/flagai/model/vision/layers/split_batchnorm.py
new file mode 100755
index 00000000..830781b3
--- /dev/null
+++ b/flagai/model/vision/layers/split_batchnorm.py
@@ -0,0 +1,75 @@
+""" Split BatchNorm
+
+A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
+a separate BN layer. The first split is passed through the parent BN layers with weight/bias
+keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
+namespace.
+
+This allows easily removing the auxiliary BN layers after training to efficiently
+achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
+'Disentangled Learning via An Auxiliary BN'
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+
+
+class SplitBatchNorm2d(torch.nn.BatchNorm2d):
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
+ track_running_stats=True, num_splits=2):
+ super().__init__(num_features, eps, momentum, affine, track_running_stats)
+ assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
+ self.num_splits = num_splits
+ self.aux_bn = nn.ModuleList([
+ nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
+
+ def forward(self, input: torch.Tensor):
+ if self.training: # aux BN only relevant while training
+ split_size = input.shape[0] // self.num_splits
+ assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
+ split_input = input.split(split_size)
+ x = [super().forward(split_input[0])]
+ for i, a in enumerate(self.aux_bn):
+ x.append(a(split_input[i + 1]))
+ return torch.cat(x, dim=0)
+ else:
+ return super().forward(input)
+
+
+def convert_splitbn_model(module, num_splits=2):
+ """
+ Recursively traverse module and its children to replace all instances of
+ ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
+ Args:
+ module (torch.nn.Module): input module
+ num_splits: number of separate batchnorm layers to split input across
+ Example::
+ >>> # model is an instance of torch.nn.Module
+ >>> model = timm.models.convert_splitbn_model(model, num_splits=2)
+ """
+ mod = module
+ if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
+ return module
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
+ mod = SplitBatchNorm2d(
+ module.num_features, module.eps, module.momentum, module.affine,
+ module.track_running_stats, num_splits=num_splits)
+ mod.running_mean = module.running_mean
+ mod.running_var = module.running_var
+ mod.num_batches_tracked = module.num_batches_tracked
+ if module.affine:
+ mod.weight.data = module.weight.data.clone().detach()
+ mod.bias.data = module.bias.data.clone().detach()
+ for aux in mod.aux_bn:
+ aux.running_mean = module.running_mean.clone()
+ aux.running_var = module.running_var.clone()
+ aux.num_batches_tracked = module.num_batches_tracked.clone()
+ if module.affine:
+ aux.weight.data = module.weight.data.clone().detach()
+ aux.bias.data = module.bias.data.clone().detach()
+ for name, child in module.named_children():
+ mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
+ del module
+ return mod
diff --git a/flagai/model/vision/layers/squeeze_excite.py b/flagai/model/vision/layers/squeeze_excite.py
new file mode 100755
index 00000000..e5da29ef
--- /dev/null
+++ b/flagai/model/vision/layers/squeeze_excite.py
@@ -0,0 +1,74 @@
+""" Squeeze-and-Excitation Channel Attention
+
+An SE implementation originally based on PyTorch SE-Net impl.
+Has since evolved with additional functionality / configuration.
+
+Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507
+
+Also included is Effective Squeeze-Excitation (ESE).
+Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from torch import nn as nn
+
+from .create_act import create_act_layer
+from .helpers import make_divisible
+
+
+class SEModule(nn.Module):
+ """ SE Module as defined in original SE-Nets with a few additions
+ Additions include:
+ * divisor can be specified to keep channels % div == 0 (default: 8)
+ * reduction channels can be specified directly by arg (if rd_channels is set)
+ * reduction channels can be specified by float rd_ratio (default: 1/16)
+ * global max pooling can be added to the squeeze aggregation
+ * customizable activation, normalization, and gate layer
+ """
+ def __init__(
+ self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
+ act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
+ super(SEModule, self).__init__()
+ self.add_maxpool = add_maxpool
+ if not rd_channels:
+ rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
+ self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
+ self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
+ self.act = create_act_layer(act_layer, inplace=True)
+ self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ x_se = x.mean((2, 3), keepdim=True)
+ if self.add_maxpool:
+ # experimental codepath, may remove or change
+ x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
+ x_se = self.fc1(x_se)
+ x_se = self.act(self.bn(x_se))
+ x_se = self.fc2(x_se)
+ return x * self.gate(x_se)
+
+
+SqueezeExcite = SEModule # alias
+
+
+class EffectiveSEModule(nn.Module):
+ """ 'Effective Squeeze-Excitation
+ From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
+ """
+ def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_):
+ super(EffectiveSEModule, self).__init__()
+ self.add_maxpool = add_maxpool
+ self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ x_se = x.mean((2, 3), keepdim=True)
+ if self.add_maxpool:
+ # experimental codepath, may remove or change
+ x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
+ x_se = self.fc(x_se)
+ return x * self.gate(x_se)
+
+
+EffectiveSqueezeExcite = EffectiveSEModule # alias
diff --git a/flagai/model/vision/layers/std_conv.py b/flagai/model/vision/layers/std_conv.py
new file mode 100755
index 00000000..d896ba5c
--- /dev/null
+++ b/flagai/model/vision/layers/std_conv.py
@@ -0,0 +1,133 @@
+""" Convolution with Weight Standardization (StdConv and ScaledStdConv)
+
+StdConv:
+@article{weightstandardization,
+ author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},
+ title = {Weight Standardization},
+ journal = {arXiv preprint arXiv:1903.10520},
+ year = {2019},
+}
+Code: https://github.com/joe-siyuan-qiao/WeightStandardization
+
+ScaledStdConv:
+Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
+
+Hacked together by / copyright Ross Wightman, 2021.
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .padding import get_padding, get_padding_value, pad_same
+
+
+class StdConv2d(nn.Conv2d):
+ """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
+
+ Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
+ https://arxiv.org/abs/1903.10520v2
+ """
+ def __init__(
+ self, in_channel, out_channels, kernel_size, stride=1, padding=None,
+ dilation=1, groups=1, bias=False, eps=1e-6):
+ if padding is None:
+ padding = get_padding(kernel_size, stride, dilation)
+ super().__init__(
+ in_channel, out_channels, kernel_size, stride=stride,
+ padding=padding, dilation=dilation, groups=groups, bias=bias)
+ self.eps = eps
+
+ def forward(self, x):
+ weight = F.batch_norm(
+ self.weight.reshape(1, self.out_channels, -1), None, None,
+ training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
+ x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+ return x
+
+
+class StdConv2dSame(nn.Conv2d):
+ """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model.
+
+ Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
+ https://arxiv.org/abs/1903.10520v2
+ """
+ def __init__(
+ self, in_channel, out_channels, kernel_size, stride=1, padding='SAME',
+ dilation=1, groups=1, bias=False, eps=1e-6):
+ padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
+ super().__init__(
+ in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
+ groups=groups, bias=bias)
+ self.same_pad = is_dynamic
+ self.eps = eps
+
+ def forward(self, x):
+ if self.same_pad:
+ x = pad_same(x, self.kernel_size, self.stride, self.dilation)
+ weight = F.batch_norm(
+ self.weight.reshape(1, self.out_channels, -1), None, None,
+ training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
+ x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+ return x
+
+
+class ScaledStdConv2d(nn.Conv2d):
+ """Conv2d layer with Scaled Weight Standardization.
+
+ Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
+ https://arxiv.org/abs/2101.08692
+
+ NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
+ """
+
+ def __init__(
+ self, in_channels, out_channels, kernel_size, stride=1, padding=None,
+ dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
+ if padding is None:
+ padding = get_padding(kernel_size, stride, dilation)
+ super().__init__(
+ in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
+ groups=groups, bias=bias)
+ self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
+ self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
+ self.eps = eps
+
+ def forward(self, x):
+ weight = F.batch_norm(
+ self.weight.reshape(1, self.out_channels, -1), None, None,
+ weight=(self.gain * self.scale).view(-1),
+ training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+
+class ScaledStdConv2dSame(nn.Conv2d):
+ """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support
+
+ Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
+ https://arxiv.org/abs/2101.08692
+
+ NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
+ """
+
+ def __init__(
+ self, in_channels, out_channels, kernel_size, stride=1, padding='SAME',
+ dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
+ padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
+ super().__init__(
+ in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
+ groups=groups, bias=bias)
+ self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
+ self.scale = gamma * self.weight[0].numel() ** -0.5
+ self.same_pad = is_dynamic
+ self.eps = eps
+
+ def forward(self, x):
+ if self.same_pad:
+ x = pad_same(x, self.kernel_size, self.stride, self.dilation)
+ weight = F.batch_norm(
+ self.weight.reshape(1, self.out_channels, -1), None, None,
+ weight=(self.gain * self.scale).view(-1),
+ training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
diff --git a/flagai/model/vision/layers/test_time_pool.py b/flagai/model/vision/layers/test_time_pool.py
new file mode 100755
index 00000000..98c0bf53
--- /dev/null
+++ b/flagai/model/vision/layers/test_time_pool.py
@@ -0,0 +1,52 @@
+""" Test Time Pooling (Average-Max Pool)
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import logging
+from torch import nn
+import torch.nn.functional as F
+
+from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestTimePoolHead(nn.Module):
+ def __init__(self, base, original_pool=7):
+ super(TestTimePoolHead, self).__init__()
+ self.base = base
+ self.original_pool = original_pool
+ base_fc = self.base.get_classifier()
+ if isinstance(base_fc, nn.Conv2d):
+ self.fc = base_fc
+ else:
+ self.fc = nn.Conv2d(
+ self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
+ self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
+ self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
+ self.base.reset_classifier(0) # delete original fc layer
+
+ def forward(self, x):
+ x = self.base.forward_features(x)
+ x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
+ x = self.fc(x)
+ x = adaptive_avgmax_pool2d(x, 1)
+ return x.view(x.size(0), -1)
+
+
+def apply_test_time_pool(model, config, use_test_size=True):
+ test_time_pool = False
+ if not hasattr(model, 'default_cfg') or not model.default_cfg:
+ return model, False
+ if use_test_size and 'test_input_size' in model.default_cfg:
+ df_input_size = model.default_cfg['test_input_size']
+ else:
+ df_input_size = model.default_cfg['input_size']
+ if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]:
+ _logger.info('Target input size %s > pretrained default %s, using test time pooling' %
+ (str(config['input_size'][-2:]), str(df_input_size[-2:])))
+ model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
+ test_time_pool = True
+ return model, test_time_pool
diff --git a/flagai/model/vision/layers/trace_utils.py b/flagai/model/vision/layers/trace_utils.py
new file mode 100755
index 00000000..83970729
--- /dev/null
+++ b/flagai/model/vision/layers/trace_utils.py
@@ -0,0 +1,13 @@
+try:
+ from torch import _assert
+except ImportError:
+ def _assert(condition: bool, message: str):
+ assert condition, message
+
+
+def _float_to_int(x: float) -> int:
+ """
+ Symbolic tracing helper to substitute for inbuilt `int`.
+ Hint: Inbuilt `int` can't accept an argument of type `Proxy`
+ """
+ return int(x)
diff --git a/flagai/model/vision/layers/weight_init.py b/flagai/model/vision/layers/weight_init.py
new file mode 100755
index 00000000..24c0fa7c
--- /dev/null
+++ b/flagai/model/vision/layers/weight_init.py
@@ -0,0 +1,88 @@
+import torch
+import math
+import warnings
+
+from torch.nn.init import _calculate_fan_in_and_fan_out
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == 'fan_in':
+ denom = fan_in
+ elif mode == 'fan_out':
+ denom = fan_out
+ elif mode == 'fan_avg':
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
+ elif distribution == "normal":
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
diff --git a/flagai/model/vision/vit.py b/flagai/model/vision/vit.py
new file mode 100644
index 00000000..6fc93865
--- /dev/null
+++ b/flagai/model/vision/vit.py
@@ -0,0 +1,497 @@
+"""
+# Copyright © 2022 BAAI. All rights reserved.
+"""
+
+"""
+Vision Transformer (ViT) in PyTorch
+A PyTorch implement of Vision Transformers as described in:
+'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
+ - https://arxiv.org/abs/2010.11929
+`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
+ - https://arxiv.org/abs/2106.10270
+The official jax code is released and available at https://github.com/google-research/vision_transformer
+Acknowledgments:
+* The paper authors for releasing code and weights, thanks!
+* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
+for some einops/einsum fun
+* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
+* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
+Hacked together by / Copyright 2020, Ross Wightman
+"""
+
+import math
+from functools import partial
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from typing import Callable
+from flagai.model.vision.layers.patch_embed import PatchEmbed
+from flagai.model.vision.layers.mlp import Mlp
+from flagai.model.vision.layers.drop import DropPath
+from flagai.model.vision.layers.weight_init import trunc_normal_, lecun_normal_
+from flagai.model.base_model import BaseModel
+from flagai.model.vision.helpers import checkpoint_seq
+
+class VitConfig:
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ global_pool='token',
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ init_values=None,
+ class_token=True,
+ fc_norm=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ weight_init='',
+ checkpoint_activations=False):
+ pass
+ self.img_size=img_size
+ self.patch_size=patch_size
+ self.in_chans=in_chans
+ self.num_classes=num_classes
+ self.global_pool=global_pool
+ self.embed_dim=embed_dim
+ self.depth=depth
+ self.num_heads=num_heads
+ self.mlp_ratio=mlp_ratio
+ self.qkv_bias=qkv_bias
+ self.init_values=init_values
+ self.class_token=class_token
+ self.fc_norm=fc_norm
+ self.drop_rate=drop_rate
+ self.attn_drop_rate=attn_drop_rate
+ self.drop_path_rate=drop_path_rate
+ self.weight_init=weight_init
+ self.checkpoint_activations = checkpoint_activations
+
+def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = '.'.join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+def adapt_input_conv(in_chans, conv_weight):
+ conv_type = conv_weight.dtype
+ conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
+ O, I, J, K = conv_weight.shape
+ if in_chans == 1:
+ if I > 3:
+ assert conv_weight.shape[1] % 3 == 0
+ # For models with space2depth stems
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
+ else:
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
+ elif in_chans != 3:
+ if I != 3:
+ raise NotImplementedError('Weight format not supported by conversion.')
+ else:
+ # NOTE this strategy should be better than random init, but there could be other combinations of
+ # the original RGB input layer weights that'd work better for specific cases.
+ repeat = int(math.ceil(in_chans / 3))
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
+ conv_weight *= (3 / float(in_chans))
+ conv_weight = conv_weight.to(conv_type)
+ return conv_weight
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+class VisionTransformer(BaseModel):
+ """ Vision Transformer
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
+ - https://arxiv.org/abs/2010.11929
+ """
+
+ def __init__(
+ self, config, num_classes=1000):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ global_pool (str): type of global pooling for final sequence (default: 'token')
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ init_values: (float): layer-scale init values
+ class_token (bool): use class token
+ fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ weight_init (str): weight init scheme
+ embed_layer (nn.Module): patch embedding layer
+ norm_layer: (nn.Module): normalization layer
+ act_layer: (nn.Module): MLP activation layer
+ """
+ super().__init__(config)
+ embed_layer=PatchEmbed
+ block_fn=Block
+ vit_config = VitConfig(**config)
+ vit_config.num_classes = num_classes
+ # config = vit_config
+
+ assert vit_config.global_pool in ('', 'avg', 'token')
+ assert vit_config.class_token or vit_config.global_pool != 'token'
+ use_fc_norm = vit_config.global_pool == 'avg' if vit_config.fc_norm is None else vit_config.fc_norm
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ act_layer = nn.GELU
+
+ self.num_classes = num_classes
+ self.global_pool = vit_config.global_pool
+ self.num_features = self.embed_dim = vit_config.embed_dim # num_features for consistency with other models
+ self.num_tokens = 1 if vit_config.class_token else 0
+ self.grad_checkpointing = vit_config.checkpoint_activations
+
+ self.patch_embed = embed_layer(
+ img_size=vit_config.img_size, patch_size=vit_config.patch_size, in_chans=vit_config.in_chans, embed_dim=vit_config.embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, vit_config.embed_dim)) if self.num_tokens > 0 else None
+ self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, vit_config.embed_dim) * .02)
+ self.pos_drop = nn.Dropout(p=vit_config.drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, vit_config.drop_path_rate, vit_config.depth)] # stochastic depth decay rule
+ self.blocks = nn.Sequential(*[
+ block_fn(
+ dim=vit_config.embed_dim, num_heads=vit_config.num_heads, mlp_ratio=vit_config.mlp_ratio, qkv_bias=vit_config.qkv_bias, init_values=vit_config.init_values,
+ drop=vit_config.drop_rate, attn_drop=vit_config.attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
+ for i in range(vit_config.depth)])
+ self.norm = norm_layer(vit_config.embed_dim) if not use_fc_norm else nn.Identity()
+
+ # Classifier Head
+ self.fc_norm = norm_layer(vit_config.embed_dim) if use_fc_norm else nn.Identity()
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ if vit_config.weight_init != 'skip':
+ self.init_weights(vit_config.weight_init)
+
+ def init_weights(self, mode=''):
+ assert mode in ('jax', 'jax_nlhb', 'moco', '')
+ head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
+ trunc_normal_(self.pos_embed, std=.02)
+ if self.cls_token is not None:
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(get_init_weights_vit(mode, head_bias), self)
+
+ def _init_weights(self, m):
+ # this fn left here for compat with downstream users
+ init_weights_vit_timm(m)
+
+ @torch.jit.ignore()
+ def load_weights(self, checkpoint_path, prefix=''):
+ _load_weights(self, checkpoint_path, prefix)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token', 'dist_token'}
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes: int, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ assert global_pool in ('', 'avg', 'token')
+ self.global_pool = global_pool
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ if self.cls_token is not None:
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = self.pos_drop(x + self.pos_embed)
+
+ if self.config["checkpoint_activations"]:
+ x = checkpoint_seq(self.blocks, x)
+ else:
+ x = self.blocks(x)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool:
+ x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
+ x = self.fc_norm(x)
+ return x if pre_logits else self.head(x)
+
+ def compute_loss(self, logits, labels):
+ loss_func = nn.CrossEntropyLoss()
+ return loss_func(logits, labels)
+
+ def forward(self, images=None, labels=None, **kwargs):
+
+ x = self.forward_features(images)
+ x = self.forward_head(x)
+ loss = None
+ if labels is not None:
+ loss = self.compute_loss(x, labels)
+ return_data = {"logits": x, "hidden_states": x, "loss": loss}
+
+ return return_data
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ''):
+ """ ViT weight initialization, original timm impl (for reproducibility) """
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif hasattr(module, 'init_weights'):
+ module.init_weights()
+
+
+def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.):
+ """ ViT weight initialization, matching JAX (Flax) impl """
+ if isinstance(module, nn.Linear):
+ if name.startswith('head'):
+ nn.init.zeros_(module.weight)
+ nn.init.constant_(module.bias, head_bias)
+ else:
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Conv2d):
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif hasattr(module, 'init_weights'):
+ module.init_weights()
+
+
+def init_weights_vit_moco(module: nn.Module, name: str = ''):
+ """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """
+ if isinstance(module, nn.Linear):
+ if 'qkv' in name:
+ # treat the weights of Q, K, V separately
+ val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
+ nn.init.uniform_(module.weight, -val, val)
+ else:
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif hasattr(module, 'init_weights'):
+ module.init_weights()
+
+
+def get_init_weights_vit(mode='jax', head_bias: float = 0.):
+ if 'jax' in mode:
+ return partial(init_weights_vit_jax, head_bias=head_bias)
+ elif 'moco' in mode:
+ return init_weights_vit_moco
+ else:
+ return init_weights_vit_timm
+
+
+@torch.no_grad()
+def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
+ """
+ import numpy as np
+
+ def _n2p(w, t=True):
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
+ w = w.flatten()
+ if t:
+ if w.ndim == 4:
+ w = w.transpose([3, 2, 0, 1])
+ elif w.ndim == 3:
+ w = w.transpose([2, 0, 1])
+ elif w.ndim == 2:
+ w = w.transpose([1, 0])
+ return torch.from_numpy(w)
+
+ w = np.load(checkpoint_path)
+ if not prefix and 'opt/target/embedding/kernel' in w:
+ prefix = 'opt/target/'
+
+ if hasattr(model.patch_embed, 'backbone'):
+ # hybrid
+ backbone = model.patch_embed.backbone
+ stem_only = not hasattr(backbone, 'stem')
+ stem = backbone if stem_only else backbone.stem
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
+ if not stem_only:
+ for i, stage in enumerate(backbone.stages):
+ for j, block in enumerate(stage.blocks):
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
+ for r in range(3):
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
+ if block.downsample is not None:
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
+ else:
+ embed_conv_w = adapt_input_conv(
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
+ if pos_embed_w.shape != model.pos_embed.shape:
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
+ model.pos_embed.copy_(pos_embed_w)
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
+ if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
+ model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
+ model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
+ # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
+ for i, block in enumerate(model.blocks.children()):
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
+ block.attn.qkv.weight.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
+ block.attn.qkv.bias.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
+ for r in range(2):
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
+
+def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+ ntok_new = posemb_new.shape[1]
+ if num_tokens:
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
+ ntok_new -= num_tokens
+ else:
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
+ gs_old = int(math.sqrt(len(posemb_grid)))
+ if not len(gs_new): # backwards compatibility
+ gs_new = [int(math.sqrt(ntok_new))] * 2
+ assert len(gs_new) >= 2
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+ return posemb
+
+def checkpoint_filter_fn(state_dict, model):
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ out_dict = {}
+ if 'model' in state_dict:
+ # For deit models
+ state_dict = state_dict['model']
+ for k, v in state_dict.items():
+ if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
+ # For old models that I trained prior to conv based patchification
+ O, I, H, W = model.patch_embed.proj.weight.shape
+ v = v.reshape(O, -1, H, W)
+ elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
+ # To resize pos embedding when using model at different size from pretrained weights
+ v = resize_pos_embed(
+ v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
+ elif 'pre_logits' in k:
+ # NOTE representation layer removed as not used in latest 21k/1k pretrained weights
+ continue
+ out_dict[k] = v
+ return out_dict
+
+
+
+
diff --git a/flagai/trainer.py b/flagai/trainer.py
index 7ad20a1c..f2fd8dd1 100644
--- a/flagai/trainer.py
+++ b/flagai/trainer.py
@@ -198,22 +198,30 @@ def __init__(
self.not_call_launch = True
self.deepspeed_config = deepspeed_config
self.model_parallel_size = model_parallel_size
- if 'deepspeed' in env_type or env_type == 'pytorchDDP':
+ self.num_nodes = num_nodes
+ self.num_gpus = num_gpus
+ self.master_ip = master_ip
+ self.master_port = master_port
+ self.hostfile = hostfile
+ self.training_script = training_script
+
+ if 'deepspeed' in self.env_type or self.env_type == 'pytorchDDP':
+ training_paras = self.get_dist_args()
# Implement for AutoLaunch
# >>> python train.py # will call get_dist_args()
# `--not_call_launch` is default 'False'
# So, if `env_type` is `pytorch`, the `Trainer` will not call lanch_dist()
# Otherwise, the lanch_dist() is called to launch 'train.py' with `--not_call_launch`
- self.get_dist_args()
if not self.not_call_launch:
launch_dist(launcher='distributed_deepspeed' if 'deepspeed'
in env_type else 'distributed_torch',
- num_nodes=num_nodes,
- gpus_per_node=num_gpus,
- master_addr=master_ip,
- master_port=master_port,
- hostfile=hostfile,
- training_script=training_script)
+ num_nodes=self.num_nodes,
+ gpus_per_node=self.num_gpus,
+ master_addr=self.master_ip,
+ master_port=self.master_port,
+ hostfile=self.hostfile,
+ training_script=self.training_script,
+ training_paras=training_paras)
os._exit(1)
self.initialize_distributed()
@@ -239,6 +247,7 @@ def get_dist_args(self):
self.master_addr = os.environ.get('MASTER_ADDR', '127.0.0.1')
self.master_port = os.environ.get('MASTER_PORT', '17500')
log_dist("not_call_launch: {}".format(ds_args.not_call_launch))
+ return []
def set_seed(self, seed=1234):
"""Set random seed for reproducability."""
@@ -310,6 +319,9 @@ def get_dataloader(self, dataset, collate_fn, shuffle=False):
else:
if self.env_type == 'deepspeed+mpu':
rank = mpu.get_model_parallel_src_rank()
+ print("*"*80)
+ print("local rank",self.rank, "model rank", rank)
+ print("*"*80)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
# num_replicas=num_replicas,
@@ -336,7 +348,8 @@ def train(self,
train_dataset=None,
valid_dataset=None,
metric_methods=[],
- collate_fn=None):
+ collate_fn=None,
+ find_unused_parameters=True):
"""Training Loops"""
"""
Trainer is a simple but unifed training and eval loop for PyTorch/Deepspeed/Megatron-LM.
@@ -404,7 +417,7 @@ def train(self,
model.to(torch.device('cuda', self.local_rank))
model = DDP(model,
device_ids=[self.local_rank],
- find_unused_parameters=True)
+ find_unused_parameters=find_unused_parameters)
elif self.env_type == 'pytorch':
model.to(self.pytorch_device)
@@ -506,7 +519,8 @@ def train(self,
lr_scheduler,
single_step=True)
dist.barrier()
- total_lm_loss += lm_loss.data.detach().float()
+ if lm_loss is not None:
+ total_lm_loss += lm_loss.data.detach().float()
# Logging.
if (self.iteration + 1) % self.log_interval == 0:
@@ -1025,3 +1039,4 @@ def evaluate_and_print_results(
log_dist(string, [0])
log_dist('-' * length, [0])
return eval_dict
+
diff --git a/flagai/utils.py b/flagai/utils.py
index 64845757..06c30026 100644
--- a/flagai/utils.py
+++ b/flagai/utils.py
@@ -206,8 +206,7 @@ def save_checkpoint(iteration,
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
if env_type == 'pytorch' or (env_type != 'deepspeed+mpu'
and dist.get_rank() == 0) or (
- env_type == 'deepspeed+mpu'
- and mpu.get_model_parallel_src_rank() == 0):
+ env_type == 'deepspeed+mpu'and mpu.get_model_parallel_src_rank() == 0):
ensure_directory_exists(checkpoint_name)
config_path = os.path.join(save_dir, str(iteration), 'config.json')
@@ -220,6 +219,7 @@ def save_checkpoint(iteration,
tracker_filename = get_checkpoint_tracker_filename(save_dir)
with open(tracker_filename, 'w') as f:
f.write(str(iteration) + '\t' + str(best_iteration))
+
# Wait so everyone is done (necessary)
if barrier and dist.is_initialized():
torch.distributed.barrier()
diff --git a/flagai_wechat.png b/flagai_wechat.png
index 387bded2..e9dd1d30 100644
Binary files a/flagai_wechat.png and b/flagai_wechat.png differ
diff --git a/setup.py b/setup.py
index 33c2090b..b36d19f8 100644
--- a/setup.py
+++ b/setup.py
@@ -5,8 +5,7 @@
setup(
name="flagai",
- version="v1.1.3",
-
+ version="v1.2.0",
description="FlagAI aims to help researchers and developers to freely train and test large-scale models for NLP tasks.",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",