Skip to content

Commit

Permalink
Add multi api inference engine
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Nov 12, 2024
1 parent 2d46594 commit de868ab
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 17 deletions.
30 changes: 30 additions & 0 deletions examples/evaluate_benchmark_with_custom_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import unitxt
from unitxt import evaluate, get_from_catalog, load_dataset
from unitxt.text_utils import print_dict

with unitxt.settings.context(
default_inference_api="watsonx", # option a to define your home api
default_format="formats.chat_api",
disable_hf_datasets_cache=False,
):
data = load_dataset("benchmarks.glue[max_samples_per_subset=5]", split="test")

model = get_from_catalog(
"engines.model.llama_3_8b_instruct[api=watsonx]"
) # option b to define your home api

predictions = model.infer(data)

evaluated_dataset = evaluate(predictions=predictions, data=data)

print_dict(
evaluated_dataset[0],
keys_to_print=[
"source",
"prediction",
"subset",
],
)
print_dict(
evaluated_dataset[0]["score"]["subsets"],
)
16 changes: 16 additions & 0 deletions prepare/engines/multi_api/llama3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from unitxt.catalog import add_to_catalog
from unitxt.inference import MultiAPIInferenceEngine

engine = MultiAPIInferenceEngine(
model="llama-3-8b-instruct",
api_model_map={
"watsonx": {
"llama-3-8b-instruct": "watsonx/meta-llama/llama-3-8b-instruct",
},
"together-ai": {
"llama-3-8b-instruct": "together_ai/togethercomputer/llama-3-8b-instruct"
},
},
)

