Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verify hash of onnx model #883

Closed
jeffchuber opened this issue Jul 26, 2023 · 5 comments
Closed

Verify hash of onnx model #883

jeffchuber opened this issue Jul 26, 2023 · 5 comments
Assignees
Labels
good first issue Good for newcomers Local Chroma An improvement to Local (single node) Chroma p0 taz-sprint-2 to-discuss

Comments

@jeffchuber
Copy link
Contributor

Sometimes the onnx model does not download correctly

https://github.com/chroma-core/chroma/blob/main/chromadb/utils/embedding_functions.py#L202

We should store a hash of the model and then compare the downloaded version to verify that we have the right thing. If not, then we should redownload.

@jeffchuber jeffchuber added the Local Chroma An improvement to Local (single node) Chroma label Jul 26, 2023
@HammadB
Copy link
Collaborator

HammadB commented Aug 14, 2023

Related #976 where they are trying to avoid downloading the archive on boot and packaging it in instead.

@Josh-XT
Copy link
Contributor

Josh-XT commented Aug 14, 2023

A note on this issue - I don't entirely agree personally. Some people may choose to try to use a different model or modify the model. I would recommend making this as flexible as possible and set the defaults that make sense.

@Josh-XT
Copy link
Contributor

Josh-XT commented Aug 14, 2023

I had to build my solution into AGiXT separately until the downloader is fixed from my PR #976 , but here is my slightly modified ONNX embedder class that I am currently using. I put the onnx.tar.gz file in the root of my repo for now, but may just make it download on docker build. I am trying to avoid having users nailing the download constantly for a smoother user experience and to avoid downloads. Those who use my software are very likely to want to play with hacking the model at some point, so I wouldn't personally make hash comparisons unless you have a reason you don't want people to try to use a different model.

HOME_DIR = os.getcwd()

