From 71f8560e13c340174fd82cb5cf302524df14518b Mon Sep 17 00:00:00 2001 From: caufieldjh Date: Thu, 14 Sep 2023 14:54:21 -0400 Subject: [PATCH] Fifth pass at type fixes --- src/ontogpt/cli.py | 2 +- src/ontogpt/engines/enrichment.py | 8 +++--- src/ontogpt/engines/knowledge_engine.py | 12 ++++----- src/ontogpt/engines/mapping_engine.py | 2 +- src/ontogpt/engines/pheno_engine.py | 2 +- src/ontogpt/engines/spires_engine.py | 2 +- src/ontogpt/evaluation/ctd/eval_ctd.py | 14 +++++----- .../evaluation/drugmechdb/eval_drugmechdb.py | 10 +++++-- src/ontogpt/evaluation/go/eval_go.py | 27 ++++++++++--------- src/ontogpt/io/html_exporter.py | 5 ++-- src/ontogpt/io/yaml_wrapper.py | 1 + 11 files changed, 47 insertions(+), 38 deletions(-) diff --git a/src/ontogpt/cli.py b/src/ontogpt/cli.py index 5c93aa4f7..f31a74b43 100644 --- a/src/ontogpt/cli.py +++ b/src/ontogpt/cli.py @@ -99,7 +99,7 @@ def write_extraction( exporter.export(results, output) elif output_format == "html": output = _as_text_writer(output) - exporter = HTMLExporter() + exporter = HTMLExporter(output=output) exporter.export(results, output) elif output_format == "yaml": output = _as_text_writer(output) diff --git a/src/ontogpt/engines/enrichment.py b/src/ontogpt/engines/enrichment.py index 5585099ee..420227847 100644 --- a/src/ontogpt/engines/enrichment.py +++ b/src/ontogpt/engines/enrichment.py @@ -53,7 +53,7 @@ class EnrichmentPayload(BaseModel): response_text: str = "" """The response text from the summarization task (only filled for LLMs).""" - truncation_factor: float = None + truncation_factor: Optional[float] = None """Fraction of gene descriptions retained after trimming to fit token limit.""" summary: str = "" @@ -65,13 +65,13 @@ class EnrichmentPayload(BaseModel): term_ids: List[str] = [""] """The normalized terms""" - ontological_synopsis: bool = None + ontological_synopsis: Optional[bool] = None """True if the gene descriptions used the ontological synopsis""" - combined_synopsis: bool = None + combined_synopsis: Optional[bool] = None """True if the gene descriptions used both ontological and narrative synopses""" - annotations: bool = None + annotations: Optional[bool] = None """True if the gene descriptions used the annotations (vs latent KB)""" response_token_length: int = 0 diff --git a/src/ontogpt/engines/knowledge_engine.py b/src/ontogpt/engines/knowledge_engine.py index 5df2ab24a..e63937893 100644 --- a/src/ontogpt/engines/knowledge_engine.py +++ b/src/ontogpt/engines/knowledge_engine.py @@ -92,7 +92,7 @@ class KnowledgeEngine(ABC): """Python class for the template. This is derived from the template and does not need to be set manually.""" - template_module: ModuleType + template_module: Optional[ModuleType] = None """Python module for the template. This is derived from the template and does not need to be set manually.""" @@ -109,23 +109,23 @@ class KnowledgeEngine(ABC): # annotator: TextAnnotatorInterface = None # """Default annotator. TODO: deprecate?""" - annotators: Dict[str, List[TextAnnotatorInterface]] + annotators: Optional[Dict[str, List[TextAnnotatorInterface]]] = None """Annotators for each class. An annotator will ground/map labels to CURIEs. These override the annotators annotated in the template """ - skip_annotators: Optional[List[TextAnnotatorInterface]] + skip_annotators: Optional[List[TextAnnotatorInterface]] = None """Annotators to skip. This overrides any specified in the schema""" - mappers: List[BasicOntologyInterface] + mappers: Optional[List[BasicOntologyInterface]] = None """List of concept mappers, to assist in grounding to desired ID prefix""" - labelers: List[BasicOntologyInterface] + labelers: Optional[List[BasicOntologyInterface]] = None """Labelers that map CURIEs to labels""" - client: OpenAIClient + client: Optional[OpenAIClient] = None """All calls to LLMs are delegated through this client""" dictionary: Dict[str, str] = field(default_factory=dict) diff --git a/src/ontogpt/engines/mapping_engine.py b/src/ontogpt/engines/mapping_engine.py index e41e24894..0b35b02b4 100644 --- a/src/ontogpt/engines/mapping_engine.py +++ b/src/ontogpt/engines/mapping_engine.py @@ -34,7 +34,7 @@ class MappingPredicate(str, Enum): DIFFERENT_FROM = "different_from" UNCATEGORIZED = "uncategorized" - def mappings() -> Dict[str, str]: + def mappings(self) -> Dict[str, str]: """Return the mappings for this predicate.""" return { "skos:exactMatch": "exact_match", diff --git a/src/ontogpt/engines/pheno_engine.py b/src/ontogpt/engines/pheno_engine.py index ed945b224..0efd07e57 100644 --- a/src/ontogpt/engines/pheno_engine.py +++ b/src/ontogpt/engines/pheno_engine.py @@ -49,7 +49,7 @@ def mondo(self): return self._mondo def predict_disease( - self, phenopacket: PHENOPACKET, template_path: Union[str, Path] = None + self, phenopacket: PHENOPACKET, template_path: Optional[Union[str, Path]] = None ) -> List[DIAGNOSIS]: if template_path is None: template_path = DEFAULT_PHENOPACKET_PROMPT diff --git a/src/ontogpt/engines/spires_engine.py b/src/ontogpt/engines/spires_engine.py index d2682c46f..e7d436411 100644 --- a/src/ontogpt/engines/spires_engine.py +++ b/src/ontogpt/engines/spires_engine.py @@ -290,7 +290,7 @@ def map_terms( prompt += "===\n\n" payload = self.client.complete(prompt, show_prompt) # outer parse - best_results = [] + best_results: List[str] = [] for sep in ["\n", "; "]: results = payload.split(sep) if len(results) > len(best_results): diff --git a/src/ontogpt/evaluation/ctd/eval_ctd.py b/src/ontogpt/evaluation/ctd/eval_ctd.py index e7362de4e..7819e4b80 100644 --- a/src/ontogpt/evaluation/ctd/eval_ctd.py +++ b/src/ontogpt/evaluation/ctd/eval_ctd.py @@ -109,15 +109,15 @@ def pairs(dm: TextWithTriples) -> Set: class EvaluationObjectSetRE(BaseModel): - """A result of predicting relationextractions.""" + """A result of predicting relation extractions.""" - precision: float = None - recall: float = None - f1: float = None + precision: float = 0 + recall: float = 0 + f1: float = 0 - training: List[TextWithTriples] = None - predictions: List[PredictionRE] = None - test: List[TextWithTriples] = None + training: Optional[List[TextWithTriples]] = None + predictions: Optional[List[PredictionRE]] = None + test: Optional[List[TextWithTriples]] = None @dataclass diff --git a/src/ontogpt/evaluation/drugmechdb/eval_drugmechdb.py b/src/ontogpt/evaluation/drugmechdb/eval_drugmechdb.py index 07d7a2506..7f3f7d770 100644 --- a/src/ontogpt/evaluation/drugmechdb/eval_drugmechdb.py +++ b/src/ontogpt/evaluation/drugmechdb/eval_drugmechdb.py @@ -222,7 +222,10 @@ def is_candidate(m: target_datamodel.DrugMechanism): return False if len(m.references) != 1: return False - ref = m.references[0] + if m.references is not None: + ref = m.references[0] + else: + ref = "" if ref.startswith("https://go.drugbank.com/drugs/") and ref.endswith( "#mechanism-of-action" ): @@ -272,7 +275,10 @@ def eval_path_prediction(self) -> EvaluationObjectSetDrugMechDB: "drug": test_obj.drug, } results = ke.generalize(stub, eos.training) - predicted_obj = results.extracted_object[0] + if results.extracted_object is not None: + predicted_obj = results.extracted_object[0] + else: + logging.warning(f"No extracted object found for {test_obj.disease}, {test_obj.drug}") pred = PredictionDrugMechDB(predicted_object=predicted_obj, test_object=test_obj) pred.calculate_scores() eos.predictions.append(pred) diff --git a/src/ontogpt/evaluation/go/eval_go.py b/src/ontogpt/evaluation/go/eval_go.py index d3224420b..e699768e5 100644 --- a/src/ontogpt/evaluation/go/eval_go.py +++ b/src/ontogpt/evaluation/go/eval_go.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from pathlib import Path from random import shuffle -from typing import Dict, List +from typing import Dict, List, Optional import yaml from oaklib import get_implementation_from_shorthand @@ -23,9 +23,9 @@ class PredictionGO(BaseModel): - predicted_object: MetabolicProcess = None - test_object: MetabolicProcess = None - scores: Dict[str, SimilarityScore] = None + predicted_object: Optional[MetabolicProcess] = None + test_object: Optional[MetabolicProcess] = None + scores: Optional[Dict[str, SimilarityScore]] = None def calculate_scores(self): self.scores = {} @@ -44,9 +44,9 @@ def calculate_scores(self): class EvaluationObjectSetGO(BaseModel): """A result of extracting knowledge on text.""" - test: List[MetabolicProcess] = None - training: List[MetabolicProcess] = None - predictions: List[PredictionGO] = None + test: Optional[List[MetabolicProcess]] = None + training: Optional[List[MetabolicProcess]] = None + predictions: Optional[List[PredictionGO]] = None @dataclass @@ -143,10 +143,11 @@ def eval(self) -> EvaluationObjectSetGO: eos = self.create_test_and_training() eos.predictions = [] print(yaml.dump(eos.dict())) - for test_obj in eos.test[0:10]: - print(yaml.dump(test_obj.dict())) - predicted_obj = ke.generalize({"label": test_obj.label}, eos.training[0:4]) - pred = PredictionGO(predicted_object=predicted_obj, test_object=test_obj) - pred.calculate_scores() - eos.predictions.append(pred) + if eos.test is not None: + for test_obj in eos.test[0:10]: + print(yaml.dump(test_obj.dict())) + predicted_obj = ke.generalize({"label": test_obj.label}, eos.training[0:4]) + pred = PredictionGO(predicted_object=predicted_obj, test_object=test_obj) + pred.calculate_scores() + eos.predictions.append(pred) return eos diff --git a/src/ontogpt/io/html_exporter.py b/src/ontogpt/io/html_exporter.py index 6a830b9ec..21e1860b9 100644 --- a/src/ontogpt/io/html_exporter.py +++ b/src/ontogpt/io/html_exporter.py @@ -1,6 +1,7 @@ """HTML Exporter.""" import html from dataclasses import dataclass +from io import BytesIO from pathlib import Path from typing import Any, TextIO, Union @@ -19,9 +20,9 @@ class HTMLExporter(Exporter): TODO: rewrite to use bootstrap """ - output: TextIO + output: Union[BytesIO, TextIO] - def export(self, extraction_output: ExtractionResult, output: Union[str, Path, TextIO]): + def export(self, extraction_output: ExtractionResult, output: Union[str, Path, TextIO, BytesIO]): if isinstance(output, Path): output = str(output) if isinstance(output, str): diff --git a/src/ontogpt/io/yaml_wrapper.py b/src/ontogpt/io/yaml_wrapper.py index 916d79e24..04e76b614 100644 --- a/src/ontogpt/io/yaml_wrapper.py +++ b/src/ontogpt/io/yaml_wrapper.py @@ -51,3 +51,4 @@ def dump_minimal_yaml(obj: Any, minimize=True, file: Optional[TextIO] = None) -> return file.getvalue() else: yaml.dump(eliminate_empty(obj, not minimize), file) + return ""