diff --git a/src/ontogpt/cli.py b/src/ontogpt/cli.py index 0019cf5cc..e3d36257f 100644 --- a/src/ontogpt/cli.py +++ b/src/ontogpt/cli.py @@ -38,6 +38,7 @@ from ontogpt.engines.reasoner_engine import ReasonerEngine from ontogpt.engines.spires_engine import SPIRESEngine from ontogpt.engines.synonym_engine import SynonymEngine +from ontogpt.engines.topic_classifier_engine import TopicClassifierEngine from ontogpt.evaluation.resolver import create_evaluator from ontogpt.io.csv_wrapper import parse_yaml_predictions, write_graph, write_obj_as_csv from ontogpt.io.html_exporter import HTMLExporter @@ -1112,6 +1113,86 @@ def synonyms( output.write(f"{line}\n") +@main.command() +@inputfile_option +@use_pdf_options +@model_option +@temperature_option +@api_base_option +@api_version_option +@model_provider_option +@system_message_option +@click.argument("topic") +def classify_by_topic( + inputfile, + model, + temperature, + api_base, + api_version, + model_provider, + system_message, + topic, + use_pdf, +): + """Classify input text by topic. + + Returns True if the input text is about the topic, False otherwise, + along with the name of the input file. + + A path to a file containing input text may be passed as inputfile, + as may a directory of input files. + + Example: + + ontogpt classify-by-topic -i temp/30091466.txt + "clinical observations of human patients, including the diagnostic + and therapeutic procedures used during their clinical care" + + """ + + if not model: + model = DEFAULT_MODEL + + inputdict = {} + + if not inputfile or inputfile == "-": + text = sys.stdin.read() + inputdict["Input"] = text + elif inputfile and Path(inputfile).is_dir(): + logging.info(f"Input file directory: {inputfile}") + inputfiles = Path(inputfile).glob("*.txt") + inputdict = {f: (open(f, "r").read()) for f in inputfiles if f.is_file()} + logging.info(f"Found {len(inputdict)} input files here.") + elif inputfile and Path(inputfile).exists(): + logging.info(f"Input file: {inputfile}") + if use_pdf: + import pymupdf + + doc = pymupdf.open(inputfile) + text = "" + for page in doc: + text = text + (page.get_text()) + else: + text = open(inputfile, "rb").read().decode(encoding="utf-8", errors="ignore") + logging.info(f"Input text: {text}") + inputdict[inputfile] = text + elif inputfile and not Path(inputfile).exists(): + raise FileNotFoundError(f"Cannot find input file {inputfile}") + + ke = TopicClassifierEngine( + model=model, + temperature=temperature, + api_base=api_base, + api_version=api_version, + model_provider=model_provider, + system_message=system_message, + ) + + for input_entry in inputdict: + response = ke.binary_classify(topic=topic, text=inputdict[input_entry]) + print(f"{input_entry}\t{response}") + + @main.command() @model_option @api_base_option