diff --git a/recbole/model/general_recommender/nncf.py b/recbole/model/general_recommender/nncf.py index 623d1c9b9..79692eada 100644 --- a/recbole/model/general_recommender/nncf.py +++ b/recbole/model/general_recommender/nncf.py @@ -92,14 +92,16 @@ def _init_weights(self, module): # Unify embedding length def Max_ner(self, lst, max_ner): - r"""Unify embedding length of neighborhood information. + r"""Unify embedding length of neighborhood information for efficiency consideration. + Truncate the list if the length is larger than max_ner. + Otherwise, pad it with 0. Args: - lst (int): The input list contains node's neighbors. + lst (list): The input list contains node's neighbors. max_ner (int): The number of neighbors we choose for each node. Returns: - int: The list of a node's neighbors, shape: [number of nodes, max_ner] + list: The list of a node's community neighbors. """ @@ -114,16 +116,18 @@ def Max_ner(self, lst, max_ner): # Find other nodes in the same community def get_community_member(self, partition, community_dict, node, kind): - r"""Find other nodes in the same community. + r"""Find other nodes in the same community. + e.g. If the node starts with letter "i", + the other nodes start with letter "i" in the same community dict group are its community neighbors. Args: - partition (int): The input dict that contains the community each node belongs. - community_dict (int): The input dict that shows the nodes each community contains. + partition (dict): The input dict that contains the community each node belongs. + community_dict (dict): The input dict that shows the nodes each community contains. node (int): The id of the input node. kind (char): The type of the input node. Returns: - int: The list of a node's community neighbors. + list: The list of a node's community neighbors. """ comm = community_dict[partition[node]] @@ -131,15 +135,16 @@ def get_community_member(self, partition, community_dict, node, kind): # Prepare neiborhood embeddings, i.e. I(u) and U(i) def prepare_vector_element(self, partition, relation, community_dict): - r"""Prepare neiborhood embeddings, i.e. I(u) and U(i). + r"""Find the community neighbors of each node, i.e. I(u) and U(i). + Then reset the id of nodes. Args: - partition (int): The input dict that contains the community each node belongs. - relation (int): The input list that contains the relationships of users and items. - community_dict (int): The input dict that shows the nodes each community contains. + partition (dict): The input dict that contains the community each node belongs. + relation (list): The input list that contains the relationships of users and items. + community_dict (dict): The input dict that shows the nodes each community contains. Returns: - int: The list of a batch of nodes' neighbors. + list: The list of nodes' community neighbors. """ item2user_neighbor_lst = [[] for _ in range(self.n_items)] @@ -170,6 +175,10 @@ def prepare_vector_element(self, partition, relation, community_dict): # Get neighborhood embeddings using louvain method def get_neigh_louvain(self): r"""Get neighborhood information using louvain algorithm. + First, change the id of node, + for example, the id of user node "1" will be set to "u_1" in order to use louvain algorithm. + Second, use louvain algorithm to seperate nodes into different communities. + Finally, find the community neighbors of each node with the same type and reset the id of the nodes. Returns: torch.IntTensor: The neighborhood nodes of a batch of user or item, shape: [batch_size, neigh_num] @@ -205,16 +214,16 @@ def get_neigh_louvain(self): # Count the similarity of node and direct neighbors using jaccard method def count_jaccard(self, inters, node, neigh_list, kind): - r""" Count the similarity of node and its direct neighbors using jaccard method. + r""" Count the similarity of the node and its direct neighbors using jaccard similarity. Args: - inters (int): The input list that contains the relationships of users and items. + inters (list): The input list that contains the relationships of users and items. node (int): The id of the input node. - neigh_list (int): The input list that contains the neighbors of the input node. + neigh_list (list): The input list that contains the neighbors of the input node. kind (char): The type of the input node. Returns: - float: The list of jaccard similarity score between the node and its neighbors. + list: The list of jaccard similarity score between the node and its neighbors. """ if kind == 'u': @@ -241,6 +250,9 @@ def count_jaccard(self, inters, node, neigh_list, kind): # Get neighborhood embeddings using knn method def get_neigh_knn(self): r"""Get neighborhood information using knn algorithm. + Find direct neighbors of each node, if the number of direct neighbors is less than neigh_num, + add other similar neighbors using jaccard similarity. + Otherwise, select random top k direct neighbors, k equals to the number of neighbors. Returns: torch.IntTensor: The neighborhood nodes of a batch of user or item, shape: [batch_size, neigh_num] @@ -297,7 +309,8 @@ def get_neigh_knn(self): # Get neighborhood embeddings using random method def get_neigh_random(self): r"""Get neighborhood information using random algorithm. - + Select random top k direct neighbors, k equals to the number of neighbors. + Returns: torch.IntTensor: The neighborhood nodes of a batch of user or item, shape: [batch_size, neigh_num] """ @@ -332,7 +345,7 @@ def get_neigh_random(self): # Get neighborhood embeddings def get_neigh_info(self, user, item): - r"""Get neighborhood embeddings. + r"""Get a batch of neighborhood embedding tensor according to input id. Args: user (torch.LongTensor): The input tensor that contains user's id, shape: [batch_size, ]