From 81451de914a742a999855ceeeed0639cfd6ed028 Mon Sep 17 00:00:00 2001 From: thejumpman2323 <kartik@superduperdb.com> Date: Mon, 29 Jan 2024 18:04:14 +0530 Subject: [PATCH] Replace identifier with model in openai models --- superduperdb/ext/openai/model.py | 53 ++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/superduperdb/ext/openai/model.py b/superduperdb/ext/openai/model.py index c5f5dbdd9..12677564e 100644 --- a/superduperdb/ext/openai/model.py +++ b/superduperdb/ext/openai/model.py @@ -2,6 +2,7 @@ import base64 import dataclasses as dc import itertools +import json import os import typing as t @@ -28,8 +29,9 @@ @cache -def _available_models(): - return tuple([r.id for r in SyncOpenAI().models.list().data]) +def _available_models(skwargs): + kwargs = json.loads(skwargs) + return tuple([r.id for r in SyncOpenAI(**kwargs).models.list().data]) @dc.dataclass(kw_only=True) @@ -45,18 +47,25 @@ class _OpenAI(APIModel): def __post_init__(self): super().__post_init__() + assert isinstance(self.client_kwargs, dict) + # dall-e is not currently included in list returned by OpenAI model endpoint - if self.model not in (mo := _available_models()) and self.model not in ( - 'dall-e' - ): + if self.model not in ( + mo := _available_models(json.dumps(self.client_kwargs)) + ) and self.model not in ('dall-e'): msg = f'model {self.model} not in OpenAI available models, {mo}' raise ValueError(msg) self.syncClient = SyncOpenAI(**self.client_kwargs) self.asyncClient = AsyncOpenAI(**self.client_kwargs) - if 'OPENAI_API_KEY' not in os.environ: - raise ValueError('OPENAI_API_KEY not set') + if 'OPENAI_API_KEY' not in os.environ or ( + 'api_key' not in self.client_kwargs.keys() and self.client_kwargs + ): + raise ValueError( + 'OPENAI_API_KEY not available neither in environment vars ' + 'nor in `client_kwargs`' + ) @dc.dataclass(kw_only=True) @@ -87,27 +96,25 @@ def pre_create(self, db): @retry def _predict_one(self, X: str, **kwargs): - e = self.syncClient.embeddings.create(input=X, model=self.identifier, **kwargs) + e = self.syncClient.embeddings.create(input=X, model=self.model, **kwargs) return e.data[0].embedding @retry async def _apredict_one(self, X: str, **kwargs): e = await self.asyncClient.embeddings.create( - input=X, model=self.identifier, **kwargs + input=X, model=self.model, **kwargs ) return e.data[0].embedding @retry def _predict_a_batch(self, texts: t.List[str], **kwargs): - out = self.syncClient.embeddings.create( - input=texts, model=self.identifier, **kwargs - ) + out = self.syncClient.embeddings.create(input=texts, model=self.model, **kwargs) return [r.embedding for r in out.data] @retry async def _apredict_a_batch(self, texts: t.List[str], **kwargs): out = await self.asyncClient.embeddings.create( - input=texts, model=self.identifier, **kwargs + input=texts, model=self.model, **kwargs ) return [r.embedding for r in out.data] @@ -162,7 +169,7 @@ def _predict_one(self, X, context: t.Optional[t.List[str]] = None, **kwargs): return ( self.syncClient.chat.completions.create( messages=[{'role': 'user', 'content': X}], - model=self.identifier, + model=self.model, **kwargs, ) .choices[0] @@ -177,7 +184,7 @@ async def _apredict_one(self, X, context: t.Optional[t.List[str]] = None, **kwar ( await self.asyncClient.chat.completions.create( messages=[{'role': 'user', 'content': X}], - model=self.identifier, + model=self.model, **kwargs, ) ) @@ -484,7 +491,7 @@ def _predict_one( self.prompt = self.prompt.format(context='\n'.join(context)) return self.syncClient.audio.transcriptions.create( file=file, - model=self.identifier, + model=self.model, prompt=self.prompt, **kwargs, ).text @@ -499,7 +506,7 @@ async def _apredict_one( return ( await self.asyncClient.audio.transcriptions.create( file=file, - model=self.identifier, + model=self.model, prompt=self.prompt, **kwargs, ) @@ -510,7 +517,7 @@ def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs): "Converts multiple file-like Audio recordings to text." resps = [ self.syncClient.audio.transcriptions.create( - file=file, model=self.identifier, **kwargs + file=file, model=self.model, **kwargs ) for file in files ] @@ -522,7 +529,7 @@ async def _apredict_a_batch(self, files: t.List[t.BinaryIO], **kwargs): resps = await asyncio.gather( *[ self.asyncClient.audio.transcriptions.create( - file=file, model=self.identifier, **kwargs + file=file, model=self.model, **kwargs ) for file in files ] @@ -587,7 +594,7 @@ def _predict_one( return ( self.syncClient.audio.translations.create( file=file, - model=self.identifier, + model=self.model, prompt=self.prompt, **kwargs, ) @@ -603,7 +610,7 @@ async def _apredict_one( return ( await self.asyncClient.audio.translations.create( file=file, - model=self.identifier, + model=self.model, prompt=self.prompt, **kwargs, ) @@ -614,7 +621,7 @@ def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs): "Translates multiple file-like Audio recordings to English." resps = [ self.syncClient.audio.translations.create( - file=file, model=self.identifier, **kwargs + file=file, model=self.model, **kwargs ) for file in files ] @@ -626,7 +633,7 @@ async def _apredict_a_batch(self, files: t.List[t.BinaryIO], **kwargs): resps = await asyncio.gather( *[ self.asyncClient.audio.translations.create( - file=file, model=self.identifier, **kwargs + file=file, model=self.model, **kwargs ) for file in files ]