Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for the BC5CDR evaluation #226

Merged
merged 9 commits into from
Oct 5, 2023
56 changes: 28 additions & 28 deletions src/ontogpt/evaluation/ctd/eval_ctd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@

import yaml
from bioc import biocxml
from oaklib import BasicOntologyInterface, get_implementation_from_shorthand
from oaklib import BasicOntologyInterface, get_adapter
from pydantic import BaseModel

from ontogpt.engines.knowledge_engine import chunk_text
from ontogpt.engines.spires_engine import SPIRESEngine
from ontogpt.evaluation.evaluation_engine import SimilarityScore, SPIRESEvaluationEngine
from ontogpt.templates.core import Publication, Triple
from ontogpt.templates.ctd import (
ChemicalToDiseaseDocument,
ChemicalToDiseaseRelationship,
Publication,
TextWithTriples,
)

Expand All @@ -49,8 +49,11 @@
logger = logging.getLogger(__name__)


def negated(Triple) -> bool:
return Triple.qualifier and Triple.qualifier.lower() == "not"
def negated(ChemicalToDiseaseRelationship) -> bool:
return (
ChemicalToDiseaseRelationship.qualifier
and ChemicalToDiseaseRelationship.qualifier.lower() == "not"
)


class PredictionRE(BaseModel):
Expand Down Expand Up @@ -129,7 +132,6 @@ class EvaluationObjectSetRE(BaseModel):

@dataclass
class EvalCTD(SPIRESEvaluationEngine):
# ontology: OboGraphInterface = None
subject_prefix = "MESH"
object_prefix = "MESH"

Expand All @@ -155,19 +157,29 @@ def load_cases(self, path: Path) -> Iterable[ChemicalToDiseaseDocument]:
doc[p.infons["type"]] = p.text
title = doc["title"]
abstract = doc["abstract"]
# text = f"Title: {title} Abstract: {abstract}"
logger.debug(f"Title: {title} Abstract: {abstract}")
for r in document.relations:
i = r.infons
t = Triple(
subject=f"{self.subject_prefix}:{i['Chemical']}",
predicate=RMAP[i["relation"]],
object=f"{self.object_prefix}:{i['Disease']}",
t = ChemicalToDiseaseRelationship.model_validate(
{
"subject": f"{self.subject_prefix}:{i['Chemical']}",
"predicate": RMAP[i["relation"]],
"object": f"{self.object_prefix}:{i['Disease']}",
}
)
triples_by_text[(title, abstract)].append(t)
i = 0
for (title, abstract), triples in triples_by_text.items():
pub = Publication(title=title, abstract=abstract)
i = i + 1
pub = Publication.model_validate(
{
"id": str(i),
"title": title,
"abstract": abstract,
}
)
logger.debug(f"Triples: {len(triples)} for Title: {title} Abstract: {abstract}")
yield ChemicalToDiseaseDocument(publication=pub, triples=triples)
yield ChemicalToDiseaseDocument.model_validate({"publication": pub, "triples": triples})

def create_training_set(self, num=100):
ke = self.extractor
Expand All @@ -176,12 +188,12 @@ def create_training_set(self, num=100):
for doc in docs[0:num]:
text = doc.text
prompt = ke.get_completion_prompt(None, text)
completion = ke.serialize_object(m)
completion = ke.serialize_object()
yield dict(prompt=prompt, completion=completion)

def eval(self) -> EvaluationObjectSetRE:
"""Evaluate the ability to extract relations."""
labeler = get_implementation_from_shorthand("sqlite:obo:mesh")
labeler = get_adapter("sqlite:obo:mesh")
num_test = self.num_tests
ke = self.extractor
docs = list(self.load_test_cases())
Expand Down Expand Up @@ -217,18 +229,6 @@ def eval(self) -> EvaluationObjectSetRE:
logger.debug(f"concatenated triples: {predicted_obj.triples}")
named_entities.extend(extraction.named_entities)

# title_extraction = ke.extract_from_text(doc.publication.title)
# logger.info(f"{len(title_extraction.extracted_object.triples)}\
# triples from: Title {doc.publication.title}")
# abstract_extraction = ke.extract_from_text(doc.publication.abstract)
# logger.info(f"{len(abstract_extraction.extracted_object.triples)}\
# triples from: Abstract {doc.publication.abstract}")
# ke.merge_resultsets([results, results2])
# predicted_obj = title_extraction.extracted_object
# predicted_obj.triples.extend(abstract_extraction.extracted_object.triples)
# logger.info(f"{len(predicted_obj.triples)} total triples, after concatenation")
# logger.debug(f"concatenated triples: {predicted_obj.triples}")

def included(t: ChemicalToDiseaseRelationship):
if not [var for var in (t.subject, t.object, t.predicate) if var is None]:
return (
Expand All @@ -249,10 +249,10 @@ def included(t: ChemicalToDiseaseRelationship):
pred = PredictionRE(predicted_object=predicted_obj, test_object=doc)
pred.named_entities = named_entities
logger.info("PRED")
logger.info(yaml.dump(pred.dict()))
logger.info(yaml.dump(data=pred.model_dump()))
logger.info("Calc scores")
pred.calculate_scores(labelers=[labeler])
logger.info(yaml.dump(pred.dict()))
logger.info(yaml.dump(data=pred.model_dump()))
eos.predictions.append(pred)
self.calc_stats(eos)
return eos
Expand Down
Loading