diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7685cfb2..fc3cdae4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: strategy: matrix: python-version: [3.6,3.7] - torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.0] + torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.0,1.8.1] # exclude: # - python-version: 3.5 diff --git a/README.md b/README.md index dfb7be53..a50aeb93 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,9 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St | AutoInt | [CIKM 2019][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921) | | ONN | [arxiv 2019][Operation-aware Neural Networks for User Response Prediction](https://arxiv.org/pdf/1904.12579.pdf) | | FiBiNET | [RecSys 2019][FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction](https://arxiv.org/pdf/1905.09433.pdf) | +| IFM | [IJCAI 2019][An Input-aware Factorization Machine for Sparse Prediction](https://www.ijcai.org/Proceedings/2019/0203.pdf) | | DCN V2 | [arxiv 2020][DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535) | +| DIFM | [IJCAI 2020][A Dual Input-aware Factorization Machine for CTR Prediction](https://www.ijcai.org/Proceedings/2020/0434.pdf) | ## DisscussionGroup & Related Projects @@ -82,6 +84,11 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St ​ Shen Weichen

Core Dev
Zhejiang Unversity

​ + + ​ pic
+ ​ Zan Shuxun +

Core Dev
Beijing University
of Posts and
Telecommunications

​ + pic
Wang Ze ​ @@ -92,11 +99,6 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St Zhang Wutong

Core Dev
Beijing University
of Posts and
Telecommunications

​ - - ​ pic
- ​ Zan Shuxun -

Core Dev
Beijing University
of Posts and
Telecommunications