add_to_catalog(engine, "engines.model.llama_3_8b_instruct", overwrite=True)
12 changes: 12 additions & 0 deletions src/unitxt/catalog/engines/model/llama_3_8b_instruct.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"__type__": "multi_api_inference_engine",
"model": "llama-3-8b-instruct",
"api_model_map": {
"watsonx": {
"llama-3-8b-instruct": "watsonx/meta-llama/llama-3-8b-instruct"
},
"together-ai": {
"llama-3-8b-instruct": "together_ai/togethercomputer/llama-3-8b-instruct"
}
}
}
100 changes: 85 additions & 15 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,9 +1121,9 @@ def _infer(
model, params = self._load_model_and_params()

result = []
for instance in dataset:
for source in dataset["source"]:
instance_result = model.generate(
prompt=instance["source"],
prompt=source,
params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
)
prediction = instance_result["results"][0]["generated_text"]
Expand Down Expand Up @@ -1364,9 +1364,7 @@ class LMMSEvalBaseInferenceEngine(
batch_size: int = 1
image_token = "<image>"

_requirements_list = {
"lmms_eval": "Install llms-eval package using 'pip install lmms-eval==0.2.4'",
}
_requirements_list = ["lmms-eval==0.2.4"]

def prepare_engine(self):
if not self.lazy_load:
Expand Down Expand Up @@ -1413,6 +1411,7 @@ def _infer(
dataset: Union[List[Dict[str, Any]], DatasetDict],
return_meta_data: bool = False,
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
self.verify_not_chat_api(dataset)
if not self._is_loaded():
self._prepare_engine()

Expand Down Expand Up @@ -1562,12 +1561,26 @@ async def acquire(self, tokens=1):
await asyncio.sleep(time_until_next_token)


class LiteLLMInferenceEngine(InferenceEngine, PackageRequirementsMixin):
class StandardAPIParamsMixin(Artifact):
model: str
max_tokens: int = 256
seed: int = 1
temperature: float = 0.0
top_p: float = 1.0
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Union[Optional[str], List[str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_logprobs: Optional[int] = 20
logit_bias: Optional[Dict[str, int]] = None
logprobs: Optional[bool] = True
n: Optional[int] = None
parallel_tool_calls: Optional[bool] = None
service_tier: Optional[Literal["auto", "default"]] = None


class LiteLLMInferenceEngine(
InferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin
):
max_requests_per_second: float = 6
max_retries: int = 5 # Set to 0 to prevent internal retries

Expand Down Expand Up @@ -1599,15 +1612,12 @@ async def _infer_instance(
# Introduce a slight delay to prevent burstiness
await asyncio.sleep(0.01)
messages = self.to_messages(instance)
kwargs = self.to_dict([StandardAPIParamsMixin])
response = await self._completion(
model=self.model,
messages=messages,
seed=self.seed,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
max_retries=self.max_retries,
caching=True,
**kwargs,
)
usage = response.get("usage", {})
return TextGenerationInferenceOutput(
Expand Down Expand Up @@ -1643,3 +1653,63 @@ def _infer(
return responses

return [response.prediction for response in responses]


_supported_apis = Literal["watsonx", "together-ai", "open-ai"]


class MultiAPIInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
"""Inference engine capable of dynamically switching between multiple APIs.
This class extends the InferenceEngine and OpenAiInferenceEngineParamsMixin
to enable seamless integration with various API providers. The supported APIs are
specified in `_supported_apis`, allowing users to interact with multiple models
from different sources. The `api_model_map` dictionary maps each API to
specific model identifiers, enabling automatic configuration based on
user requests.
Attributes:
api: Optional; Specifies the current API in use. Must be one of the
literals in `_supported_apis`.
api_model_map: Dictionary mapping each supported API to a corresponding
model identifier string. This mapping allows consistent access to models
across different API backends.
"""

api: Optional[_supported_apis] = None

api_model_map: Dict[_supported_apis, Dict[str, str]] = {
"watsonx": {
"llama-3-8b-instruct": "watsonx/meta-llama/llama-3-8b-instruct",
},
"together-ai": {
"llama-3-8b-instruct": "together_ai/togethercomputer/llama-3-8b-instruct"
},
}

_api_to_base_class = {
"watsonx": LiteLLMInferenceEngine,
"open-ai": LiteLLMInferenceEngine,
"together-ai": LiteLLMInferenceEngine,
}

def get_api_name(self):
return self.api if self.api is not None else settings.default_inference_api

def prepare_engine(self):
api = self.get_api_name()
cls = self.__class__._api_to_base_class[api]
args = self.to_dict([OpenAiInferenceEngineParamsMixin])
args["model"] = self.api_model_map[api][self.model]
self.engine = cls(**args)

def _infer(
self,
dataset: List[Dict[str, Any]] | DatasetDict,
return_meta_data: bool = False,
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
return self.engine._infer(dataset, return_meta_data)

def get_engine_id(self):
api = self.get_api_name()
return get_model_and_label_id(self.api_model_map[api][self.model], api)
2 changes: 2 additions & 0 deletions src/unitxt/settings_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def __getattr__(self, key):
settings.disable_hf_datasets_cache = (bool, True)
settings.loader_cache_size = (int, 1)
settings.task_data_as_text = (bool, True)
settings.default_inference_api = "watsonx"
settings.default_format = None

if Constants.is_uninitilized():
constants = Constants()
Expand Down
12 changes: 10 additions & 2 deletions src/unitxt/standard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional, Union

from .artifact import fetch_artifact
from .augmentors import (
Augmentor,
FinalStateInputsAugmentor,
Expand All @@ -16,7 +17,7 @@
from .recipe import Recipe
from .schema import FinalizeDataset
from .serializers import SingleTypeSerializer
from .settings_utils import get_constants
from .settings_utils import get_constants, get_settings
from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
from .stream import MultiStream
from .system_prompts import EmptySystemPrompt, SystemPrompt
Expand All @@ -25,6 +26,7 @@
from .utils import LRUCache

constants = get_constants()
settings = get_settings()
logger = get_logger()


Expand All @@ -39,7 +41,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
task: Task = None
template: Union[Template, List[Template], TemplatesList] = None
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
format: Format = Field(default_factory=SystemFormat)
format: Format = None
serializer: Union[SingleTypeSerializer, List[SingleTypeSerializer]] = None

# Additional parameters
Expand Down Expand Up @@ -263,6 +265,12 @@ def produce(self, task_instances):
return list(multi_stream[constants.inference_stream])

def reset_pipeline(self):
if self.format is None:
if settings.default_format is not None:
self.format, _ = fetch_artifact(settings.default_format)
else:
self.format = SystemFormat()

if self.card and self.card.preprocess_steps is None:
self.card.preprocess_steps = []

Expand Down

0 comments on commit de868ab

Please sign in to comment.