Skip to content

Commit

Permalink
Speed up Wang's semantic similarity calculations.
Browse files Browse the repository at this point in the history
  • Loading branch information
dvklopfenstein committed Nov 22, 2020
1 parent 9181bb8 commit b7b63bb
Showing 1 changed file with 36 additions and 27 deletions.
63 changes: 36 additions & 27 deletions goatools/semsim/termwise/dag_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
__copyright__ = "Copyright (C) 2020-present, DV Klopfenstein. All rights reserved."
__author__ = "DV Klopfenstein"

## import timeit
## from goatools.godag.prttime import prt_hms


class DagA:
"""A GO term, A, can be represented as DAG_a = (A, T_a, E_a), aka a GoSubDag"""

def __init__(self, go_a, gosubdag, rel2scf):
self.go_usr = go_a
self.go_a = gosubdag.go2obj[go_a].item_id
self.gosubdag = gosubdag
self.go2svalue = self._init_go2svalue(rel2scf)
self.goids = set(self.go2svalue.keys())
def __init__(self, go_a, ancestors, go2depth, w_e, godag):
self.go_a = go_a
self.ancestors = ancestors
#tic = timeit.default_timer()
self.goids = self._init_goids()
#prt_hms(tic, '\nDagA INIT GO IDs')
self.go2svalue = self._init_go2svalue(go2depth, w_e, godag)
#prt_hms(tic, 'DagA SVALUES')

def get_sv(self):
"""Get the semantic value of GO term A"""
Expand All @@ -23,24 +28,20 @@ def get_svalues(self, goids):
s_go2svalue = self.go2svalue
return [s_go2svalue[go] for go in goids]

def _init_go2svalue(self, rel2scf):
def _init_go2svalue(self, go2depth, w_e, godag):
"""S-value: the contribution of GO term, t, to the semantics of GO term, A"""
#tic = timeit.default_timer()
go2svalue = {self.go_a: 1.0}
s_go2obj = self.gosubdag.go2obj
s_rels = self.gosubdag.relationships
terms_a = set(self.gosubdag.rcntobj.go2ancestors[self.go_a])
ancestors_sorted = self._get_sorted(terms_a)
terms_a.add(self.go_a)
for ancestor_id, _ in ancestors_sorted:
## print(ancestor_id, ntd)
goterm = s_go2obj[ancestor_id]
svals = []
weight = rel2scf['is_a']
for cobj in goterm.children:
if cobj.item_id in terms_a:
svals.append(weight*go2svalue[cobj.item_id])
weight = rel2scf['part_of']
for rel in s_rels:
if not self.ancestors:
return go2svalue
terms_a = self.goids
w_r = {r:v for r, v in w_e.items() if r != 'is_a'}
#prt_hms(tic, 'DagA edge weights wo/is_a')
for ancestor_id in self._get_sorted(go2depth):
goterm = godag[ancestor_id]
weight = w_e['is_a']
svals = [weight*go2svalue[o.item_id] for o in goterm.children if o.item_id in terms_a]
for rel, weight in w_r.items():
if rel in goterm.relationship_rev:
for cobj in goterm.relationship_rev[rel]:
if cobj.item_id in terms_a:
Expand All @@ -50,12 +51,20 @@ def _init_go2svalue(self, rel2scf):
## print(ancestor_id, max(svals))
return go2svalue

def _get_sorted(self, ancestors):
def _get_sorted(self, go2depth):
"""Get the sorted ancestors"""
go2nt = self.gosubdag.get_go2nt(ancestors)
if self.gosubdag.relationships:
return sorted(go2nt.items(), key=lambda t: t[1].reldepth, reverse=True)
return sorted(go2nt.items(), key=lambda t: t[1].depth, reverse=True)
#tic = timeit.default_timer()
go2dep = {go:go2depth[go] for go in self.ancestors}
go_dep = sorted(go2dep.items(), key=lambda t: t[1], reverse=True)
gos, _ = zip(*go_dep)
#prt_hms(tic, 'DagA SORTED')
return gos

def _init_goids(self):
"""Return all GO IDs in GO_a's GODAG"""
goids = set(self.ancestors)
goids.add(self.go_a)
return goids


# Copyright (C) 2020-present, DV Klopfenstein. All rights reserved."

0 comments on commit b7b63bb

Please sign in to comment.