​ - ​ pic
Zhang Yuefeng diff --git a/deepctr_torch/__init__.py b/deepctr_torch/__init__.py index b3ae817b..b780468d 100644 --- a/deepctr_torch/__init__.py +++ b/deepctr_torch/__init__.py @@ -2,5 +2,5 @@ from . import models from .utils import check_version -__version__ = '0.2.5' +__version__ = '0.2.6' check_version(__version__) \ No newline at end of file diff --git a/deepctr_torch/layers/activation.py b/deepctr_torch/layers/activation.py index 4ba8758e..01624a05 100644 --- a/deepctr_torch/layers/activation.py +++ b/deepctr_torch/layers/activation.py @@ -12,7 +12,7 @@ class Dice(nn.Module): Output shape: - Same shape as input. - + References - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) - https://github.com/zhougr1993/DeepInterestNetwork, https://github.com/fanoping/DIN-pytorch diff --git a/deepctr_torch/layers/interaction.py b/deepctr_torch/layers/interaction.py index af7d945e..edbfa88b 100644 --- a/deepctr_torch/layers/interaction.py +++ b/deepctr_torch/layers/interaction.py @@ -106,10 +106,11 @@ class BilinearInteraction(nn.Module): Input shape - A list of 3D tensor with shape: ``(batch_size,filed_size, embedding_size)``. Output shape - - 3D tensor with shape: ``(batch_size,filed_size, embedding_size)``. + - 3D tensor with shape: ``(batch_size,filed_size*(filed_size-1)/2, embedding_size)``. Arguments - **filed_size** : Positive integer, number of feature groups. - - **str** : String, types of bilinear functions used in this layer. + - **embedding_size** : Positive integer, embedding size of sparse features. + - **bilinear_type** : String, types of bilinear functions used in this layer. - **seed** : A Python integer to use as random seed. References - [FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction @@ -125,7 +126,7 @@ def __init__(self, filed_size, embedding_size, bilinear_type="interaction", seed self.bilinear = nn.Linear( embedding_size, embedding_size, bias=False) elif self.bilinear_type == "each": - for i in range(filed_size): + for _ in range(filed_size): self.bilinear.append( nn.Linear(embedding_size, embedding_size, bias=False)) elif self.bilinear_type == "interaction": @@ -340,13 +341,14 @@ class InteractingLayer(nn.Module): - [Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.](https://arxiv.org/abs/1810.11921) """ - def __init__(self, in_features, att_embedding_size=8, head_num=2, use_res=True, seed=1024, device='cpu'): + def __init__(self, in_features, att_embedding_size=8, head_num=2, use_res=True, scaling=False, seed=1024, device='cpu'): super(InteractingLayer, self).__init__() if head_num <= 0: raise ValueError('head_num must be a int > 0') self.att_embedding_size = att_embedding_size self.head_num = head_num self.use_res = use_res + self.scaling = scaling self.seed = seed embedding_size = in_features @@ -388,7 +390,8 @@ def forward(self, inputs): values, self.att_embedding_size, dim=2)) inner_product = torch.einsum( 'bnik,bnjk->bnij', querys, keys) # head_num None F F - + if self.scaling: + inner_product /= self.att_embedding_size ** 0.5 self.normalized_att_scores = F.softmax( inner_product, dim=-1) # head_num None F F result = torch.matmul(self.normalized_att_scores, @@ -428,17 +431,20 @@ def __init__(self, in_features, layer_num=2, parameterization='vector', seed=102 self.parameterization = parameterization if self.parameterization == 'vector': # weight in DCN. (in_features, 1) - self.kernels = torch.nn.ParameterList( - [nn.Parameter(nn.init.xavier_normal_(torch.empty(in_features, 1))) for i in range(self.layer_num)]) + self.kernels = nn.Parameter(torch.Tensor(self.layer_num, in_features, 1)) elif self.parameterization == 'matrix': # weight matrix in DCN-M. (in_features, in_features) - self.kernels = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_( - torch.empty(in_features, in_features))) for i in range(self.layer_num)]) + self.kernels = nn.Parameter(torch.Tensor(self.layer_num, in_features, in_features)) else: # error raise ValueError("parameterization should be 'vector' or 'matrix'") - self.bias = torch.nn.ParameterList( - [nn.Parameter(nn.init.zeros_(torch.empty(in_features, 1))) for i in range(self.layer_num)]) + self.bias = nn.Parameter(torch.Tensor(self.layer_num, in_features, 1)) + + for i in range(self.kernels.shape[0]): + nn.init.xavier_normal_(self.kernels[i]) + for i in range(self.bias.shape[0]): + nn.init.zeros_(self.bias[i]) + self.to(device) def forward(self, inputs): @@ -483,18 +489,23 @@ def __init__(self, in_features, low_rank=32, num_experts=4, layer_num=2, device= self.num_experts = num_experts # U: (in_features, low_rank) - self.U_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_( - torch.empty(num_experts, in_features, low_rank))) for i in range(self.layer_num)]) + self.U_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, in_features, low_rank)) # V: (in_features, low_rank) - self.V_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_( - torch.empty(num_experts, in_features, low_rank))) for i in range(self.layer_num)]) + self.V_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, in_features, low_rank)) # C: (low_rank, low_rank) - self.C_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_( - torch.empty(num_experts, low_rank, low_rank))) for i in range(self.layer_num)]) + self.C_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, low_rank, low_rank)) self.gating = nn.ModuleList([nn.Linear(in_features, 1, bias=False) for i in range(self.num_experts)]) - self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_( - torch.empty(in_features, 1))) for i in range(self.layer_num)]) + self.bias = nn.Parameter(torch.Tensor(self.layer_num, in_features, 1)) + + init_para_list = [self.U_list, self.V_list, self.C_list] + for i in range(len(init_para_list)): + for j in range(self.layer_num): + nn.init.xavier_normal_(init_para_list[i][j]) + + for i in range(len(self.bias)): + nn.init.zeros_(self.bias[i]) + self.to(device) def forward(self, inputs): diff --git a/deepctr_torch/layers/sequence.py b/deepctr_torch/layers/sequence.py index 64736b5b..550e5878 100644 --- a/deepctr_torch/layers/sequence.py +++ b/deepctr_torch/layers/sequence.py @@ -39,7 +39,7 @@ def _sequence_mask(self, lengths, maxlen=None, dtype=torch.bool): # Returns a mask tensor representing the first N positions of each cell. if maxlen is None: maxlen = lengths.max() - row_vector = torch.arange(0, maxlen, 1).to(self.device) + row_vector = torch.arange(0, maxlen, 1).to(lengths.device) matrix = torch.unsqueeze(lengths, dim=-1) mask = row_vector < matrix @@ -70,6 +70,7 @@ def forward(self, seq_value_len_list): hist = torch.sum(hist, dim=1, keepdim=False) if self.mode == 'mean': + self.eps = self.eps.to(user_behavior_length.device) hist = torch.div(hist, user_behavior_length.type(torch.float32) + self.eps) hist = torch.unsqueeze(hist, dim=1) diff --git a/deepctr_torch/models/__init__.py b/deepctr_torch/models/__init__.py index 09f1d7c3..43381369 100644 --- a/deepctr_torch/models/__init__.py +++ b/deepctr_torch/models/__init__.py @@ -2,6 +2,8 @@ from .deepfm import DeepFM from .xdeepfm import xDeepFM from .afm import AFM +from .difm import DIFM +from .ifm import IFM from .autoint import AutoInt from .dcn import DCN from .dcnmix import DCNMix diff --git a/deepctr_torch/models/afm.py b/deepctr_torch/models/afm.py index 4d015e72..ae1556d4 100644 --- a/deepctr_torch/models/afm.py +++ b/deepctr_torch/models/afm.py @@ -27,16 +27,17 @@ class AFM(BaseModel): :param seed: integer ,to use as random seed. :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. """ def __init__(self, linear_feature_columns, dnn_feature_columns, use_attention=True, attention_factor=8, l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_att=1e-5, afm_dropout=0, init_std=0.0001, seed=1024, - task='binary', device='cpu'): + task='binary', device='cpu', gpus=None): super(AFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, - device=device) + device=device, gpus=gpus) self.use_attention = use_attention diff --git a/deepctr_torch/models/autoint.py b/deepctr_torch/models/autoint.py index a2001183..c39effb4 100644 --- a/deepctr_torch/models/autoint.py +++ b/deepctr_torch/models/autoint.py @@ -32,19 +32,20 @@ class AutoInt(BaseModel): :param seed: integer ,to use as random seed. :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. - + """ def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3, att_embedding_size=8, att_head_num=2, att_res=True, dnn_hidden_units=(256, 128), dnn_activation='relu', l2_reg_dnn=0, l2_reg_embedding=1e-5, dnn_use_bn=False, dnn_dropout=0, init_std=0.0001, seed=1024, - task='binary', device='cpu'): + task='binary', device='cpu', gpus=None): super(AutoInt, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=0, l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, - device=device) + device=device, gpus=gpus) if len(dnn_hidden_units) <= 0 and att_layer_num <= 0: raise ValueError("Either hidden_layer or att_layer_num must > 0") diff --git a/deepctr_torch/models/basemodel.py b/deepctr_torch/models/basemodel.py index 865c0bfc..bb9d1f7a 100644 --- a/deepctr_torch/models/basemodel.py +++ b/deepctr_torch/models/basemodel.py @@ -59,7 +59,7 @@ def __init__(self, feature_columns, feature_index, init_std=0.0001, device='cpu' device)) torch.nn.init.normal_(self.weight, mean=0, std=init_std) - def forward(self, X): + def forward(self, X, sparse_feat_refine_weight=None): sparse_embedding_list = [self.embedding_dict[feat.embedding_name]( X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]].long()) for @@ -73,26 +73,25 @@ def forward(self, X): sparse_embedding_list += varlen_embedding_list - if len(sparse_embedding_list) > 0 and len(dense_value_list) > 0: - linear_sparse_logit = torch.sum( - torch.cat(sparse_embedding_list, dim=-1), dim=-1, keepdim=False) - linear_dense_logit = torch.cat( + linear_logit = torch.zeros([X.shape[0], 1]).to(sparse_embedding_list[0].device) + if len(sparse_embedding_list) > 0: + sparse_embedding_cat = torch.cat(sparse_embedding_list, dim=-1) + if sparse_feat_refine_weight is not None: + # w_{x,i}=m_{x,i} * w_i (in IFM and DIFM) + sparse_embedding_cat = sparse_embedding_cat * sparse_feat_refine_weight.unsqueeze(1) + sparse_feat_logit = torch.sum(sparse_embedding_cat, dim=-1, keepdim=False) + linear_logit += sparse_feat_logit + if len(dense_value_list) > 0: + dense_value_logit = torch.cat( dense_value_list, dim=-1).matmul(self.weight) - linear_logit = linear_sparse_logit + linear_dense_logit - elif len(sparse_embedding_list) > 0: - linear_logit = torch.sum( - torch.cat(sparse_embedding_list, dim=-1), dim=-1, keepdim=False) - elif len(dense_value_list) > 0: - linear_logit = torch.cat( - dense_value_list, dim=-1).matmul(self.weight) - else: - linear_logit = torch.zeros([X.shape[0], 1]) + linear_logit += dense_value_logit + return linear_logit class BaseModel(nn.Module): def __init__(self, linear_feature_columns, dnn_feature_columns, l2_reg_linear=1e-5, l2_reg_embedding=1e-5, - init_std=0.0001, seed=1024, task='binary', device='cpu'): + init_std=0.0001, seed=1024, task='binary', device='cpu', gpus=None): super(BaseModel, self).__init__() torch.manual_seed(seed) @@ -100,7 +99,11 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, l2_reg_linear=1e self.reg_loss = torch.zeros((1,), device=device) self.aux_loss = torch.zeros((1,), device=device) - self.device = device # device + self.device = device + self.gpus = gpus + if gpus and str(self.gpus[0]) not in self.device: + raise ValueError( + "`gpus[0]` should be the same gpu with `device`") self.feature_index = build_input_features( linear_feature_columns + dnn_feature_columns) @@ -192,14 +195,21 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc torch.from_numpy(y)) if batch_size is None: batch_size = 256 - train_loader = DataLoader( - dataset=train_tensor_data, shuffle=shuffle, batch_size=batch_size) - print(self.device, end="\n") model = self.train() loss_func = self.loss_func optim = self.optim + if self.gpus: + print('parallel running on these gpus:', self.gpus) + model = torch.nn.DataParallel(model, device_ids=self.gpus) + batch_size *= len(self.gpus) # input `batch_size` is batch_size per gpu + else: + print(self.device) + + train_loader = DataLoader( + dataset=train_tensor_data, shuffle=shuffle, batch_size=batch_size) + sample_num = len(train_tensor_data) steps_per_epoch = (sample_num - 1) // batch_size + 1 @@ -224,7 +234,7 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc train_result = {} try: with tqdm(enumerate(train_loader), disable=verbose != 1) as t: - for index, (x_train, y_train) in t: + for _, (x_train, y_train) in t: x = x_train.to(self.device).float() y = y_train.to(self.device).float() @@ -323,7 +333,7 @@ def predict(self, x, batch_size=256): pred_ans = [] with torch.no_grad(): - for index, x_test in enumerate(test_loader): + for _, x_test in enumerate(test_loader): x = x_test[0].to(self.device).float() y_pred = model(x).cpu().data.numpy() # .squeeze() diff --git a/deepctr_torch/models/ccpm.py b/deepctr_torch/models/ccpm.py index 73272b66..7ab098ae 100644 --- a/deepctr_torch/models/ccpm.py +++ b/deepctr_torch/models/ccpm.py @@ -34,6 +34,7 @@ class CCPM(BaseModel): :param seed: integer ,to use as random seed. :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. """ @@ -41,11 +42,11 @@ class CCPM(BaseModel): def __init__(self, linear_feature_columns, dnn_feature_columns, conv_kernel_width=(6, 5), conv_filters=(4, 4), dnn_hidden_units=(256,), l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_dnn=0, dnn_dropout=0, - init_std=0.0001, seed=1024, task='binary', device='cpu', dnn_use_bn=False, dnn_activation='relu'): + init_std=0.0001, seed=1024, task='binary', device='cpu', dnn_use_bn=False, dnn_activation='relu', gpus=None): super(CCPM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, - device=device) + device=device, gpus=gpus) if len(conv_kernel_width) != len(conv_filters): raise ValueError( diff --git a/deepctr_torch/models/dcn.py b/deepctr_torch/models/dcn.py index 4528b9a7..f5ef03bf 100644 --- a/deepctr_torch/models/dcn.py +++ b/deepctr_torch/models/dcn.py @@ -36,18 +36,19 @@ class DCN(BaseModel): :param dnn_activation: Activation function to use in DNN :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. - + """ def __init__(self, linear_feature_columns, dnn_feature_columns, cross_num=2, cross_parameterization='vector', dnn_hidden_units=(128, 128), l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_cross=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False, - task='binary', device='cpu'): + task='binary', device='cpu', gpus=None): super(DCN, self).__init__(linear_feature_columns=linear_feature_columns, dnn_feature_columns=dnn_feature_columns, l2_reg_embedding=l2_reg_embedding, - init_std=init_std, seed=seed, task=task, device=device) + init_std=init_std, seed=seed, task=task, device=device, gpus=gpus) self.dnn_hidden_units = dnn_hidden_units self.cross_num = cross_num self.dnn = DNN(self.compute_input_dim(dnn_feature_columns), dnn_hidden_units, diff --git a/deepctr_torch/models/dcnmix.py b/deepctr_torch/models/dcnmix.py index c01fd44c..9b0e97d4 100644 --- a/deepctr_torch/models/dcnmix.py +++ b/deepctr_torch/models/dcnmix.py @@ -36,8 +36,9 @@ class DCNMix(BaseModel): :param num_experts: Positive integer, number of experts. :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. - + """ def __init__(self, linear_feature_columns, @@ -45,11 +46,11 @@ def __init__(self, linear_feature_columns, dnn_hidden_units=(128, 128), l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_cross=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, low_rank=32, num_experts=4, - dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu'): + dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None): super(DCNMix, self).__init__(linear_feature_columns=linear_feature_columns, dnn_feature_columns=dnn_feature_columns, l2_reg_embedding=l2_reg_embedding, - init_std=init_std, seed=seed, task=task, device=device) + init_std=init_std, seed=seed, task=task, device=device, gpus=gpus) self.dnn_hidden_units = dnn_hidden_units self.cross_num = cross_num self.dnn = DNN(self.compute_input_dim(dnn_feature_columns), dnn_hidden_units, diff --git a/deepctr_torch/models/deepfm.py b/deepctr_torch/models/deepfm.py index f0dfb411..7f90faf7 100644 --- a/deepctr_torch/models/deepfm.py +++ b/deepctr_torch/models/deepfm.py @@ -30,8 +30,9 @@ class DeepFM(BaseModel): :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. - + """ def __init__(self, @@ -39,11 +40,11 @@ def __init__(self, dnn_hidden_units=(256, 128), l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, - dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu'): + dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None): super(DeepFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, - device=device) + device=device, gpus=gpus) self.use_fm = use_fm self.use_dnn = len(dnn_feature_columns) > 0 and len( diff --git a/deepctr_torch/models/dien.py b/deepctr_torch/models/dien.py index b86a9897..6f37c1aa 100644 --- a/deepctr_torch/models/dien.py +++ b/deepctr_torch/models/dien.py @@ -16,25 +16,27 @@ class DIEN(BaseModel): """Instantiates the Deep Interest Evolution Network architecture. - :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. - :param history_feature_list: list,to indicate sequence sparse field - :param gru_type: str,can be GRU AIGRU AUGRU AGRU - :param use_negsampling: bool, whether or not use negtive sampling - :param alpha: float ,weight of auxiliary_loss - :param use_bn: bool. Whether use BatchNormalization before activation or not in deep net - :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN - :param dnn_activation: Activation function to use in DNN - :param att_hidden_units: list,list of positive integer , the layer number and units in each layer of attention net - :param att_activation: Activation function to use in attention net - :param att_weight_normalization: bool.Whether normalize the attention score of local activation unit. - :param l2_reg_dnn: float. L2 regularizer strength applied to DNN - :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector - :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. - :param init_std: float,to use as the initialize std of embedding vector - :param seed: integer ,to use as random seed. - :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss - :param device: str, ``"cpu"`` or ``"cuda:0"`` - :return: A PyTorch model instance. + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param history_feature_list: list,to indicate sequence sparse field + :param gru_type: str,can be GRU AIGRU AUGRU AGRU + :param use_negsampling: bool, whether or not use negtive sampling + :param alpha: float ,weight of auxiliary_loss + :param use_bn: bool. Whether use BatchNormalization before activation or not in deep net + :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN + :param dnn_activation: Activation function to use in DNN + :param att_hidden_units: list,list of positive integer , the layer number and units in each layer of attention net + :param att_activation: Activation function to use in attention net + :param att_weight_normalization: bool.Whether normalize the attention score of local activation unit. + :param l2_reg_dnn: float. L2 regularizer strength applied to DNN + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. + :param init_std: float,to use as the initialize std of embedding vector + :param seed: integer ,to use as random seed. + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss + :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. + :return: A PyTorch model instance. + """ def __init__(self, @@ -43,9 +45,9 @@ def __init__(self, dnn_activation='relu', att_hidden_units=(64, 16), att_activation="relu", att_weight_normalization=True, l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, init_std=0.0001, seed=1024, task='binary', - device='cpu'): + device='cpu', gpus=None): super(DIEN, self).__init__([], dnn_feature_columns, l2_reg_linear=0, l2_reg_embedding=l2_reg_embedding, - init_std=init_std, seed=seed, task=task, device=device) + init_std=init_std, seed=seed, task=task, device=device, gpus=gpus) self.item_features = history_feature_list self.use_negsampling = use_negsampling diff --git a/deepctr_torch/models/difm.py b/deepctr_torch/models/difm.py new file mode 100644 index 00000000..13a3aaab --- /dev/null +++ b/deepctr_torch/models/difm.py @@ -0,0 +1,106 @@ +# -*- coding:utf-8 -*- +""" +Author: + zanshuxun, zanshuxun@aliyun.com +Reference: + [1] Lu W, Yu Y, Chang Y, et al. A Dual Input-aware Factorization Machine for CTR Prediction[C]//IJCAI. 2020: 3139-3145.(https://www.ijcai.org/Proceedings/2020/0434.pdf) +""" +import torch +import torch.nn as nn + +from .basemodel import BaseModel +from ..inputs import combined_dnn_input, SparseFeat, VarLenSparseFeat +from ..layers import FM, DNN, InteractingLayer, concat_fun + + +class DIFM(BaseModel): + """Instantiates the DIFM Network architecture. + + :param linear_feature_columns: An iterable containing all the features used by linear part of the model. + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN + :param l2_reg_linear: float. L2 regularizer strength applied to linear part + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector + :param l2_reg_dnn: float. L2 regularizer strength applied to DNN + :param init_std: float,to use as the initialize std of embedding vector + :param seed: integer ,to use as random seed. + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. + :param dnn_activation: Activation function to use in DNN + :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss + :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on ``device`` . ``gpus[0]`` should be the same gpu with ``device`` . + :return: A PyTorch model instance. + + """ + + def __init__(self, + linear_feature_columns, dnn_feature_columns, att_embedding_size=8, att_head_num=8, + att_res=True, dnn_hidden_units=(256, 128), + l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024, + dnn_dropout=0, + dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None): + super(DIFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear, + l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, + device=device, gpus=gpus) + + if not len(dnn_hidden_units) > 0: + raise ValueError("dnn_hidden_units is null!") + + self.use_dnn = len(dnn_feature_columns) > 0 and len( + dnn_hidden_units) > 0 + self.fm = FM() + + # InteractingLayer (used in AutoInt) = multi-head self-attention + Residual Network + self.vector_wise_net = InteractingLayer(self.embedding_size, att_embedding_size, + att_head_num, att_res, scaling=True, device=device) + + self.bit_wise_net = DNN(self.compute_input_dim(dnn_feature_columns, include_dense=False), + dnn_hidden_units, activation=dnn_activation, l2_reg=l2_reg_dnn, + dropout_rate=dnn_dropout, + use_bn=dnn_use_bn, init_std=init_std, device=device) + self.sparse_feat_num = len(list(filter(lambda x: isinstance(x, SparseFeat) or isinstance(x, VarLenSparseFeat), + dnn_feature_columns))) + + self.transform_matrix_P_vec = nn.Linear( + self.sparse_feat_num*att_embedding_size*att_head_num, self.sparse_feat_num, bias=False).to(device) + self.transform_matrix_P_bit = nn.Linear( + dnn_hidden_units[-1], self.sparse_feat_num, bias=False).to(device) + + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.vector_wise_net.named_parameters()), + l2=l2_reg_dnn) + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.bit_wise_net.named_parameters()), + l2=l2_reg_dnn) + self.add_regularization_weight(self.transform_matrix_P_vec.weight, l2=l2_reg_dnn) + self.add_regularization_weight(self.transform_matrix_P_bit.weight, l2=l2_reg_dnn) + + self.to(device) + + def forward(self, X): + sparse_embedding_list, _ = self.input_from_feature_columns(X, self.dnn_feature_columns, + self.embedding_dict) + if not len(sparse_embedding_list) > 0: + raise ValueError("there are no sparse features") + + att_input = concat_fun(sparse_embedding_list, axis=1) + att_out = self.vector_wise_net(att_input) + att_out = att_out.reshape(att_out.shape[0], -1) + m_vec = self.transform_matrix_P_vec(att_out) + + dnn_input = combined_dnn_input(sparse_embedding_list, []) + dnn_output = self.bit_wise_net(dnn_input) + m_bit = self.transform_matrix_P_bit(dnn_output) + + m_x = m_vec + m_bit # m_x is the complete input-aware factor + + logit = self.linear_model(X, sparse_feat_refine_weight=m_x) + + fm_input = torch.cat(sparse_embedding_list, dim=1) + refined_fm_input = fm_input * m_x.unsqueeze(-1) # \textbf{v}_{x,i}=m_{x,i} * \textbf{v}_i + logit += self.fm(refined_fm_input) + + y_pred = self.out(logit) + + return y_pred diff --git a/deepctr_torch/models/din.py b/deepctr_torch/models/din.py index 0a8e46af..8bac0383 100644 --- a/deepctr_torch/models/din.py +++ b/deepctr_torch/models/din.py @@ -29,6 +29,8 @@ class DIN(BaseModel): :param init_std: float,to use as the initialize std of embedding vector :param seed: integer ,to use as random seed. :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss + :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. """ @@ -37,9 +39,9 @@ def __init__(self, dnn_feature_columns, history_feature_list, dnn_use_bn=False, dnn_hidden_units=(256, 128), dnn_activation='relu', att_hidden_size=(64, 16), att_activation='Dice', att_weight_normalization=False, l2_reg_dnn=0.0, l2_reg_embedding=1e-6, dnn_dropout=0, init_std=0.0001, - seed=1024, task='binary', device='cpu'): + seed=1024, task='binary', device='cpu', gpus=None): super(DIN, self).__init__([], dnn_feature_columns, l2_reg_linear=0, l2_reg_embedding=l2_reg_embedding, - init_std=init_std, seed=seed, task=task, device=device) + init_std=init_std, seed=seed, task=task, device=device, gpus=gpus) self.sparse_feature_columns = list( filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns)) if dnn_feature_columns else [] diff --git a/deepctr_torch/models/fibinet.py b/deepctr_torch/models/fibinet.py index f3e1b436..67ec4783 100644 --- a/deepctr_torch/models/fibinet.py +++ b/deepctr_torch/models/fibinet.py @@ -31,17 +31,18 @@ class FiBiNET(BaseModel): :param dnn_activation: Activation function to use in DNN :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. - + """ def __init__(self, linear_feature_columns, dnn_feature_columns, bilinear_type='interaction', reduction_ratio=3, dnn_hidden_units=(128, 128), l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', - task='binary', device='cpu'): + task='binary', device='cpu', gpus=None): super(FiBiNET, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, - device=device) + device=device, gpus=gpus) self.linear_feature_columns = linear_feature_columns self.dnn_feature_columns = dnn_feature_columns self.filed_size = len(self.embedding_dict) diff --git a/deepctr_torch/models/ifm.py b/deepctr_torch/models/ifm.py new file mode 100644 index 00000000..4f057833 --- /dev/null +++ b/deepctr_torch/models/ifm.py @@ -0,0 +1,89 @@ +# -*- coding:utf-8 -*- +""" +Author: + zanshuxun, zanshuxun@aliyun.com +Reference: + [1] Yu Y, Wang Z, Yuan B. An Input-aware Factorization Machine for Sparse Prediction[C]//IJCAI. 2019: 1466-1472.(https://www.ijcai.org/Proceedings/2019/0203.pdf) +""" +import torch +import torch.nn as nn + +from .basemodel import BaseModel +from ..inputs import combined_dnn_input, SparseFeat, VarLenSparseFeat +from ..layers import FM, DNN + + +class IFM(BaseModel): + """Instantiates the IFM Network architecture. + + :param linear_feature_columns: An iterable containing all the features used by linear part of the model. + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN + :param l2_reg_linear: float. L2 regularizer strength applied to linear part + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector + :param l2_reg_dnn: float. L2 regularizer strength applied to DNN + :param init_std: float,to use as the initialize std of embedding vector + :param seed: integer ,to use as random seed. + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. + :param dnn_activation: Activation function to use in DNN + :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss + :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on ``device`` . ``gpus[0]`` should be the same gpu with ``device`` . + :return: A PyTorch model instance. + + """ + + def __init__(self, + linear_feature_columns, dnn_feature_columns, + dnn_hidden_units=(256, 128), + l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024, + dnn_dropout=0, + dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None): + super(IFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear, + l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, + device=device, gpus=gpus) + + if not len(dnn_hidden_units) > 0: + raise ValueError("dnn_hidden_units is null!") + + self.use_dnn = len(dnn_feature_columns) > 0 and len( + dnn_hidden_units) > 0 + self.fm = FM() + + self.factor_estimating_net = DNN(self.compute_input_dim(dnn_feature_columns, include_dense=False), + dnn_hidden_units, activation=dnn_activation, l2_reg=l2_reg_dnn, + dropout_rate=dnn_dropout, + use_bn=dnn_use_bn, init_std=init_std, device=device) + self.sparse_feat_num = len(list(filter(lambda x: isinstance(x, SparseFeat) or isinstance(x, VarLenSparseFeat), + dnn_feature_columns))) + self.transform_weight_matrix_P = nn.Linear( + dnn_hidden_units[-1], self.sparse_feat_num, bias=False).to(device) + + self.add_regularization_weight( + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.factor_estimating_net.named_parameters()), + l2=l2_reg_dnn) + self.add_regularization_weight(self.transform_weight_matrix_P.weight, l2=l2_reg_dnn) + + self.to(device) + + def forward(self, X): + sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, + self.embedding_dict) + if not len(sparse_embedding_list) > 0: + raise ValueError("there are no sparse features") + + dnn_input = combined_dnn_input(sparse_embedding_list, []) # (batch_size, feat_num * embedding_size) + dnn_output = self.factor_estimating_net(dnn_input) + dnn_output = self.transform_weight_matrix_P(dnn_output) # m'_{x} + input_aware_factor = self.sparse_feat_num * dnn_output.softmax(1) # input_aware_factor m_{x,i} + + logit = self.linear_model(X, sparse_feat_refine_weight=input_aware_factor) + + fm_input = torch.cat(sparse_embedding_list, dim=1) + refined_fm_input = fm_input * input_aware_factor.unsqueeze(-1) # \textbf{v}_{x,i}=m_{x,i}\textbf{v}_i + logit += self.fm(refined_fm_input) + + y_pred = self.out(logit) + + return y_pred diff --git a/deepctr_torch/models/mlr.py b/deepctr_torch/models/mlr.py index 9259694a..8cb60090 100644 --- a/deepctr_torch/models/mlr.py +++ b/deepctr_torch/models/mlr.py @@ -26,14 +26,15 @@ class MLR(BaseModel): :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param bias_feature_columns: An iterable containing all the features used by bias part of the model. :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. - + """ def __init__(self, region_feature_columns, base_feature_columns=None, bias_feature_columns=None, - region_num=4, l2_reg_linear=1e-5, init_std=0.0001, seed=1024, task='binary', device='cpu' + region_num=4, l2_reg_linear=1e-5, init_std=0.0001, seed=1024, task='binary', device='cpu', gpus=None ): - super(MLR, self).__init__(region_feature_columns, region_feature_columns, task=task, device=device) + super(MLR, self).__init__(region_feature_columns, region_feature_columns, task=task, device=device, gpus=gpus) if region_num <= 1: raise ValueError("region_num must > 1") diff --git a/deepctr_torch/models/nfm.py b/deepctr_torch/models/nfm.py index 4120d5de..f01613c7 100644 --- a/deepctr_torch/models/nfm.py +++ b/deepctr_torch/models/nfm.py @@ -29,17 +29,18 @@ class NFM(BaseModel): :param dnn_activation: Activation function to use in deep net :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. - + """ def __init__(self, linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(128, 128), l2_reg_embedding=1e-5, l2_reg_linear=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, bi_dropout=0, - dnn_dropout=0, dnn_activation='relu', task='binary', device='cpu'): + dnn_dropout=0, dnn_activation='relu', task='binary', device='cpu', gpus=None): super(NFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, - device=device) + device=device, gpus=gpus) self.dnn = DNN(self.compute_input_dim(dnn_feature_columns, include_sparse=False) + self.embedding_size, dnn_hidden_units, diff --git a/deepctr_torch/models/onn.py b/deepctr_torch/models/onn.py index b4d4d085..49f59cca 100644 --- a/deepctr_torch/models/onn.py +++ b/deepctr_torch/models/onn.py @@ -50,18 +50,19 @@ class ONN(BaseModel): :param reduce_sum: bool,whether apply reduce_sum on cross vector :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. - + """ def __init__(self, linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(128, 128), l2_reg_embedding=1e-5, l2_reg_linear=1e-5, l2_reg_dnn=0, dnn_dropout=0, init_std=0.0001, seed=1024, dnn_use_bn=False, dnn_activation='relu', - task='binary', device='cpu'): + task='binary', device='cpu', gpus=None): super(ONN, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, - device=device) + device=device, gpus=gpus) # second order part embedding_size = self.embedding_size diff --git a/deepctr_torch/models/pnn.py b/deepctr_torch/models/pnn.py index 2cdeff0a..d72b2d5c 100644 --- a/deepctr_torch/models/pnn.py +++ b/deepctr_torch/models/pnn.py @@ -30,16 +30,17 @@ class PNN(BaseModel): :param kernel_type: str,kernel_type used in outter-product,can be ``'mat'`` , ``'vec'`` or ``'num'`` :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. - + """ def __init__(self, dnn_feature_columns, dnn_hidden_units=(128, 128), l2_reg_embedding=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', use_inner=True, use_outter=False, - kernel_type='mat', task='binary', device='cpu', ): + kernel_type='mat', task='binary', device='cpu', gpus=None): super(PNN, self).__init__([], dnn_feature_columns, l2_reg_linear=0, l2_reg_embedding=l2_reg_embedding, - init_std=init_std, seed=seed, task=task, device=device) + init_std=init_std, seed=seed, task=task, device=device, gpus=gpus) if kernel_type not in ['mat', 'vec', 'num']: raise ValueError("kernel_type must be mat,vec or num") diff --git a/deepctr_torch/models/wdl.py b/deepctr_torch/models/wdl.py index 322b0920..6016eb0a 100644 --- a/deepctr_torch/models/wdl.py +++ b/deepctr_torch/models/wdl.py @@ -28,8 +28,9 @@ class WDL(BaseModel): :param dnn_activation: Activation function to use in DNN :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. - + """ def __init__(self, @@ -37,11 +38,11 @@ def __init__(self, l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False, - task='binary', device='cpu'): + task='binary', device='cpu', gpus=None): super(WDL, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, - device=device) + device=device, gpus=gpus) self.use_dnn = len(dnn_feature_columns) > 0 and len( dnn_hidden_units) > 0 diff --git a/deepctr_torch/models/xdeepfm.py b/deepctr_torch/models/xdeepfm.py index 87cac472..7d5efa01 100644 --- a/deepctr_torch/models/xdeepfm.py +++ b/deepctr_torch/models/xdeepfm.py @@ -34,18 +34,19 @@ class xDeepFM(BaseModel): :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` + :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. - + """ def __init__(self, linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(256, 256), cin_layer_size=(256, 128,), cin_split_half=True, cin_activation='relu', l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, l2_reg_cin=0, init_std=0.0001, seed=1024, dnn_dropout=0, - dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu'): + dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None): super(xDeepFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, - device=device) + device=device, gpus=gpus) self.dnn_hidden_units = dnn_hidden_units self.use_dnn = len(dnn_feature_columns) > 0 and len(dnn_hidden_units) > 0 if self.use_dnn: diff --git a/docs/pics/DIFM.png b/docs/pics/DIFM.png new file mode 100644 index 00000000..76a983b2 Binary files /dev/null and b/docs/pics/DIFM.png differ diff --git a/docs/pics/IFM.png b/docs/pics/IFM.png new file mode 100644 index 00000000..5adf940c Binary files /dev/null and b/docs/pics/IFM.png differ diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index 3399bb06..a7a4eb6e 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -60,6 +60,7 @@ model.fit(model_input,label) ``` ## 4. How to run the demo with GPU ? + ```python import torch device = 'cpu' @@ -70,3 +71,9 @@ if use_cuda and torch.cuda.is_available(): model = DeepFM(...,device=device) ``` + +## 5. How to run the demo with multiple GPUs ? + +```python +model = DeepFM(..., device=device, gpus=[0, 1]) +``` diff --git a/docs/source/Features.md b/docs/source/Features.md index cce54b89..2aaf6787 100644 --- a/docs/source/Features.md +++ b/docs/source/Features.md @@ -241,6 +241,27 @@ Feature Importance and Bilinear feature Interaction NETwork is proposed to dynam [Huang T, Zhang Z, Zhang J. FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1905.09433, 2019.](https://arxiv.org/pdf/1905.09433.pdf) +### IFM(Input-aware Factorization Machine) + +Input-aware Factorization Machine (IFM) learns a unique input-aware factor for the same feature in different instances via a neural network. + +[**IFM Model API**](./deepctr_torch.models.ifm.html) + +![IFM](../pics/IFM.png) + +[Yu Y, Wang Z, Yuan B. An Input-aware Factorization Machine for Sparse Prediction[C]//IJCAI. 2019: 1466-1472.](https://www.ijcai.org/Proceedings/2019/0203.pdf) + +### DIFM(Dual Input-aware Factorization Machine) + +Dual Inputaware Factorization Machines (DIFM) can adaptively reweight the original feature representations at the bit-wise and vector-wise levels simultaneously.Furthermore, DIFMs strategically integrate various components including Multi-Head Self-Attention, Residual Networks and DNNs into a unified end-to-end model. + +[**DFM Model API**](./deepctr_torch.models.difm.html) + +![DIFM](../pics/DIFM.png) + +[Lu W, Yu Y, Chang Y, et al. A Dual Input-aware Factorization Machine for CTR Prediction[C]//IJCAI. 2020: 3139-3145.](https://www.ijcai.org/Proceedings/2020/0434.pdf) + + ## Layers diff --git a/docs/source/History.md b/docs/source/History.md index 78f4b463..eef2f07b 100644 --- a/docs/source/History.md +++ b/docs/source/History.md @@ -1,4 +1,5 @@ # History +- 04/04/2021 : [v0.2.6](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.6) released.Add add [IFM](./Features.html#ifm-input-aware-factorization-machine) and [DIFM](./Features.html#difm-dual-input-aware-factorization-machine);Support multi-gpus running([example](./FAQ.html#how-to-run-the-demo-with-multiple-gpus)). - 02/12/2021 : [v0.2.5](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.5) released.Fix bug in DCN-M. - 12/05/2020 : [v0.2.4](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.4) released.Imporve compatibility & fix issues.Add History callback.([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)). - 10/18/2020 : [v0.2.3](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.3) released.Add [DCN-M](./Features.html#dcn-deep-cross-network)&[DCN-Mix](./Features.html#dcn-mix-improved-deep-cross-network-with-mix-of-experts-and-matrix-kernel).Add EarlyStopping and ModelCheckpoint callbacks([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)). diff --git a/docs/source/Models.rst b/docs/source/Models.rst index 52d96c28..a5eeb102 100644 --- a/docs/source/Models.rst +++ b/docs/source/Models.rst @@ -21,3 +21,5 @@ DeepCTR-Torch Models API ONN FGCNN FiBiNET + IFM + DIFM \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 1dd328ae..d43d0eea 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ # The short X.Y version version = '' # The full version, including alpha/beta/rc tags -release = '0.2.5' +release = '0.2.6' # -- General configuration --------------------------------------------------- diff --git a/docs/source/deepctr_torch.models.difm.rst b/docs/source/deepctr_torch.models.difm.rst new file mode 100644 index 00000000..ae16a5b7 --- /dev/null +++ b/docs/source/deepctr_torch.models.difm.rst @@ -0,0 +1,7 @@ +deepctr\_torch.models.difm module +================================ + +.. automodule:: deepctr_torch.models.difm + :members: + :no-undoc-members: + :no-show-inheritance: diff --git a/docs/source/deepctr_torch.models.ifm.rst b/docs/source/deepctr_torch.models.ifm.rst new file mode 100644 index 00000000..e625757b --- /dev/null +++ b/docs/source/deepctr_torch.models.ifm.rst @@ -0,0 +1,7 @@ +deepctr\_torch.models.ifm module +================================ + +.. automodule:: deepctr_torch.models.ifm + :members: + :no-undoc-members: + :no-show-inheritance: diff --git a/docs/source/deepctr_torch.models.rst b/docs/source/deepctr_torch.models.rst index 25041a55..599710b6 100644 --- a/docs/source/deepctr_torch.models.rst +++ b/docs/source/deepctr_torch.models.rst @@ -10,6 +10,7 @@ Submodules deepctr_torch.models.autoint deepctr_torch.models.basemodel deepctr_torch.models.dcn + deepctr_torch.models.dcnmix deepctr_torch.models.deepfm deepctr_torch.models.fibinet deepctr_torch.models.mlr @@ -20,6 +21,8 @@ Submodules deepctr_torch.models.xdeepfm deepctr_torch.models.din deepctr_torch.models.dien + deepctr_torch.models.ifm + deepctr_torch.models.difm Module contents --------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index 2205e11d..bc4d2b1d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,13 +34,12 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR-Torch and News ----- +04/04/2021 : Add `IFM <./Features.html#ifm-input-aware-factorization-machine>`_ and `DIFM <./Features.html#difm-dual-input-aware-factorization-machine>`_ . Support multi-gpus running(`example <./FAQ.html#how-to-run-the-demo-with-multiple-gpus>`_). `Changelog `_ + 02/12/2021 : Fix bug in DCN-M. `Changelog `_ 12/05/2020 : Imporve compatibility & fix issues.Add History callback(`example `_). `Changelog `_ -10/18/2020 : Add `DCN-M <./Features.html#dcn-deep-cross-network>`_ and `DCN-Mix <./Features.html#dcn-mix-improved-deep-cross-network-with-mix-of-experts-and-matrix-kernel>`_ . Add EarlyStopping and ModelCheckpoint callbacks(`example `_). `Changelog `_ - - DisscussionGroup ----------------------- diff --git a/setup.py b/setup.py index 4d77342e..7060df42 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setuptools.setup( name="deepctr-torch", - version="0.2.5", + version="0.2.6", author="Weichen Shen", author_email="weichenswc@163.com", description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with PyTorch", diff --git a/tests/models/DIFM_test.py b/tests/models/DIFM_test.py new file mode 100644 index 00000000..0960232d --- /dev/null +++ b/tests/models/DIFM_test.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +import pytest + +from deepctr_torch.models import DIFM +from ..utils import get_test_data, SAMPLE_SIZE, check_model, get_device + + +@pytest.mark.parametrize( + 'att_head_num,dnn_hidden_units,sparse_feature_num', + [(1, (4,), 2), (2, (4, 4,), 2), (1, (4,), 1)] +) +def test_DIFM(att_head_num, dnn_hidden_units, sparse_feature_num): + model_name = "DIFM" + sample_size = SAMPLE_SIZE + x, y, feature_columns = get_test_data( + sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=sparse_feature_num) + + model = DIFM(linear_feature_columns=feature_columns, dnn_feature_columns=feature_columns, + att_head_num=att_head_num, + dnn_hidden_units=dnn_hidden_units, dnn_dropout=0.5, device=get_device()) + check_model(model, model_name, x, y) + + +if __name__ == "__main__": + pass diff --git a/tests/models/IFM_test.py b/tests/models/IFM_test.py new file mode 100644 index 00000000..44dd9d89 --- /dev/null +++ b/tests/models/IFM_test.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +import pytest + +from deepctr_torch.models import IFM +from ..utils import get_test_data, SAMPLE_SIZE, check_model, get_device + + +@pytest.mark.parametrize( + 'hidden_size,sparse_feature_num', + [((32,), 3), + ((32,), 2), ((32,), 1), + ] +) +def test_IFM(hidden_size, sparse_feature_num): + model_name = "IFM" + sample_size = SAMPLE_SIZE + x, y, feature_columns = get_test_data( + sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=sparse_feature_num) + + model = IFM(feature_columns, feature_columns, + dnn_hidden_units=hidden_size, dnn_dropout=0.5, device=get_device()) + check_model(model, model_name, x, y) + + +if __name__ == "__main__": + pass diff --git a/tests/utils.py b/tests/utils.py index d8a8d6cb..4c79631e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -70,7 +70,7 @@ def layer_test(layer_cls, kwargs = {}, input_shape=None, input_dtype=torch.float32, input_data=None, expected_output=None, expected_output_shape=None, expected_output_dtype=None, fixed_batch_size=False): '''check layer is valid or not - + :param layer_cls: :param input_shape: :param input_dtype: