Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Feb 7, 2023
1 parent aca9f38 commit 70314f0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/torchmetrics/functional/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def _clip_score_update(
f"Expected the number of images and text examples to be the same but got {len(images)} and {len(text)}"
)
device = images[0].device
processed_input = processor(text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True)
processed_input = processor(
text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True
) # type: ignore

img_features = model.get_image_features(processed_input["pixel_values"].to(device))
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/text/infolm.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def _get_batch_distribution(
for mask_idx in range(seq_len):
input_ids = batch["input_ids"].clone()
input_ids[:, mask_idx] = special_tokens_map["mask_token_id"]
logits_distribution = model(input_ids, batch["attention_mask"]).logits
logits_distribution = model(input_ids, batch["attention_mask"]).logits # type: ignore
# [batch_size, seq_len, vocab_size] -> [batch_size, vocab_size]
logits_distribution = logits_distribution[:, mask_idx, :]
prob_distribution = F.softmax(logits_distribution / temperature, dim=-1)
Expand Down

0 comments on commit 70314f0

Please sign in to comment.