Skip to content

Commit

Permalink
Support openai>=1.1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou committed Nov 10, 2023
1 parent dc73752 commit 07787bb
Show file tree
Hide file tree
Showing 30 changed files with 2,711 additions and 103,341 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ torch = [
"torchaudio",
]
apis = [
"openai==0.27.6",
"openai>=1.1.2",
"cohere",
"anthropic",
]
Expand Down
210 changes: 132 additions & 78 deletions superduperdb/ext/openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
import aiohttp
import requests
import tqdm
from openai import Audio, ChatCompletion, Embedding, Image, Model as OpenAIModel
from openai.error import RateLimitError, ServiceUnavailableError, Timeout, TryAgain
from openai import (
APITimeoutError,
AsyncOpenAI,
InternalServerError,
OpenAI as SyncOpenAI,
RateLimitError,
)

from superduperdb.components.component import Component
from superduperdb.components.encoder import Encoder
Expand All @@ -18,14 +23,12 @@
from superduperdb.misc.compat import cache
from superduperdb.misc.retry import Retry

retry = Retry(
exception_types=(RateLimitError, ServiceUnavailableError, Timeout, TryAgain)
)
retry = Retry(exception_types=(RateLimitError, InternalServerError, APITimeoutError))


@cache
def _available_models():
return tuple([r['id'] for r in OpenAIModel.list()['data']])
return tuple([r.id for r in SyncOpenAI().models.list().data])


@dc.dataclass
Expand Down Expand Up @@ -65,6 +68,9 @@ def __post_init__(self):

self.identifier = self.identifier or self.model

self.syncClient = SyncOpenAI()
self.asyncClient = AsyncOpenAI()

if 'OPENAI_API_KEY' not in os.environ:
raise ValueError('OPENAI_API_KEY not set')

Expand All @@ -88,23 +94,29 @@ def __post_init__(self):

@retry
def _predict_one(self, X: str, **kwargs):
e = Embedding.create(input=X, model=self.identifier, **kwargs)
return e['data'][0]['embedding']
e = self.syncClient.embeddings.create(input=X, model=self.identifier, **kwargs)
return e.data[0].embedding

@retry
async def _apredict_one(self, X: str, **kwargs):
e = await Embedding.acreate(input=X, model=self.identifier, **kwargs)
return e['data'][0]['embedding']
e = await self.asyncClient.embeddings.create(
input=X, model=self.identifier, **kwargs
)
return e.data[0].embedding

@retry
def _predict_a_batch(self, texts: t.List[str], **kwargs):
out = Embedding.create(input=texts, model=self.identifier, **kwargs)
return [r['embedding'] for r in out['data']]
out = self.syncClient.embeddings.create(
input=texts, model=self.identifier, **kwargs
)
return [r.embedding for r in out.data]

@retry
async def _apredict_a_batch(self, texts: t.List[str], **kwargs):
out = await Embedding.acreate(input=texts, model=self.identifier, **kwargs)
return [r['embedding'] for r in out['data']]
out = await self.asyncClient.embeddings.create(
input=texts, model=self.identifier, **kwargs
)
return [r.embedding for r in out.data]

def _predict(self, X, one: bool = False, **kwargs):
if isinstance(X, str):
Expand Down Expand Up @@ -145,23 +157,31 @@ def _format_prompt(self, context, X):
def _predict_one(self, X, context: t.Optional[t.List[str]] = None, **kwargs):
if context is not None:
X = self._format_prompt(context, X)
return ChatCompletion.create(
messages=[{'role': 'user', 'content': X}],
model=self.identifier,
**kwargs,
)['choices'][0]['message']['content']
return (
self.syncClient.chat.completions.create(
messages=[{'role': 'user', 'content': X}],
model=self.identifier,
**kwargs,
)
.choices[0]
.message.content
)

@retry
async def _apredict_one(self, X, context: t.Optional[t.List[str]] = None, **kwargs):
if context is not None:
X = self._format_prompt(context, X)
return (
await ChatCompletion.acreate(
messages=[{'role': 'user', 'content': X}],
model=self.identifier,
**kwargs,
(
await self.asyncClient.chat.completions.create(
messages=[{'role': 'user', 'content': X}],
model=self.identifier,
**kwargs,
)
)
)['choices'][0]['message']['content']
.choices[0]
.message.content
)

