Skip to content

Commit

Permalink
Replace identifier with model in openai models
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Feb 1, 2024
1 parent 32b3a17 commit 81451de
Showing 1 changed file with 30 additions and 23 deletions.
53 changes: 30 additions & 23 deletions superduperdb/ext/openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import base64
import dataclasses as dc
import itertools
import json
import os
import typing as t

Expand All @@ -28,8 +29,9 @@


@cache
def _available_models():
return tuple([r.id for r in SyncOpenAI().models.list().data])
def _available_models(skwargs):
kwargs = json.loads(skwargs)
return tuple([r.id for r in SyncOpenAI(**kwargs).models.list().data])


@dc.dataclass(kw_only=True)
Expand All @@ -45,18 +47,25 @@ class _OpenAI(APIModel):
def __post_init__(self):
super().__post_init__()

assert isinstance(self.client_kwargs, dict)

# dall-e is not currently included in list returned by OpenAI model endpoint
if self.model not in (mo := _available_models()) and self.model not in (
'dall-e'
):
if self.model not in (
mo := _available_models(json.dumps(self.client_kwargs))
) and self.model not in ('dall-e'):
msg = f'model {self.model} not in OpenAI available models, {mo}'
raise ValueError(msg)

self.syncClient = SyncOpenAI(**self.client_kwargs)
self.asyncClient = AsyncOpenAI(**self.client_kwargs)

if 'OPENAI_API_KEY' not in os.environ:
raise ValueError('OPENAI_API_KEY not set')
if 'OPENAI_API_KEY' not in os.environ or (
'api_key' not in self.client_kwargs.keys() and self.client_kwargs
):
raise ValueError(
'OPENAI_API_KEY not available neither in environment vars '
'nor in `client_kwargs`'
)


@dc.dataclass(kw_only=True)
Expand Down Expand Up @@ -87,27 +96,25 @@ def pre_create(self, db):

@retry
def _predict_one(self, X: str, **kwargs):
e = self.syncClient.embeddings.create(input=X, model=self.identifier, **kwargs)
e = self.syncClient.embeddings.create(input=X, model=self.model, **kwargs)
return e.data[0].embedding

@retry
async def _apredict_one(self, X: str, **kwargs):
e = await self.asyncClient.embeddings.create(
input=X, model=self.identifier, **kwargs
input=X, model=self.model, **kwargs
)
return e.data[0].embedding

@retry
def _predict_a_batch(self, texts: t.List[str], **kwargs):
out = self.syncClient.embeddings.create(
input=texts, model=self.identifier, **kwargs
)
out = self.syncClient.embeddings.create(input=texts, model=self.model, **kwargs)
return [r.embedding for r in out.data]

@retry
async def _apredict_a_batch(self, texts: t.List[str], **kwargs):
out = await self.asyncClient.embeddings.create(
input=texts, model=self.identifier, **kwargs
input=texts, model=self.model, **kwargs
)
return [r.embedding for r in out.data]

Expand Down Expand Up @@ -162,7 +169,7 @@ def _predict_one(self, X, context: t.Optional[t.List[str]] = None, **kwargs):
return (
self.syncClient.chat.completions.create(
messages=[{'role': 'user', 'content': X}],
model=self.identifier,
model=self.model,
**kwargs,
)
.choices[0]
Expand All @@ -177,7 +184,7 @@ async def _apredict_one(self, X, context: t.Optional[t.List[str]] = None, **kwar
(
await self.asyncClient.chat.completions.create(
messages=[{'role': 'user', 'content': X}],
model=self.identifier,
model=self.model,
**kwargs,
)
)
Expand Down Expand Up @@ -484,7 +491,7 @@ def _predict_one(
self.prompt = self.prompt.format(context='\n'.join(context))
return self.syncClient.audio.transcriptions.create(
file=file,
model=self.identifier,
model=self.model,
prompt=self.prompt,
**kwargs,
).text
Expand All @@ -499,7 +506,7 @@ async def _apredict_one(
return (
await self.asyncClient.audio.transcriptions.create(
file=file,
model=self.identifier,
model=self.model,
prompt=self.prompt,
**kwargs,
)
Expand All @@ -510,7 +517,7 @@ def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
"Converts multiple file-like Audio recordings to text."
resps = [
self.syncClient.audio.transcriptions.create(
file=file, model=self.identifier, **kwargs
file=file, model=self.model, **kwargs
)
for file in files
]
Expand All @@ -522,7 +529,7 @@ async def _apredict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
resps = await asyncio.gather(
*[
self.asyncClient.audio.transcriptions.create(
file=file, model=self.identifier, **kwargs
file=file, model=self.model, **kwargs
)
for file in files
]
Expand Down Expand Up @@ -587,7 +594,7 @@ def _predict_one(
return (
self.syncClient.audio.translations.create(
file=file,
model=self.identifier,
model=self.model,
prompt=self.prompt,
**kwargs,
)
Expand All @@ -603,7 +610,7 @@ async def _apredict_one(
return (
await self.asyncClient.audio.translations.create(
file=file,
model=self.identifier,
model=self.model,
prompt=self.prompt,
**kwargs,
)
Expand All @@ -614,7 +621,7 @@ def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
"Translates multiple file-like Audio recordings to English."
resps = [
self.syncClient.audio.translations.create(
file=file, model=self.identifier, **kwargs
file=file, model=self.model, **kwargs
)
for file in files
]
Expand All @@ -626,7 +633,7 @@ async def _apredict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
resps = await asyncio.gather(
*[
self.asyncClient.audio.translations.create(
file=file, model=self.identifier, **kwargs
file=file, model=self.model, **kwargs
)
for file in files
]
Expand Down

0 comments on commit 81451de

Please sign in to comment.