diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index c46d292..7946741 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -53,6 +53,7 @@ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", metaclip_400m=("https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt", "3c68642594a329afc1ec0fe489ee2b58ab19c9d0556ccf7c404a59baa0762d71"), + metaclip2_5b=("https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt", "885b7ec11fe07a9826e2e6812d70e5011918e32fe9b12136b49d5dded92b4386"), metaclip_fullcc=("https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt", "885b7ec11fe07a9826e2e6812d70e5011918e32fe9b12136b49d5dded92b4386"), ) @@ -64,6 +65,7 @@ _VITB16_quickgelu = dict( metaclip_400m=("https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt", "68dfb5996c52a8f4fecb9bd16601e97e1895236645082778bd9cede8429a8d49"), + metaclip2_5b=("https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt", "512ea0fb9f2cf88d027e96e4674247a1a91a96af18abc2e2fcdb8008c551e04b"), metaclip_fullcc=("https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt", "512ea0fb9f2cf88d027e96e4674247a1a91a96af18abc2e2fcdb8008c551e04b"), ) @@ -80,6 +82,7 @@ _VITL14_quickgelu = dict( metaclip_400m=("https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt", "51c782959f920b030779e494517b8d545f56794df6b0a2796a4c310455a361be"), + metaclip2_5b=("https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt", "ce24750710544ee288ef0abdead2016730da1893a1d07447bda3a75e1c148f97"), metaclip_fullcc=("https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt", "ce24750710544ee288ef0abdead2016730da1893a1d07447bda3a75e1c148f97"), ) @@ -88,6 +91,7 @@ ) _VITH14_quickgelu = dict( + metaclip2_5b=("https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt", "1286807d5cc8d9a0b12563b47474efb53b9522eb3d7eac5a9a5d39c3a776ad5c"), metaclip_fullcc=("https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt", "1286807d5cc8d9a0b12563b47474efb53b9522eb3d7eac5a9a5d39c3a776ad5c"), ) @@ -186,4 +190,4 @@ def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip" if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") - return download_target + return download_target \ No newline at end of file diff --git a/tests/test.py b/tests/test.py new file mode 100644 index 0000000..0434d3e --- /dev/null +++ b/tests/test.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +import torch +from PIL import Image +from open_clip import tokenizer +import open_clip +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +def test_inference(): + for model_name in ["ViT-B-32", "ViT-B-32-quickgelu", "ViT-B-16", "ViT-L-14"]: + for pretrained in ["metaclip400m", "metaclip2_5b"]: + model, _, preprocess = open_clip.create_model_and_transforms( + model_name, pretrained=pretrained + ) + + current_dir = os.path.dirname(os.path.realpath(__file__)) + + image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze( + 0 + ) + text = tokenizer.tokenize(["a diagram", "a dog", "a cat"]) + + with torch.no_grad(): + image_features = model.encode_image(image) + text_features = model.encode_text(text) + + text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) + + assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0]