diff --git a/src/ontogpt/__init__.py b/src/ontogpt/__init__.py index 49e2d1cda..8b663e8aa 100644 --- a/src/ontogpt/__init__.py +++ b/src/ontogpt/__init__.py @@ -1,4 +1,5 @@ """ontogpt package.""" + from importlib import metadata from pathlib import Path @@ -16,6 +17,12 @@ DEFAULT_MODEL = model["alternative_names"][0] break +OPENAI_EMBEDDING_MODELS = [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", +] + try: __version__ = metadata.version(__name__) except metadata.PackageNotFoundError: diff --git a/src/ontogpt/cli.py b/src/ontogpt/cli.py index 64bb16012..ac10d1b1e 100644 --- a/src/ontogpt/cli.py +++ b/src/ontogpt/cli.py @@ -24,7 +24,13 @@ from sssom.util import to_mapping_set_dataframe import ontogpt.ontex.extractor as extractor -from ontogpt import DEFAULT_MODEL, DEFAULT_MODEL_DETAILS, MODELS, __version__ +from ontogpt import ( + DEFAULT_MODEL, + DEFAULT_MODEL_DETAILS, + MODELS, + OPENAI_EMBEDDING_MODELS, + __version__, +) from ontogpt.clients import OpenAIClient from ontogpt.clients.pubmed_client import PubmedClient from ontogpt.clients.soup_client import SoupClient @@ -936,19 +942,18 @@ def embed(text, context, output, model, output_format, azure_select, **kwargs): Not currently supported for open models. """ if model: - selectmodel = get_model_by_name(model) - model_source = selectmodel["provider"] - - if model_source != "OpenAI": - raise NotImplementedError("Model not yet supported for embeddings.") + if model not in OPENAI_EMBEDDING_MODELS: + raise NotImplementedError("Model not recognized or not yet supported for embeddings.") else: model = "text-embedding-ada-002" + logging.info(f"Embedding with model {model}") + if not text: raise ValueError("Text must be passed") client = OpenAIClient(model=model, use_azure=azure_select) - resp = client.embeddings(text) + resp = client.embeddings(text=text, model=model) print(resp) @@ -969,14 +974,13 @@ def text_similarity(text, context, output, model, output_format, azure_select, * Not currently supported for open models. """ if model: - selectmodel = get_model_by_name(model) - model_source = selectmodel["provider"] - - if model_source != "OpenAI": - raise NotImplementedError("Model not yet supported for embeddings.") + if model not in OPENAI_EMBEDDING_MODELS: + raise NotImplementedError("Model not recognized or not yet supported for embeddings.") else: model = "text-embedding-ada-002" + logging.info(f"Embedding with model {model}") + if not text: raise ValueError("Text must be passed") text = list(text) @@ -1010,14 +1014,13 @@ def text_distance(text, context, output, model, output_format, azure_select, **k Not currently supported for open models. """ if model: - selectmodel = get_model_by_name(model) - model_source = selectmodel["provider"] - - if model_source != "OpenAI": - raise NotImplementedError("Model not yet supported for embeddings.") + if model not in OPENAI_EMBEDDING_MODELS: + raise NotImplementedError("Model not recognized or not yet supported for embeddings.") else: model = "text-embedding-ada-002" + logging.info(f"Embedding with model {model}") + if not text: raise ValueError("Text must be passed") text = list(text) @@ -1082,11 +1085,12 @@ def entity_similarity(terms, ontology, output, model, output_format, **kwargs): Not currently supported for open models. """ if model: - selectmodel = get_model_by_name(model) - model_source = selectmodel["provider"] + if model not in OPENAI_EMBEDDING_MODELS: + raise NotImplementedError("Model not recognized or not yet supported for embeddings.") + else: + model = "text-embedding-ada-002" - if model_source != "OpenAI": - raise NotImplementedError("Model not yet supported for embeddings.") + logging.info(f"Embedding with model {model}") if not terms: raise ValueError("terms must be passed") @@ -1202,13 +1206,16 @@ def diagnose( @click.argument("output_directory") @output_option_wb def run_multilingual_analysis( - input_data_dir, output_directory, output, model="gpt-4-turbo", + input_data_dir, + output_directory, + output, + model="gpt-4-turbo", ): """Call the multilingual analysis function.""" - multilingual_analysis(input_data_dir=input_data_dir, - output_directory=output_directory, - output=output, - model=model) + multilingual_analysis( + input_data_dir=input_data_dir, output_directory=output_directory, output=output, model=model + ) + def get_kanjee_prompt() -> str: """Prompt from Kanjee et al. 2023.""" diff --git a/src/ontogpt/clients/openai_client.py b/src/ontogpt/clients/openai_client.py index aed20e61c..fbf27cf46 100644 --- a/src/ontogpt/clients/openai_client.py +++ b/src/ontogpt/clients/openai_client.py @@ -173,11 +173,9 @@ def _must_use_chat_api(self) -> bool: return False return True - def embeddings(self, text: str, model: str = ""): + def embeddings(self, text: str, model: str): text = str(text) - if model == "": - model = "text-embedding-ada-002" cur = self.db_connection() try: logger.info("creating embeddings cache")