Skip to content

Commit

Permalink
FIX: update nncf
Browse files Browse the repository at this point in the history
  • Loading branch information
cyLi-Tiger authored Feb 23, 2021
1 parent 42d7296 commit 744a905
Showing 1 changed file with 31 additions and 18 deletions.
49 changes: 31 additions & 18 deletions recbole/model/general_recommender/nncf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,16 @@ def _init_weights(self, module):

# Unify embedding length
def Max_ner(self, lst, max_ner):
r"""Unify embedding length of neighborhood information.
r"""Unify embedding length of neighborhood information for efficiency consideration.
Truncate the list if the length is larger than max_ner.
Otherwise, pad it with 0.
Args:
lst (int): The input list contains node's neighbors.
lst (list): The input list contains node's neighbors.
max_ner (int): The number of neighbors we choose for each node.
Returns:
int: The list of a node's neighbors, shape: [number of nodes, max_ner]
list: The list of a node's community neighbors.
"""
Expand All @@ -114,32 +116,35 @@ def Max_ner(self, lst, max_ner):

# Find other nodes in the same community
def get_community_member(self, partition, community_dict, node, kind):
r"""Find other nodes in the same community.
r"""Find other nodes in the same community.
e.g. If the node starts with letter "i",
the other nodes start with letter "i" in the same community dict group are its community neighbors.
Args:
partition (int): The input dict that contains the community each node belongs.
community_dict (int): The input dict that shows the nodes each community contains.
partition (dict): The input dict that contains the community each node belongs.
community_dict (dict): The input dict that shows the nodes each community contains.
node (int): The id of the input node.
kind (char): The type of the input node.
Returns:
int: The list of a node's community neighbors.
list: The list of a node's community neighbors.
"""
comm = community_dict[partition[node]]
return [x for x in comm if x.startswith(kind)]

# Prepare neiborhood embeddings, i.e. I(u) and U(i)
def prepare_vector_element(self, partition, relation, community_dict):
r"""Prepare neiborhood embeddings, i.e. I(u) and U(i).
r"""Find the community neighbors of each node, i.e. I(u) and U(i).
Then reset the id of nodes.
Args:
partition (int): The input dict that contains the community each node belongs.
relation (int): The input list that contains the relationships of users and items.
community_dict (int): The input dict that shows the nodes each community contains.
partition (dict): The input dict that contains the community each node belongs.
relation (list): The input list that contains the relationships of users and items.
community_dict (dict): The input dict that shows the nodes each community contains.
Returns:
int: The list of a batch of nodes' neighbors.
list: The list of nodes' community neighbors.
"""
item2user_neighbor_lst = [[] for _ in range(self.n_items)]
Expand Down Expand Up @@ -170,6 +175,10 @@ def prepare_vector_element(self, partition, relation, community_dict):
# Get neighborhood embeddings using louvain method
def get_neigh_louvain(self):
r"""Get neighborhood information using louvain algorithm.
First, change the id of node,
for example, the id of user node "1" will be set to "u_1" in order to use louvain algorithm.
Second, use louvain algorithm to seperate nodes into different communities.
Finally, find the community neighbors of each node with the same type and reset the id of the nodes.
Returns:
torch.IntTensor: The neighborhood nodes of a batch of user or item, shape: [batch_size, neigh_num]
Expand Down Expand Up @@ -205,16 +214,16 @@ def get_neigh_louvain(self):

# Count the similarity of node and direct neighbors using jaccard method
def count_jaccard(self, inters, node, neigh_list, kind):
r""" Count the similarity of node and its direct neighbors using jaccard method.
r""" Count the similarity of the node and its direct neighbors using jaccard similarity.
Args:
inters (int): The input list that contains the relationships of users and items.
inters (list): The input list that contains the relationships of users and items.
node (int): The id of the input node.
neigh_list (int): The input list that contains the neighbors of the input node.
neigh_list (list): The input list that contains the neighbors of the input node.
kind (char): The type of the input node.
Returns:
float: The list of jaccard similarity score between the node and its neighbors.
list: The list of jaccard similarity score between the node and its neighbors.
"""
if kind == 'u':
Expand All @@ -241,6 +250,9 @@ def count_jaccard(self, inters, node, neigh_list, kind):
# Get neighborhood embeddings using knn method
def get_neigh_knn(self):
r"""Get neighborhood information using knn algorithm.
Find direct neighbors of each node, if the number of direct neighbors is less than neigh_num,
add other similar neighbors using jaccard similarity.
Otherwise, select random top k direct neighbors, k equals to the number of neighbors.
Returns:
torch.IntTensor: The neighborhood nodes of a batch of user or item, shape: [batch_size, neigh_num]
Expand Down Expand Up @@ -297,7 +309,8 @@ def get_neigh_knn(self):
# Get neighborhood embeddings using random method
def get_neigh_random(self):
r"""Get neighborhood information using random algorithm.
Select random top k direct neighbors, k equals to the number of neighbors.
Returns:
torch.IntTensor: The neighborhood nodes of a batch of user or item, shape: [batch_size, neigh_num]
"""
Expand Down Expand Up @@ -332,7 +345,7 @@ def get_neigh_random(self):

# Get neighborhood embeddings
def get_neigh_info(self, user, item):
r"""Get neighborhood embeddings.
r"""Get a batch of neighborhood embedding tensor according to input id.
Args:
user (torch.LongTensor): The input tensor that contains user's id, shape: [batch_size, ]
Expand Down

0 comments on commit 744a905

Please sign in to comment.