Skip to content

Commit

Permalink
Update options for embedding models (#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
caufieldjh authored May 20, 2024
2 parents b4b2cd8 + 2ef825b commit a84fa95
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 29 deletions.
7 changes: 7 additions & 0 deletions src/ontogpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""ontogpt package."""

from importlib import metadata
from pathlib import Path

Expand All @@ -16,6 +17,12 @@
DEFAULT_MODEL = model["alternative_names"][0]
break

OPENAI_EMBEDDING_MODELS = [
"text-embedding-ada-002",
"text-embedding-3-small",
"text-embedding-3-large",
]

try:
__version__ = metadata.version(__name__)
except metadata.PackageNotFoundError:
Expand Down
59 changes: 33 additions & 26 deletions src/ontogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
from sssom.util import to_mapping_set_dataframe

import ontogpt.ontex.extractor as extractor
from ontogpt import DEFAULT_MODEL, DEFAULT_MODEL_DETAILS, MODELS, __version__
from ontogpt import (
DEFAULT_MODEL,
DEFAULT_MODEL_DETAILS,
MODELS,
OPENAI_EMBEDDING_MODELS,
__version__,
)
from ontogpt.clients import OpenAIClient
from ontogpt.clients.pubmed_client import PubmedClient
from ontogpt.clients.soup_client import SoupClient
Expand Down Expand Up @@ -936,19 +942,18 @@ def embed(text, context, output, model, output_format, azure_select, **kwargs):
Not currently supported for open models.
"""
if model:
selectmodel = get_model_by_name(model)
model_source = selectmodel["provider"]

if model_source != "OpenAI":
raise NotImplementedError("Model not yet supported for embeddings.")
if model not in OPENAI_EMBEDDING_MODELS:
raise NotImplementedError("Model not recognized or not yet supported for embeddings.")
else:
model = "text-embedding-ada-002"

logging.info(f"Embedding with model {model}")

if not text:
raise ValueError("Text must be passed")

client = OpenAIClient(model=model, use_azure=azure_select)
resp = client.embeddings(text)
resp = client.embeddings(text=text, model=model)
print(resp)


Expand All @@ -969,14 +974,13 @@ def text_similarity(text, context, output, model, output_format, azure_select, *
Not currently supported for open models.
"""
if model:
selectmodel = get_model_by_name(model)
model_source = selectmodel["provider"]

if model_source != "OpenAI":
raise NotImplementedError("Model not yet supported for embeddings.")
if model not in OPENAI_EMBEDDING_MODELS:
raise NotImplementedError("Model not recognized or not yet supported for embeddings.")
else:
model = "text-embedding-ada-002"

logging.info(f"Embedding with model {model}")

if not text:
raise ValueError("Text must be passed")
text = list(text)
Expand Down Expand Up @@ -1010,14 +1014,13 @@ def text_distance(text, context, output, model, output_format, azure_select, **k
Not currently supported for open models.
"""
if model:
selectmodel = get_model_by_name(model)
model_source = selectmodel["provider"]

if model_source != "OpenAI":
raise NotImplementedError("Model not yet supported for embeddings.")
if model not in OPENAI_EMBEDDING_MODELS:
raise NotImplementedError("Model not recognized or not yet supported for embeddings.")
else:
model = "text-embedding-ada-002"

logging.info(f"Embedding with model {model}")

if not text:
raise ValueError("Text must be passed")
text = list(text)
Expand Down Expand Up @@ -1082,11 +1085,12 @@ def entity_similarity(terms, ontology, output, model, output_format, **kwargs):
Not currently supported for open models.
"""
if model:
selectmodel = get_model_by_name(model)
model_source = selectmodel["provider"]
if model not in OPENAI_EMBEDDING_MODELS:
raise NotImplementedError("Model not recognized or not yet supported for embeddings.")
else:
model = "text-embedding-ada-002"

if model_source != "OpenAI":
raise NotImplementedError("Model not yet supported for embeddings.")
logging.info(f"Embedding with model {model}")

if not terms:
raise ValueError("terms must be passed")
Expand Down Expand Up @@ -1202,13 +1206,16 @@ def diagnose(
@click.argument("output_directory")
@output_option_wb
def run_multilingual_analysis(
input_data_dir, output_directory, output, model="gpt-4-turbo",
input_data_dir,
output_directory,
output,
model="gpt-4-turbo",
):
"""Call the multilingual analysis function."""
multilingual_analysis(input_data_dir=input_data_dir,
output_directory=output_directory,
output=output,
model=model)
multilingual_analysis(
input_data_dir=input_data_dir, output_directory=output_directory, output=output, model=model
)


def get_kanjee_prompt() -> str:
"""Prompt from Kanjee et al. 2023."""
Expand Down
4 changes: 1 addition & 3 deletions src/ontogpt/clients/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,9 @@ def _must_use_chat_api(self) -> bool:
return False
return True

def embeddings(self, text: str, model: str = ""):
def embeddings(self, text: str, model: str):
text = str(text)

if model == "":
model = "text-embedding-ada-002"
cur = self.db_connection()
try:
logger.info("creating embeddings cache")
Expand Down

0 comments on commit a84fa95

Please sign in to comment.