Skip to content

Commit

Permalink
Feat: Add jasper (#1591)
Browse files Browse the repository at this point in the history
* init jasper

* init jasper

* add to overview

* add to overview

* remove some params

* fix max length

* return sdpa

* add dtype

* add dtype

* fix convert_to_tensor

* change to encode

* return whitespace processing

* explicitly add instructions

* move seq length

* try float

* fix max_seq_length

* add prompt validation to format instruction

* don't use instructions only to s2p
  • Loading branch information
Samoed authored Dec 23, 2024
1 parent d8dd96c commit ef5a068
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 6 deletions.
5 changes: 0 additions & 5 deletions mteb/models/instruct_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,4 @@ def encode(
embeddings = embeddings.cpu().detach().float().numpy()
return embeddings

def format_instruction(self, instruction: str) -> str:
if isinstance(self.instruction_template, str):
return self.instruction_template.format(instruction=instruction)
return self.instruction_template(instruction)

return InstructWrapper(model_name_or_path, mode, instruction_template, **kwargs)
96 changes: 96 additions & 0 deletions mteb/models/jasper_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import logging
from collections.abc import Sequence
from functools import partial
from typing import Any, Callable

import numpy as np
import torch
from sentence_transformers import SentenceTransformer

import mteb
from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta

from .wrapper import Wrapper

logger = logging.getLogger(__name__)


class JasperWrapper(Wrapper):
def __init__(
self,
model_name: str,
revision: str,
instruction_template: str | Callable[[str], str] | None = None,
max_seq_length: int = 2048,
**kwargs: Any,
):
self.model_name = model_name
self.model = SentenceTransformer(model_name, revision=revision, **kwargs)
self.instruction_template = instruction_template
self.model.max_seq_length = max_seq_length

def encode(
self,
sentences: Sequence[str],
*,
task_name: str,
prompt_type: PromptType | None = None,
**kwargs: Any,
) -> np.ndarray:
task = mteb.get_task(task_name=task_name)
instruction = self.get_task_instruction(task_name, prompt_type)

# to passage prompts won't be applied to passages
if prompt_type == PromptType.passage and task.metadata.type == "s2p":
instruction = None

embeddings = self.model.encode(
sentences,
normalize_embeddings=True,
prompt=instruction,
**kwargs,
)

if isinstance(embeddings, torch.Tensor):
# sometimes in kwargs can be return_tensors=True
embeddings = embeddings.cpu().detach().float().numpy()
return embeddings


jasper_en_v1 = ModelMeta(
loader=partial( # type: ignore
JasperWrapper,
model_name="infgrad/jasper_en_vision_language_v1",
revision="d6330ce98f8a0d741e781df845904c9484f00efa",
config_kwargs={"is_text_encoder": True, "vector_dim": 12288},
model_kwargs={
"attn_implementation": "sdpa",
"torch_dtype": torch.float16,
},
trust_remote_code=True,
max_seq_length=2048,
instruction_template="Instruct: {instruction}\nQuery: ",
),
name="infgrad/jasper_en_vision_language_v1",
languages=["eng-Latn"],
open_weights=True,
revision="d6330ce98f8a0d741e781df845904c9484f00efa",
release_date="2024-12-11", # first commit
n_parameters=1_999_000_000,
memory_usage=None,
max_tokens=131072,
embed_dim=8960,
license="apache-2.0",
reference="https://huggingface.co/infgrad/jasper_en_vision_language_v1/tree/main",
similarity_fn_name="cosine",
framework=["Sentence Transformers", "PyTorch"],
use_instructions=True,
adapted_from=None,
superseded_by=None,
training_datasets={
"non_mteb": ["BAAI/Infinity-MM", "HuggingFaceFW/fineweb-edu"],
},
)
8 changes: 8 additions & 0 deletions mteb/models/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
gritlm_models,
gte_models,
ibm_granite_models,
jasper_models,
jina_models,
linq_models,
llm2vec_models,
Expand Down Expand Up @@ -74,6 +75,13 @@
ru_sentence_models,
salesforce_models,
sentence_transformers_models,
voyage_models,
google_models,
repllama_models,
promptriever_models,
jina_models,
jasper_models,
uae_models,
stella_models,
uae_models,
voyage_models,
Expand Down
21 changes: 20 additions & 1 deletion mteb/models/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import get_args
from typing import Callable, get_args

import mteb
from mteb.abstasks.TaskMetadata import TASK_TYPE
Expand All @@ -15,6 +15,8 @@ class Wrapper:
Also contains some utility functions for wrappers for working with prompts and instructions.
"""

instruction_template: str | Callable[[str], str] | None = None

@staticmethod
def get_prompt_name(
task_to_prompt: dict[str, str] | None,
Expand Down Expand Up @@ -100,3 +102,20 @@ def get_instruction(task_name: str, prompt_type: PromptType | None) -> str:
if task_metadata.prompt:
return task_metadata.prompt
return task.abstask_prompt

def format_instruction(self, instruction: str) -> str:
if isinstance(self.instruction_template, str):
if "{instruction}" not in self.instruction_template:
raise ValueError(
"Instruction template must contain the string '{instruction}'."
)
return self.instruction_template.format(instruction=instruction)
return self.instruction_template(instruction)

def get_task_instruction(
self, task_name: str, prompt_type: PromptType | None
) -> str:
instruction = self.get_instruction(task_name, prompt_type)
if self.instruction_template:
return self.format_instruction(instruction)
return instruction

0 comments on commit ef5a068

Please sign in to comment.