Skip to content

Commit

Permalink
formatting and linting
Browse files Browse the repository at this point in the history
  • Loading branch information
devpramod committed Jun 3, 2024
1 parent 9ffa54f commit e031c84
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

# Standard
from collections.abc import Sized
from contextlib import nullcontext
from enum import Enum, auto
from typing import Callable, Dict, List, NamedTuple, Optional, TypeVar, Union
import importlib
from contextlib import nullcontext
import os
import time

Expand Down Expand Up @@ -790,11 +790,13 @@ def sum_token_count(

class SentenceTransformerWithTruncate(SentenceTransformer):
def __init__(self, *args, **kwargs):
super(SentenceTransformerWithTruncate, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
if GRAPH_MODE:
# Initialize the compiled model right after the base class initialization
self.compiled_model = self._apply_graph_mode() # Compile and store the graph model

self.compiled_model = (
self._apply_graph_mode()
) # Compile and store the graph model

def _truncate_input_tokens(
self,
truncate_input_tokens: int,
Expand Down Expand Up @@ -897,19 +899,20 @@ def _truncate_input_tokens(
)

return TruncatedTokensTuple(tokenized, input_token_count)

def _apply_graph_mode(self) -> torch.jit.ScriptModule:
"""
Compiles the model into a TorchScript graph using predefined fixed-size randomized input tensors.
The tensors simulate typical input structures without relying on actual input feature data.
Compiles the model into a TorchScript graph using predefined fixed-size randomized
input tensors.The tensors simulate typical input structures without relying
on actual input feature data.
:return: A TorchScript graph that is optimized for inference.
"""
self.eval()

max_seq_length = self.max_seq_length
vocab_size = self.tokenizer.vocab_size

# Generate random input_ids within the vocabulary range and a full attention mask
input_ids = torch.randint(low=0, high=vocab_size, size=(1, max_seq_length))
attention_mask = torch.ones(1, max_seq_length).int()
Expand All @@ -921,18 +924,19 @@ def _apply_graph_mode(self) -> torch.jit.ScriptModule:
# Trace the model with the synthetic input to create a TorchScript graph
compiled_graph = torch.jit.trace(
self,
({
"input_ids": input_ids,
"attention_mask": attention_mask,
}),
(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
}
),
strict=False,
)

# Freeze the compiled graph to optimize it for runtime performance
compiled_graph = torch.jit.freeze(compiled_graph)

return compiled_graph

return compiled_graph

def encode(
self,
Expand Down

0 comments on commit e031c84

Please sign in to comment.