Skip to content

Commit

Permalink
Update UNI download logic (#756)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig authored Feb 4, 2025
1 parent 7844be6 commit c60fe16
Showing 1 changed file with 3 additions and 19 deletions.
22 changes: 3 additions & 19 deletions src/eva/vision/models/networks/backbones/pathology/mahmood.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
"""Pathology FMs from MahmoodLab."""

import os
from pathlib import Path
from typing import Tuple

import huggingface_hub
import timm
import torch
from loguru import logger
from torch import nn

from eva.vision.models import wrappers
Expand All @@ -20,7 +16,6 @@ def mahmood_uni(
dynamic_img_size: bool = True,
out_indices: int | Tuple[int, ...] | None = None,
hf_token: str | None = None,
download_dir: str = os.path.join(str(Path.home()), ".cache/eva"),
) -> nn.Module:
"""Initializes UNI model from MahmoodLab.
Expand All @@ -29,31 +24,20 @@ def mahmood_uni(
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
out_indices: Whether and which multi-level patch embeddings to return.
hf_token: HuggingFace token to download the model.
download_dir: Directory to download the model checkpoint.
Returns:
The model instance.
"""
checkpoint_path = os.path.join(download_dir, "uni.bin")
if not os.path.exists(checkpoint_path):
logger.info(f"Downloading the model checkpoint to {download_dir} ...")
os.makedirs(download_dir, exist_ok=True)
_utils.huggingface_login(hf_token)
huggingface_hub.hf_hub_download(
"MahmoodLab/UNI",
filename="uni.bin",
local_dir=download_dir,
force_download=True,
)
_utils.huggingface_login(hf_token)

return wrappers.TimmModel(
model_name="vit_large_patch16_224",
model_name="hf-hub:MahmoodLab/uni",
pretrained=True,
out_indices=out_indices,
model_kwargs={
"init_values": 1e-5,
"dynamic_img_size": dynamic_img_size,
},
checkpoint_path=checkpoint_path,
)


Expand Down

0 comments on commit c60fe16

Please sign in to comment.