class ONNX(EmbeddingFunction):
    def __init__(
        self,
        MODEL_NAME: str = "all-MiniLM-L6-v2",
        DOWNLOAD_PATH: str = HOME_DIR,
        EXTRACTED_FOLDER_NAME="onnx",
        ARCHIVE_FILENAME="onnx.tar.gz",
        MODEL_DOWNLOAD_URL=(
            "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz"
        ),
        tokenizer=None,
        model=None,
    ):
        # Import dependencies on demand to mirror other embedding functions. This
        # breaks typechecking, thus the ignores.
        self.MODEL_NAME = MODEL_NAME if MODEL_NAME else "all-MiniLM-L6-v2"
        self.DOWNLOAD_PATH = DOWNLOAD_PATH if DOWNLOAD_PATH else HOME_DIR
        self.EXTRACTED_FOLDER_NAME = (
            EXTRACTED_FOLDER_NAME if EXTRACTED_FOLDER_NAME else "onnx"
        )
        self.ARCHIVE_FILENAME = ARCHIVE_FILENAME if ARCHIVE_FILENAME else "onnx.tar.gz"
        self.MODEL_DOWNLOAD_URL = (
            MODEL_DOWNLOAD_URL
            if MODEL_DOWNLOAD_URL
            else "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz"
        )
        self.tokenizer = tokenizer
        self.model = model
        try:
            # Equivalent to import onnxruntime
            self.ort = importlib.import_module("onnxruntime")
        except ImportError:
            raise ValueError(
                "The onnxruntime python package is not installed. Please install it with `pip install onnxruntime`"
            )
        try:
            # Equivalent to from tokenizers import Tokenizer
            self.Tokenizer = importlib.import_module("tokenizers").Tokenizer
        except ImportError:
            raise ValueError(
                "The tokenizers python package is not installed. Please install it with `pip install tokenizers`"
            )
        try:
            # Equivalent to from tqdm import tqdm
            self.tqdm = importlib.import_module("tqdm").tqdm
        except ImportError:
            raise ValueError(
                "The tqdm python package is not installed. Please install it with `pip install tqdm`"
            )

    # Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
    # Download with tqdm to preserve the sentence-transformers experience
    def _download(self, url: str, fname: Path, chunk_size: int = 1024) -> None:
        resp = requests.get(url, stream=True)
        total = int(resp.headers.get("content-length", 0))
        with open(fname, "wb") as file, self.tqdm(
            desc=str(fname),
            total=total,
            unit="iB",
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for data in resp.iter_content(chunk_size=chunk_size):
                size = file.write(data)
                bar.update(size)

    # Use pytorches default epsilon for division by zero
    # https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
    def _normalize(self, v: npt.NDArray) -> npt.NDArray:
        norm = np.linalg.norm(v, axis=1)
        norm[norm == 0] = 1e-12
        return v / norm[:, np.newaxis]

    def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray:
        # We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values
        self.tokenizer = cast(self.Tokenizer, self.tokenizer)  # type: ignore
        self.model = cast(self.ort.InferenceSession, self.model)  # type: ignore
        all_embeddings = []
        for i in range(0, len(documents), batch_size):
            batch = documents[i : i + batch_size]
            encoded = [self.tokenizer.encode(d) for d in batch]
            input_ids = np.array([e.ids for e in encoded])
            attention_mask = np.array([e.attention_mask for e in encoded])
            onnx_input = {
                "input_ids": np.array(input_ids, dtype=np.int64),
                "attention_mask": np.array(attention_mask, dtype=np.int64),
                "token_type_ids": np.array(
                    [np.zeros(len(e), dtype=np.int64) for e in input_ids],
                    dtype=np.int64,
                ),
            }
            model_output = self.model.run(None, onnx_input)
            last_hidden_state = model_output[0]
            # Perform mean pooling with attention weighting
            input_mask_expanded = np.broadcast_to(
                np.expand_dims(attention_mask, -1), last_hidden_state.shape
            )
            embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(
                input_mask_expanded.sum(1), a_min=1e-9, a_max=None
            )
            embeddings = self._normalize(embeddings).astype(np.float32)
            all_embeddings.append(embeddings)
        return np.concatenate(all_embeddings)

    def _init_model_and_tokenizer(self) -> None:
        if self.model is None and self.tokenizer is None:
            self.tokenizer = self.Tokenizer.from_file(
                os.path.join(
                    self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json"
                )
            )
            # max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128
            # https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480
            self.tokenizer.enable_truncation(max_length=256)
            self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256)
            self.model = self.ort.InferenceSession(
                os.path.join(
                    self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"
                )
            )

    def __call__(self, texts: Documents) -> Embeddings:
        # Only download the model when it is actually used
        self._download_model_if_not_exists()
        self._init_model_and_tokenizer()
        res = cast(Embeddings, self._forward(texts).tolist())
        return res

    def _download_model_if_not_exists(self) -> None:
        # Model is not downloaded yet
        if not os.path.exists(
            os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx")
        ):
            os.makedirs(self.DOWNLOAD_PATH, exist_ok=True)
            if not os.path.exists(
                os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME)
            ):
                self._download(
                    self.MODEL_DOWNLOAD_URL, self.DOWNLOAD_PATH / self.ARCHIVE_FILENAME
                )
            with tarfile.open(
                name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
                mode="r:gz",
            ) as tar:
                tar.extractall(path=self.DOWNLOAD_PATH)

@HammadB
Copy link
Collaborator

HammadB commented Aug 14, 2023

Some people may choose to try to use a different model or modify the model. I would recommend making this as flexible as possible and set the defaults that make sense.

If that is the intent then they should use a different embedding function. this one is meant for this specific model. Its coupled fairly tightly already to this model and not meant as a generic ONNX runner.

@jeffchuber jeffchuber added the good first issue Good for newcomers label Sep 13, 2023
tazarov added a commit to amikos-tech/chroma-core that referenced this issue Dec 11, 2023
tazarov added a commit to amikos-tech/chroma-core that referenced this issue Dec 12, 2023
- Added up to three retries of downloading the model and then rethrowing the exception.

Refs: chroma-core#883
tazarov added a commit to amikos-tech/chroma-core that referenced this issue Dec 19, 2023
HammadB pushed a commit that referenced this issue Dec 19, 2023
Refs: #883

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Verify ONNX all-MiniLM-L6 model model download from s3 with static
SHA256 (within the python code)

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python

## Documentation Changes
N/A
@HammadB
Copy link
Collaborator

HammadB commented Jan 15, 2024

Done with #1493

@HammadB HammadB closed this as completed Jan 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers Local Chroma An improvement to Local (single node) Chroma p0 taz-sprint-2 to-discuss
Projects
None yet
Development

No branches or pull requests

4 participants