From 81451de914a742a999855ceeeed0639cfd6ed028 Mon Sep 17 00:00:00 2001
From: thejumpman2323 <kartik@superduperdb.com>
Date: Mon, 29 Jan 2024 18:04:14 +0530
Subject: [PATCH] Replace identifier with model in openai models

---
 superduperdb/ext/openai/model.py | 53 ++++++++++++++++++--------------
 1 file changed, 30 insertions(+), 23 deletions(-)

diff --git a/superduperdb/ext/openai/model.py b/superduperdb/ext/openai/model.py
index c5f5dbdd9..12677564e 100644
--- a/superduperdb/ext/openai/model.py
+++ b/superduperdb/ext/openai/model.py
@@ -2,6 +2,7 @@
 import base64
 import dataclasses as dc
 import itertools
+import json
 import os
 import typing as t
 
@@ -28,8 +29,9 @@
 
 
 @cache
-def _available_models():
-    return tuple([r.id for r in SyncOpenAI().models.list().data])
+def _available_models(skwargs):
+    kwargs = json.loads(skwargs)
+    return tuple([r.id for r in SyncOpenAI(**kwargs).models.list().data])
 
 
 @dc.dataclass(kw_only=True)
@@ -45,18 +47,25 @@ class _OpenAI(APIModel):
     def __post_init__(self):
         super().__post_init__()
 
+        assert isinstance(self.client_kwargs, dict)
+
         # dall-e is not currently included in list returned by OpenAI model endpoint
-        if self.model not in (mo := _available_models()) and self.model not in (
-            'dall-e'
-        ):
+        if self.model not in (
+            mo := _available_models(json.dumps(self.client_kwargs))
+        ) and self.model not in ('dall-e'):
             msg = f'model {self.model} not in OpenAI available models, {mo}'
             raise ValueError(msg)
 
         self.syncClient = SyncOpenAI(**self.client_kwargs)
         self.asyncClient = AsyncOpenAI(**self.client_kwargs)
 
-        if 'OPENAI_API_KEY' not in os.environ:
-            raise ValueError('OPENAI_API_KEY not set')
+        if 'OPENAI_API_KEY' not in os.environ or (
+            'api_key' not in self.client_kwargs.keys() and self.client_kwargs
+        ):
+            raise ValueError(
+                'OPENAI_API_KEY not available neither in environment vars '
+                'nor in `client_kwargs`'
+            )
 
 
 @dc.dataclass(kw_only=True)
@@ -87,27 +96,25 @@ def pre_create(self, db):
 
     @retry
     def _predict_one(self, X: str, **kwargs):
-        e = self.syncClient.embeddings.create(input=X, model=self.identifier, **kwargs)
+        e = self.syncClient.embeddings.create(input=X, model=self.model, **kwargs)
         return e.data[0].embedding
 
     @retry
     async def _apredict_one(self, X: str, **kwargs):
         e = await self.asyncClient.embeddings.create(
-            input=X, model=self.identifier, **kwargs
+            input=X, model=self.model, **kwargs
         )
         return e.data[0].embedding
 
     @retry
     def _predict_a_batch(self, texts: t.List[str], **kwargs):
-        out = self.syncClient.embeddings.create(
-            input=texts, model=self.identifier, **kwargs
-        )
+        out = self.syncClient.embeddings.create(input=texts, model=self.model, **kwargs)
         return [r.embedding for r in out.data]
 
     @retry
     async def _apredict_a_batch(self, texts: t.List[str], **kwargs):
         out = await self.asyncClient.embeddings.create(
-            input=texts, model=self.identifier, **kwargs
+            input=texts, model=self.model, **kwargs
         )
         return [r.embedding for r in out.data]
 
@@ -162,7 +169,7 @@ def _predict_one(self, X, context: t.Optional[t.List[str]] = None, **kwargs):
         return (
             self.syncClient.chat.completions.create(
                 messages=[{'role': 'user', 'content': X}],
-                model=self.identifier,
+                model=self.model,
                 **kwargs,
             )
             .choices[0]
@@ -177,7 +184,7 @@ async def _apredict_one(self, X, context: t.Optional[t.List[str]] = None, **kwar
             (
                 await self.asyncClient.chat.completions.create(
                     messages=[{'role': 'user', 'content': X}],
-                    model=self.identifier,
+                    model=self.model,
                     **kwargs,
                 )
             )
@@ -484,7 +491,7 @@ def _predict_one(
             self.prompt = self.prompt.format(context='\n'.join(context))
         return self.syncClient.audio.transcriptions.create(
             file=file,
-            model=self.identifier,
+            model=self.model,
             prompt=self.prompt,
             **kwargs,
         ).text
@@ -499,7 +506,7 @@ async def _apredict_one(
         return (
             await self.asyncClient.audio.transcriptions.create(
                 file=file,
-                model=self.identifier,
+                model=self.model,
                 prompt=self.prompt,
                 **kwargs,
             )
@@ -510,7 +517,7 @@ def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
         "Converts multiple file-like Audio recordings to text."
         resps = [
             self.syncClient.audio.transcriptions.create(
-                file=file, model=self.identifier, **kwargs
+                file=file, model=self.model, **kwargs
             )
             for file in files
         ]
@@ -522,7 +529,7 @@ async def _apredict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
         resps = await asyncio.gather(
             *[
                 self.asyncClient.audio.transcriptions.create(
-                    file=file, model=self.identifier, **kwargs
+                    file=file, model=self.model, **kwargs
                 )
                 for file in files
             ]
@@ -587,7 +594,7 @@ def _predict_one(
         return (
             self.syncClient.audio.translations.create(
                 file=file,
-                model=self.identifier,
+                model=self.model,
                 prompt=self.prompt,
                 **kwargs,
             )
@@ -603,7 +610,7 @@ async def _apredict_one(
         return (
             await self.asyncClient.audio.translations.create(
                 file=file,
-                model=self.identifier,
+                model=self.model,
                 prompt=self.prompt,
                 **kwargs,
             )
@@ -614,7 +621,7 @@ def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
         "Translates multiple file-like Audio recordings to English."
         resps = [
             self.syncClient.audio.translations.create(
-                file=file, model=self.identifier, **kwargs
+                file=file, model=self.model, **kwargs
             )
             for file in files
         ]
@@ -626,7 +633,7 @@ async def _apredict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
         resps = await asyncio.gather(
             *[
                 self.asyncClient.audio.translations.create(
-                    file=file, model=self.identifier, **kwargs
+                    file=file, model=self.model, **kwargs
                 )
                 for file in files
             ]