Skip to content

Commit

Permalink
Add option to show prompt (#183)
Browse files Browse the repository at this point in the history
Extraction-based commands can be passed the `--show_prompts` flag.
If used with `-vvv`, *all* prompts will be output to the logger in their
full form.
Without that verbose flag, no additional output will be provided.
  • Loading branch information
caufieldjh authored Aug 25, 2023
2 parents ac4c060 + 9258c33 commit a859db7
Show file tree
Hide file tree
Showing 5 changed files with 339 additions and 46 deletions.
80 changes: 49 additions & 31 deletions src/ontogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ def get_model_by_name(modelname: str):
default="AUTO",
help="Prefix to use for auto-generated classes. Default is AUTO.",
)
show_prompt_option = click.option(
"--show-prompt/--no-show-prompt",
default=False,
show_default=True,
help="If set, show all prompts passed to model through an API. Use with verbose setting.",
)


@click.group()
Expand Down Expand Up @@ -226,6 +232,7 @@ def main(verbose: int, quiet: bool, cache_db: str, skip_annotator):
@output_format_options
@use_textract_options
@auto_prefix_option
@show_prompt_option
@click.option(
"--set-slot-value",
"-S",
Expand All @@ -244,6 +251,7 @@ def extract(
set_slot_value,
use_textract,
model,
show_prompt,
**kwargs,
):
"""Extract knowledge from text guided by schema, using SPIRES engine.
Expand Down Expand Up @@ -308,7 +316,7 @@ def extract(
target_class_def = ke.schemaview.get_class(target_class)
else:
target_class_def = None
results = ke.extract_from_text(text, target_class_def)
results = ke.extract_from_text(text=text, cls=target_class_def, show_prompt=show_prompt)
if set_slot_value:
for slot_value in set_slot_value:
slot, value = slot_value.split("=")
Expand All @@ -324,8 +332,9 @@ def extract(
@output_option_wb
@output_format_options
@auto_prefix_option
@show_prompt_option
@click.argument("entity")
def generate_extract(model, entity, template, output, output_format, **kwargs):
def generate_extract(model, entity, template, output, output_format, show_prompt, **kwargs):
"""Generate text and then extract knowledge from it."""
logging.info(f"Creating for {template}")

Expand All @@ -346,7 +355,7 @@ def generate_extract(model, entity, template, output, output_format, **kwargs):
ke = GPT4AllEngine(template=template, model=model_name, **kwargs)

logging.debug(f"Input entity: {entity}")
results = ke.generate_and_extract(entity)
results = ke.generate_and_extract(entity, show_prompt)
write_extraction(results, output, output_format, ke)


Expand All @@ -357,6 +366,7 @@ def generate_extract(model, entity, template, output, output_format, **kwargs):
@output_option_wb
@output_format_options
@auto_prefix_option
@show_prompt_option
@click.option("--ontology", "-r", help="Ontology to use; use oaklib selector path")
@click.option("--max-iterations", "-M", default=10, type=click.INT)
@click.option("--iteration-slot", "-I", multiple=True, help="Slots to iterate over")
Expand All @@ -376,6 +386,7 @@ def iteratively_generate_extract(
max_iterations,
clear,
ontology,
show_prompt,
**kwargs,
):
"""Iterate through generate-extract."""
Expand All @@ -402,6 +413,7 @@ def iteratively_generate_extract(
for results in ke.iteratively_generate_and_extract(
entity,
db,
show_prompt=show_prompt,
iteration_slots=list(iteration_slot),
max_iterations=max_iterations,
adapter=adapter,
Expand All @@ -416,13 +428,14 @@ def iteratively_generate_extract(
@recurse_option
@output_option_wb
@output_format_options
@show_prompt_option
@click.option(
"--get-pmc/--no-get-pmc",
default=False,
help="Attempt to parse PubMed Central full text(s) instead of abstract(s) alone.",
)
@click.argument("pmid")
def pubmed_extract(model, pmid, template, output, output_format, get_pmc, **kwargs):
def pubmed_extract(model, pmid, template, output, output_format, get_pmc, show_prompt, **kwargs):
"""Extract knowledge from a single PubMed ID."""
logging.info(f"Creating for {template}")

Expand Down Expand Up @@ -450,7 +463,7 @@ def pubmed_extract(model, pmid, template, output, output_format, get_pmc, **kwar
textlist = pmc.text(pmid)
for text in textlist:
logging.debug(f"Input text: {text}")
results = ke.extract_from_text(text)
results = ke.extract_from_text(text=text, show_prompt=show_prompt)
write_extraction(results, output, output_format)


Expand All @@ -460,6 +473,7 @@ def pubmed_extract(model, pmid, template, output, output_format, get_pmc, **kwar
@recurse_option
@output_option_wb
@output_format_options
@show_prompt_option
@click.option(
"--limit",
default=20,
Expand All @@ -471,7 +485,7 @@ def pubmed_extract(model, pmid, template, output, output_format, get_pmc, **kwar
help="Attempt to parse PubMed Central full text(s) instead of abstract(s) alone.",
)
@click.argument("search")
def pubmed_annotate(model, search, template, output, output_format, limit, get_pmc, **kwargs):
def pubmed_annotate(model, search, template, output, output_format, limit, get_pmc, show_prompt, **kwargs):
"""Retrieve a collection of PubMed IDs for a search term; annotate them using a template.
Example:
Expand Down Expand Up @@ -506,7 +520,7 @@ def pubmed_annotate(model, search, template, output, output_format, limit, get_p
textlist = pmc.text(pmids[: pubmed_annotate_limit + 1])
for text in textlist:
logging.debug(f"Input text: {text}")
results = ke.extract_from_text(text)
results = ke.extract_from_text(text=text, show_prompt=show_prompt)
write_extraction(results, output, output_format, ke)


Expand All @@ -516,9 +530,10 @@ def pubmed_annotate(model, search, template, output, output_format, limit, get_p
@recurse_option
@output_option_wb
@output_format_options
@show_prompt_option
@click.option("--auto-prefix", default="AUTO", help="Prefix to use for auto-generated classes.")
@click.argument("article")
def wikipedia_extract(model, article, template, output, output_format, **kwargs):
def wikipedia_extract(model, article, template, output, output_format, show_prompt, **kwargs):
"""Extract knowledge from a Wikipedia page."""
if not model:
model = DEFAULT_MODEL
Expand All @@ -541,7 +556,7 @@ def wikipedia_extract(model, article, template, output, output_format, **kwargs)
text = client.text(article)

logging.debug(f"Input text: {text}")
results = ke.extract_from_text(text)
results = ke.extract_from_text(text=text, show_prompt=show_prompt)
write_extraction(results, output, output_format, ke)


Expand All @@ -551,14 +566,15 @@ def wikipedia_extract(model, article, template, output, output_format, **kwargs)
@recurse_option
@output_option_wb
@output_format_options
@show_prompt_option
@click.option(
"--keyword",
"-k",
multiple=True,
help="Keyword to search for (e.g. --keyword therapy). Also obtained from schema",
)
@click.argument("topic")
def wikipedia_search(model, topic, keyword, template, output, output_format, **kwargs):
def wikipedia_search(model, topic, keyword, template, output, output_format, show_prompt, **kwargs):
"""Extract knowledge from a Wikipedia page."""
if not model:
model = DEFAULT_MODEL
Expand Down Expand Up @@ -589,7 +605,7 @@ def wikipedia_search(model, topic, keyword, template, output, output_format, **k
# TODO - expand this to fit context limits better
# or add as cli option
text = text[:4000]
results = ke.extract_from_text(text)
results = ke.extract_from_text(text=text, show_prompt=show_prompt)
write_extraction(results, output, output_format)
break

Expand All @@ -600,14 +616,15 @@ def wikipedia_search(model, topic, keyword, template, output, output_format, **k
@recurse_option
@output_option_wb
@output_format_options
@show_prompt_option
@click.option(
"--keyword",
"-k",
multiple=True,
help="Keyword to search for (e.g. --keyword therapy). Also obtained from schema",
)
@click.argument("term_tokens", nargs=-1)
def search_and_extract(model, term_tokens, keyword, template, output, output_format, **kwargs):
def search_and_extract(model, term_tokens, keyword, template, output, output_format, show_prompt, **kwargs):
"""Search for relevant literature and extract knowledge from it."""
if not model:
model = DEFAULT_MODEL
Expand Down Expand Up @@ -639,7 +656,7 @@ def search_and_extract(model, term_tokens, keyword, template, output, output_for
logging.info(f"PMID={pmid}")
text = pmc.text(pmid)
logging.info(f"Input text: {text}")
results = ke.extract_from_text(text)
results = ke.extract_from_text(text=text, show_prompt=show_prompt)
write_extraction(results, output, output_format)


Expand All @@ -649,8 +666,9 @@ def search_and_extract(model, term_tokens, keyword, template, output, output_for
@recurse_option
@output_option_wb
@output_format_options
@show_prompt_option
@click.argument("url")
def web_extract(model, template, url, output, output_format, **kwargs):
def web_extract(model, template, url, output, output_format, show_prompt, **kwargs):
"""Extract knowledge from web page."""
logging.info(f"Creating for {template}")

Expand All @@ -674,7 +692,7 @@ def web_extract(model, template, url, output, output_format, **kwargs):
text = web_client.text(url)

logging.debug(f"Input text: {text}")
results = ke.extract_from_text(text)
results = ke.extract_from_text(text=text, show_prompt=show_prompt)
write_extraction(results, output, output_format)


Expand All @@ -689,8 +707,9 @@ def web_extract(model, template, url, output, output_format, **kwargs):
)
@click.option("--auto-prefix", default="AUTO", help="Prefix to use for auto-generated classes.")
@model_option
@show_prompt_option
@click.argument("url")
def recipe_extract(model, url, recipes_urls_file, dictionary, output, output_format, **kwargs):
def recipe_extract(model, url, recipes_urls_file, dictionary, output, output_format, show_prompt, **kwargs):
"""Extract from recipe on the web."""
try:
from recipe_scrapers import scrape_me
Expand Down Expand Up @@ -737,7 +756,7 @@ def recipe_extract(model, url, recipes_urls_file, dictionary, output, output_for
Instructions:\n{instructions}
"""
logging.info(f"Input text: {text}")
results = ke.extract_from_text(text)
results = ke.extract_from_text(text=text, show_prompt=show_prompt)
logging.debug(f"Results: {results}")
results.extracted_object.url = url
write_extraction(results, output, output_format, ke)
Expand Down Expand Up @@ -839,6 +858,7 @@ def convert_geneset(input_file, output, output_format, fill, **kwargs):
@output_option_txt
@output_format_options
@model_option
@show_prompt_option
@click.option(
"--resolver", "-r", help="OAK selector for the gene ID resolver. E.g. sqlite:obo:hgnc"
)
Expand All @@ -853,12 +873,6 @@ def convert_geneset(input_file, output, output_format, fill, **kwargs):
show_default=True,
help="If set, there must be a unique mappings from labels to IDs",
)
@click.option(
"--show-prompt/--no-show-prompt",
default=True,
show_default=True,
help="If set, show prompt passed to model",
)
@click.option(
"--input-file",
"-U",
Expand Down Expand Up @@ -995,7 +1009,7 @@ def enrichment(
@click.argument("text", nargs=-1)
def embed(text, context, output, model, output_format, **kwargs):
"""Embed text.
Not currently supported for open models.
"""
if model:
Expand Down Expand Up @@ -1027,7 +1041,7 @@ def embed(text, context, output, model, output_format, **kwargs):
@click.argument("text", nargs=-1)
def text_similarity(text, context, output, model, output_format, **kwargs):
"""Embed text.
Not currently supported for open models.
"""
if model:
Expand Down Expand Up @@ -1067,7 +1081,7 @@ def text_similarity(text, context, output, model, output_format, **kwargs):
@click.argument("text", nargs=-1)
def text_distance(text, context, output, model, output_format, **kwargs):
"""Embed text and calculate euclidian distance between embeddings.
Not currently supported for open models.
"""
if model:
Expand Down Expand Up @@ -1440,8 +1454,9 @@ def eval(evaluator, num_tests, output, output_format, **kwargs):
@recurse_option
@output_option_wb
@output_format_options
@show_prompt_option
@click.argument("object")
def fill(model, template, object: str, examples, output, output_format, **kwargs):
def fill(model, template, object: str, examples, output, output_format, show_prompt, **kwargs):
"""Fill in missing values."""
logging.info(f"Creating for {template}")

Expand All @@ -1463,7 +1478,7 @@ def fill(model, template, object: str, examples, output, output_format, **kwargs
logging.info(f"Loading {examples}")
examples = yaml.safe_load(examples)
logging.debug(f"Input object: {object}")
results = ke.generalize(object, examples)
results = ke.generalize(object, examples, show_prompt)

output.write(yaml.dump(results.dict()))

Expand All @@ -1480,8 +1495,9 @@ def openai_models(**kwargs):
@model_option
@output_option_txt
@output_format_options
@show_prompt_option
@click.argument("input")
def complete(model, input, output, output_format, **kwargs):
def complete(model, input, output, output_format, show_prompt, **kwargs):
"""Prompt completion."""
if not model:
model = DEFAULT_MODEL
Expand All @@ -1493,7 +1509,7 @@ def complete(model, input, output, output_format, **kwargs):

if model_source == "OpenAI":
c = OpenAIClient(model=model_name)
results = c.complete(text)
results = c.complete(text, show_prompt)

elif model_source == "GPT4All":
c = set_up_gpt4all_model(modelname=model_name)
Expand Down Expand Up @@ -1594,6 +1610,7 @@ def halo(model, input, context, terms, output, **kwargs):
@model_option
@output_option_wb
@output_format_options
@show_prompt_option
@click.option(
"-d",
"--description",
Expand All @@ -1607,6 +1624,7 @@ def clinical_notes(
sections,
output,
model,
show_prompt,
output_format,
**kwargs,
):
Expand All @@ -1631,7 +1649,7 @@ def clinical_notes(

if model_source == "OpenAI":
c = OpenAIClient(model=model_name)
results = c.complete(prompt)
results = c.complete(prompt, show_prompt)

elif model_source == "GPT4All":
c = set_up_gpt4all_model(modelname=model_name)
Expand Down
4 changes: 3 additions & 1 deletion src/ontogpt/clients/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ def __post_init__(self):
self.api_key = get_apikey_value("openai")
openai.api_key = self.api_key

def complete(self, prompt, max_tokens=3000, **kwargs) -> str:
def complete(self, prompt, show_prompt: bool = False, max_tokens=3000, **kwargs) -> str:
engine = self.model
logger.info(f"Complete: engine={engine}, prompt[{len(prompt)}]={prompt[0:100]}...")
if show_prompt:
logger.info(f" SENDING PROMPT:\n{prompt}")
cur = self.db_connection()
res = cur.execute("SELECT payload FROM cache WHERE prompt=? AND engine=?", (prompt, engine))
payload = res.fetchone()
Expand Down
Loading

0 comments on commit a859db7

Please sign in to comment.