def _predict(
self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs
Expand Down Expand Up @@ -209,12 +229,16 @@ def _predict_one(
if context is not None:
X = self._format_prompt(context, X)
if response_format == 'b64_json':
b64_json = Image.create(prompt=X, n=n, response_format='b64_json')['data'][
0
]['b64_json']
b64_json = (
self.syncClient.images.generate(
prompt=X, n=n, response_format='b64_json'
)
.data[0]
.b64_json
)
return base64.b64decode(b64_json)
else:
url = Image.create(prompt=X, n=n, **kwargs)['data'][0]['url']
url = self.syncClient.images.generate(prompt=X, n=n, **kwargs).data[0].url
return requests.get(url).content

@retry
Expand All @@ -229,12 +253,22 @@ async def _apredict_one(
if context is not None:
X = self._format_prompt(context, X)
if response_format == 'b64_json':
b64_json = (await Image.acreate(prompt=X, n=n, response_format='b64_json'))[
'data'
][0]['b64_json']
b64_json = (
(
await self.asyncClient.images.generate(
prompt=X, n=n, response_format='b64_json'
)
)
.data[0]
.b64_json
)
return base64.b64decode(b64_json)
else:
url = (await Image.acreate(prompt=X, n=n, **kwargs))['data'][0]['url']
url = (
(await self.asyncClient.images.generate(prompt=X, n=n, **kwargs))
.data[0]
.url
)
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
return await resp.read()
Expand Down Expand Up @@ -300,22 +334,29 @@ def _predict_one(
if mask_png_path is not None:
with open(mask_png_path, 'rb') as f:
mask = f.read()
else:
mask = None
kwargs['mask'] = mask

if response_format == 'b64_json':
b64_json = Image.create_edit(
image=image,
mask=mask,
prompt=self.prompt,
n=n,
response_format='b64_json',
)['data'][0]['b64_json']
b64_json = (
self.syncClient.images.edit(
image=image,
prompt=self.prompt,
n=n,
response_format='b64_json',
**kwargs,
)
.data[0]
.b64_json
)
return base64.b64decode(b64_json)
else:
url = Image.create_edit(
image=image, mask=mask, prompt=self.prompt, n=n, **kwargs
)['data'][0]['url']
url = (
self.syncClient.images.edit(
image=image, prompt=self.prompt, n=n, **kwargs
)
.data[0]
.url
)
return requests.get(url).content

@retry
Expand All @@ -334,26 +375,33 @@ async def _apredict_one(
if mask_png_path is not None:
with open(mask_png_path, 'rb') as f:
mask = f.read()
else:
mask = None
kwargs['mask'] = mask

if response_format == 'b64_json':
b64_json = (
await Image.acreate_edit(
image=image,
mask=mask,
prompt=self.prompt,
n=n,
response_format='b64_json',
(
await self.asyncClient.images.edit(
image=image,
prompt=self.prompt,
n=n,
response_format='b64_json',
**kwargs,
)
)
)['data'][0]['b64_json']
.data[0]
.b64_json
)
return base64.b64decode(b64_json)
else:
url = (
await Image.acreate_edit(
image=image, mask=mask, prompt=self.prompt, n=n, **kwargs
(
await self.asyncClient.images.edit(
image=image, prompt=self.prompt, n=n, **kwargs
)
)
)['data'][0]['url']
.data[0]
.url
)
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
return await resp.read()
Expand Down Expand Up @@ -411,14 +459,12 @@ def _predict_one(
"Converts a file-like Audio recording to text."
if context is not None:
self.prompt = self.prompt.format(context='\n'.join(context))
return (
Audio.transcribe(
file=file,
model=self.identifier,
prompt=self.prompt,
**kwargs,
)
)['text']
return self.syncClient.audio.transcriptions.create(
file=file,
model=self.identifier,
prompt=self.prompt,
**kwargs,
).text

@retry
async def _apredict_one(
Expand All @@ -428,33 +474,37 @@ async def _apredict_one(
if context is not None:
self.prompt = self.prompt.format(context='\n'.join(context))
return (
await Audio.atranscribe(
await self.asyncClient.audio.transcriptions.create(
file=file,
model=self.identifier,
prompt=self.prompt,
**kwargs,
)
)['text']
).text

@retry
def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
"Converts multiple file-like Audio recordings to text."
resps = [
Audio.transcribe(file=file, model=self.identifier, **kwargs)
self.syncClient.audio.transcriptions.create(
file=file, model=self.identifier, **kwargs
)
for file in files
]
return [resp['text'] for resp in resps]
return [resp.text for resp in resps]

@retry
async def _apredict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
"Converts multiple file-like Audio recordings to text."
resps = await asyncio.gather(
*[
Audio.atranscribe(file=file, model=self.identifier, **kwargs)
self.asyncClient.audio.transcriptions.create(
file=file, model=self.identifier, **kwargs
)
for file in files
]
)
return [resp['text'] for resp in resps]
return [resp.text for resp in resps]

def _predict(
self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs
Expand Down Expand Up @@ -505,13 +555,13 @@ def _predict_one(
if context is not None:
self.prompt = self.prompt.format(context='\n'.join(context))
return (
Audio.translate(
self.syncClient.audio.translations.create(
file=file,
model=self.identifier,
prompt=self.prompt,
**kwargs,
)
)['text']
).text

@retry
async def _apredict_one(
Expand All @@ -521,33 +571,37 @@ async def _apredict_one(
if context is not None:
self.prompt = self.prompt.format(context='\n'.join(context))
return (
await Audio.atranslate(
await self.asyncClient.audio.translations.create(
file=file,
model=self.identifier,
prompt=self.prompt,
**kwargs,
)
)['text']
).text

@retry
def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
"Translates multiple file-like Audio recordings to English."
resps = [
Audio.translate(file=file, model=self.identifier, **kwargs)
self.syncClient.audio.translations.create(
file=file, model=self.identifier, **kwargs
)
for file in files
]
return [resp['text'] for resp in resps]
return [resp.text for resp in resps]

@retry
async def _apredict_a_batch(self, files: t.List[t.BinaryIO], **kwargs):
"Translates multiple file-like Audio recordings to English."
resps = await asyncio.gather(
*[
Audio.atranslate(file=file, model=self.identifier, **kwargs)
self.asyncClient.audio.translations.create(
file=file, model=self.identifier, **kwargs
)
for file in files
]
)
return [resp['text'] for resp in resps]
return [resp.text for resp in resps]

def _predict(
self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs
Expand Down
Loading

0 comments on commit 07787bb

Please sign in to comment.