Skip to content

Commit

Permalink
Merge pull request #6 from 920232796/master
Browse files Browse the repository at this point in the history
add altclip-m18
  • Loading branch information
Anhforth authored Mar 29, 2023
2 parents c582324 + 4e220f0 commit e55b20e
Show file tree
Hide file tree
Showing 10 changed files with 2,629 additions and 1 deletion.
441 changes: 441 additions & 0 deletions examples/AltCLIP-m18/README.md

Large diffs are not rendered by default.

89 changes: 89 additions & 0 deletions examples/AltCLIP-m18/altclip_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import torch
from flagai.auto_model.auto_loader import AutoLoader
import zeroshot_classification
import json
import os
from torchvision.datasets import CIFAR10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
maxlen = 256

dataset_root = "./clip_benchmark_datasets/"
dataset_name = "cifar10"

auto_loader = AutoLoader(
task_name="txt_img_matching",
model_dir="./checkpoints/",
model_name="AltCLIP-XLMR-L-m18" # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
)

model = auto_loader.get_model()
model.to(device)
model.eval()
tokenizer = auto_loader.get_tokenizer()
transform = auto_loader.get_transform()

dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name),
transform=transform,
download=True)
batch_size = 128
num_workers = 4

template = {"cifar10": [
"a photo of a {c}.",
"a blurry photo of a {c}.",
"a black and white photo of a {c}.",
"a low contrast photo of a {c}.",
"a high contrast photo of a {c}.",
"a bad photo of a {c}.",
"a good photo of a {c}.",
"a photo of a small {c}.",
"a photo of a big {c}.",
"a photo of the {c}.",
"a blurry photo of the {c}.",
"a black and white photo of the {c}.",
"a low contrast photo of the {c}.",
"a high contrast photo of the {c}.",
"a bad photo of the {c}.",
"a good photo of the {c}.",
"a photo of the small {c}.",
"a photo of the big {c}."
],
}
def evaluate():
if dataset:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
)

zeroshot_templates = template["cifar10"]
classnames = dataset.classes if hasattr(dataset, "classes") else None

metrics = zeroshot_classification.evaluate(
model,
dataloader,
tokenizer,
classnames,
zeroshot_templates,
device=device,
amp=True,
)

dump = {
"dataset": dataset_name,
"metrics": metrics
}

print(dump)
with open("./result.txt", "w") as f:
json.dump(dump, f)
return metrics

if __name__ == "__main__":
evaluate()
65 changes: 65 additions & 0 deletions examples/AltCLIP-m18/altclip_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import torch
from flagai.auto_model.auto_loader import AutoLoader
import os
from flagai.trainer import Trainer
from torchvision.datasets import (
CIFAR10
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset_root = "./clip_benchmark_datasets"
dataset_name = "cifar10"

batch_size = 4
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

auto_loader = AutoLoader(
task_name="txt_img_matching",
model_dir="./checkpoints",
model_name="AltCLIP-XLMR-L-m18" # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
)

model = auto_loader.get_model()
model.to(device)
model.eval()
tokenizer = auto_loader.get_tokenizer()
transform = auto_loader.get_transform()

trainer = Trainer(env_type="pytorch",
pytorch_device=device,
experiment_name="clip_finetuning",
batch_size=4,
lr=1e-4,
epochs=10,
log_interval=10)

dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name),
transform=transform,
download=True)

def cifar10_collate_fn(batch):
# image shape is (batch, 3, 224, 224)
images = torch.tensor([b[0]["pixel_values"][0] for b in batch])
# text_id shape is (batch, n)
input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}",
padding=True,
truncation=True,
max_length=77)["input_ids"] for b in batch])

attention_mask = torch.tensor([tokenizer(f"a photo of a {b[1]}",
padding=True,
truncation=True,
max_length=77)["attention_mask"] for b in batch])

return {
"pixel_values": images,
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if __name__ == "__main__":
trainer.train(model=model, train_dataset=dataset, collate_fn=cifar10_collate_fn)
41 changes: 41 additions & 0 deletions examples/AltCLIP-m18/altclip_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from PIL import Image
from flagai.auto_model.auto_loader import AutoLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loader = AutoLoader(
task_name="txt_img_matching",
model_name="AltCLIP-XLMR-L-m18", # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
model_dir="./checkpoints"
)

model = loader.get_model()
tokenizer = loader.get_tokenizer()
transform = loader.get_transform()

model.eval()
model.to(device)
tokenizer = loader.get_tokenizer()

def inference():
image = Image.open("./dog.jpeg")
image = transform(image)
image = torch.tensor(image["pixel_values"]).to(device)
tokenizer_out = tokenizer(["a rat", "a dog", "a cat"],
padding=True,
truncation=True,
max_length=77,
return_tensors='pt')

text = tokenizer_out["input_ids"].to(device)
attention_mask = tokenizer_out["attention_mask"].to(device)
with torch.no_grad():
image_features = model.get_image_features(image)
text_features = model.get_text_features(text, attention_mask=attention_mask)
text_probs = (image_features @ text_features.T).softmax(dim=-1)

print(text_probs.cpu().numpy()[0].tolist())

if __name__=="__main__":
inference()
Binary file added examples/AltCLIP-m18/dog.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit e55b20e

Please sign in to comment.