Skip to content

Commit

Permalink
Merge pull request #21 from BBVA/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
sbasaldua authored Nov 5, 2024
2 parents 5eb8079 + 7826608 commit 86274e5
Show file tree
Hide file tree
Showing 30 changed files with 112,329 additions and 174 deletions.
File renamed without changes.
6 changes: 3 additions & 3 deletions mercury/graph/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,10 @@ def _from_dataframe(self, edges, nodes, keys):
edges = edges.withColumnRenamed(weight, 'weight')

if nodes is not None:
nodes = nodes.withColumnRenamed(id, 'id')
nodes = nodes.withColumnRenamed(id, 'id').dropDuplicates(['id'])
else:
src_nodes = edges.select(src).distinct().withColumnRenamed(src, id)
dst_nodes = edges.select(dst).distinct().withColumnRenamed(dst, id)
src_nodes = edges.select('src').distinct().withColumnRenamed('src', 'id')
dst_nodes = edges.select('dst').distinct().withColumnRenamed('dst', 'id')
nodes = src_nodes.union(dst_nodes).distinct()

g = SparkInterface().graphframes.GraphFrame(nodes, edges)
Expand Down
1 change: 1 addition & 0 deletions mercury/graph/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from mercury.graph.embeddings.embeddings import Embeddings
from mercury.graph.embeddings.graphembeddings import GraphEmbedding
from mercury.graph.embeddings.spark_node2vec import SparkNode2Vec
4 changes: 3 additions & 1 deletion mercury/graph/embeddings/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from scipy.spatial.distance import cdist
from mercury.graph.core._njit import njit

from mercury.graph.core.base import BaseClass


@njit
def _elliptic_rotate(self_em, iu, iv, cos_w, sin_w):
Expand All @@ -21,7 +23,7 @@ def _elliptic_rotate(self_em, iu, iv, cos_w, sin_w):
return self_em


class Embeddings:
class Embeddings(BaseClass):
"""
This class holds a matrix object that is interpreted as the embeddings for any list of objects, not only the nodes of a graph. You
can see this class as the internal object holding the embedding for other classes such as class GraphEmbedding.
Expand Down
19 changes: 14 additions & 5 deletions mercury/graph/embeddings/graphembeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import pickle

import numpy as np
import pandas as pd
import networkx as nx

from mercury.graph.core import Graph, njit, graph_i4
from mercury.graph.core.base import BaseClass
from mercury.graph.embeddings import Embeddings


Expand Down Expand Up @@ -63,7 +65,7 @@ def _random_walks(r_ini, r_len, r_sum, r_col, r_wgt, TotW, n_jmp, max_jpe):
return (convrge, diverge)


class GraphEmbedding:
class GraphEmbedding(BaseClass):
"""
Create an embedding mapping the nodes of a graph.
Expand Down Expand Up @@ -93,15 +95,21 @@ def __init__(
load_file=None,
):
"""GraphEmbedding class constructor"""
if load_file is not None:
self._load(load_file)
return
if load_file is None and (dimension is None or n_jumps is None):
raise ValueError(
"Parameters dimension and n_jumps are required when load_file is None"
)

self.dimension = dimension
self.n_jumps = n_jumps
self.max_per_epoch = max_per_epoch
self.learn_step = learn_step
self.bidirectional = bidirectional
self.load_file = load_file

if self.load_file is not None:
self._load(self.load_file)
return

def __getitem__(self, arg):
"""
Expand Down Expand Up @@ -218,6 +226,7 @@ def get_most_similar_nodes(
Returns:
(list): list of k most similar nodes and list of similarities of the most similar nodes
(DataFrame): A list of k most similar nodes as a `pd.DataFrame[word: string, similarity: double]`
"""
node_index = self.node_ids.index(node_id)

Expand All @@ -230,7 +239,7 @@ def get_most_similar_nodes(
else:
nodes = list(ordered_indices)

return nodes, ordered_similarities
return pd.DataFrame({"word": nodes, "similarity": ordered_similarities})

def save(self, file_name, save_embedding=False):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging

from mercury.graph.core import Graph
from mercury.graph.graphml.base import BaseClass
from mercury.graph.core.base import BaseClass

from mercury.graph.core.spark_interface import SparkInterface, pyspark_installed, graphframes_installed
from mercury.graph.core.spark_interface import (
SparkInterface,
pyspark_installed,
graphframes_installed,
)

if pyspark_installed:
import pyspark.sql.functions as f
Expand Down Expand Up @@ -134,7 +138,7 @@ def fit(self, G: Graph):

if not self.use_cached_rw:
paths = (
self._run_rw(G, self.sampling_ratio, self.num_epochs, self.batch_size)
self._run_rw(G)
.withColumn("size", f.size("random_walks"))
.where(f.col("size") > 1)
.drop("size")
Expand All @@ -150,9 +154,7 @@ def fit(self, G: Graph):
if self.num_paths_per_node > 1:
for block_id in range(1, self.num_paths_per_node):
new_paths = (
self._run_rw(
G, self.sampling_ratio, self.num_epochs, self.batch_size
)
self._run_rw(G)
.withColumn("size", f.size("random_walks"))
.where(f.col("size") > 1)
.drop("size")
Expand All @@ -172,7 +174,8 @@ def fit(self, G: Graph):
self.paths_ = paths.persist()
else:
self.paths_ = (
SparkInterface().read_parquet(self.path_cache)
SparkInterface()
.read_parquet(self.path_cache)
.drop("block")
.repartition(self.n_partitions_cache)
.persist()
Expand Down Expand Up @@ -254,15 +257,15 @@ def _load(self, file_name):

self.node2vec_ = Word2VecModel.load(file_name)

def _start_rw(self, G: Graph, sampling_ratio):
def _start_rw(self, G: Graph):
aux_vert = (
G.graphframe.vertices.groupBy(f.col("id"))
.agg(f.collect_list(f.col("id")).alias("tmp_rw_aux_acc_path"))
.withColumn("tmp_rw_init_p", f.rand())
.withColumn(
"new_rw_acc_path",
f.when(
f.col("tmp_rw_init_p") <= sampling_ratio,
f.col("tmp_rw_init_p") <= self.sampling_ratio,
f.col("tmp_rw_aux_acc_path"),
).otherwise(f.lit(None)),
)
Expand Down Expand Up @@ -291,7 +294,7 @@ def _start_rw(self, G: Graph, sampling_ratio):
f.col("weight"),
f.col("new_rw_norm_cumsum"),
)
).persist()
)

self.gx = GraphFrame(aux_vert, aux_edges)

Expand Down Expand Up @@ -328,15 +331,16 @@ def _update_state_with_next_step(self, i):

return selected_next_step

def _run_rw(self, G: Graph, sampling_ratio, num_epochs, batch_size):
self._start_rw(G, sampling_ratio)
def _run_rw(self, G: Graph):
self._start_rw(G)

for i in range(num_epochs):
for i in range(self.num_epochs):

aux_vert = self._update_state_with_next_step(i)
aux_vert = aux_vert.checkpoint()
self.gx = GraphFrame(aux_vert, self.gx.edges)

if (i + 1) % batch_size == 0:
if (i + 1) % self.batch_size == 0:
old_aux_vert = aux_vert
aux_vert = AggregateMessages.getCachedDataFrame(aux_vert)
old_aux_vert.unpersist()
Expand Down
Loading

0 comments on commit 86274e5

Please sign in to comment.