diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 7685cfb2..fc3cdae4 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -18,7 +18,7 @@ jobs:
strategy:
matrix:
python-version: [3.6,3.7]
- torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.0]
+ torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.0,1.8.1]
# exclude:
# - python-version: 3.5
diff --git a/README.md b/README.md
index dfb7be53..a50aeb93 100644
--- a/README.md
+++ b/README.md
@@ -38,7 +38,9 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
| AutoInt | [CIKM 2019][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921) |
| ONN | [arxiv 2019][Operation-aware Neural Networks for User Response Prediction](https://arxiv.org/pdf/1904.12579.pdf) |
| FiBiNET | [RecSys 2019][FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction](https://arxiv.org/pdf/1905.09433.pdf) |
+| IFM | [IJCAI 2019][An Input-aware Factorization Machine for Sparse Prediction](https://www.ijcai.org/Proceedings/2019/0203.pdf) |
| DCN V2 | [arxiv 2020][DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535) |
+| DIFM | [IJCAI 2020][A Dual Input-aware Factorization Machine for CTR Prediction](https://www.ijcai.org/Proceedings/2020/0434.pdf) |
## DisscussionGroup & Related Projects
@@ -82,6 +84,11 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
Shen Weichen
Core Dev
Zhejiang Unversity
+
+
+ Zan Shuxun
+ Core Dev Beijing University of Posts and Telecommunications
+ |
Wang Ze
@@ -92,11 +99,6 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
Zhang Wutong
Core Dev Beijing University of Posts and Telecommunications
|
-
-
- Zan Shuxun
- Core Dev Beijing University of Posts and Telecommunications
- |
Zhang Yuefeng
diff --git a/deepctr_torch/__init__.py b/deepctr_torch/__init__.py
index b3ae817b..b780468d 100644
--- a/deepctr_torch/__init__.py
+++ b/deepctr_torch/__init__.py
@@ -2,5 +2,5 @@
from . import models
from .utils import check_version
-__version__ = '0.2.5'
+__version__ = '0.2.6'
check_version(__version__)
\ No newline at end of file
diff --git a/deepctr_torch/layers/activation.py b/deepctr_torch/layers/activation.py
index 4ba8758e..01624a05 100644
--- a/deepctr_torch/layers/activation.py
+++ b/deepctr_torch/layers/activation.py
@@ -12,7 +12,7 @@ class Dice(nn.Module):
Output shape:
- Same shape as input.
-
+
References
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
- https://github.com/zhougr1993/DeepInterestNetwork, https://github.com/fanoping/DIN-pytorch
diff --git a/deepctr_torch/layers/interaction.py b/deepctr_torch/layers/interaction.py
index af7d945e..edbfa88b 100644
--- a/deepctr_torch/layers/interaction.py
+++ b/deepctr_torch/layers/interaction.py
@@ -106,10 +106,11 @@ class BilinearInteraction(nn.Module):
Input shape
- A list of 3D tensor with shape: ``(batch_size,filed_size, embedding_size)``.
Output shape
- - 3D tensor with shape: ``(batch_size,filed_size, embedding_size)``.
+ - 3D tensor with shape: ``(batch_size,filed_size*(filed_size-1)/2, embedding_size)``.
Arguments
- **filed_size** : Positive integer, number of feature groups.
- - **str** : String, types of bilinear functions used in this layer.
+ - **embedding_size** : Positive integer, embedding size of sparse features.
+ - **bilinear_type** : String, types of bilinear functions used in this layer.
- **seed** : A Python integer to use as random seed.
References
- [FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction
@@ -125,7 +126,7 @@ def __init__(self, filed_size, embedding_size, bilinear_type="interaction", seed
self.bilinear = nn.Linear(
embedding_size, embedding_size, bias=False)
elif self.bilinear_type == "each":
- for i in range(filed_size):
+ for _ in range(filed_size):
self.bilinear.append(
nn.Linear(embedding_size, embedding_size, bias=False))
elif self.bilinear_type == "interaction":
@@ -340,13 +341,14 @@ class InteractingLayer(nn.Module):
- [Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.](https://arxiv.org/abs/1810.11921)
"""
- def __init__(self, in_features, att_embedding_size=8, head_num=2, use_res=True, seed=1024, device='cpu'):
+ def __init__(self, in_features, att_embedding_size=8, head_num=2, use_res=True, scaling=False, seed=1024, device='cpu'):
super(InteractingLayer, self).__init__()
if head_num <= 0:
raise ValueError('head_num must be a int > 0')
self.att_embedding_size = att_embedding_size
self.head_num = head_num
self.use_res = use_res
+ self.scaling = scaling
self.seed = seed
embedding_size = in_features
@@ -388,7 +390,8 @@ def forward(self, inputs):
values, self.att_embedding_size, dim=2))
inner_product = torch.einsum(
'bnik,bnjk->bnij', querys, keys) # head_num None F F
-
+ if self.scaling:
+ inner_product /= self.att_embedding_size ** 0.5
self.normalized_att_scores = F.softmax(
inner_product, dim=-1) # head_num None F F
result = torch.matmul(self.normalized_att_scores,
@@ -428,17 +431,20 @@ def __init__(self, in_features, layer_num=2, parameterization='vector', seed=102
self.parameterization = parameterization
if self.parameterization == 'vector':
# weight in DCN. (in_features, 1)
- self.kernels = torch.nn.ParameterList(
- [nn.Parameter(nn.init.xavier_normal_(torch.empty(in_features, 1))) for i in range(self.layer_num)])
+ self.kernels = nn.Parameter(torch.Tensor(self.layer_num, in_features, 1))
elif self.parameterization == 'matrix':
# weight matrix in DCN-M. (in_features, in_features)
- self.kernels = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
- torch.empty(in_features, in_features))) for i in range(self.layer_num)])
+ self.kernels = nn.Parameter(torch.Tensor(self.layer_num, in_features, in_features))
else: # error
raise ValueError("parameterization should be 'vector' or 'matrix'")
- self.bias = torch.nn.ParameterList(
- [nn.Parameter(nn.init.zeros_(torch.empty(in_features, 1))) for i in range(self.layer_num)])
+ self.bias = nn.Parameter(torch.Tensor(self.layer_num, in_features, 1))
+
+ for i in range(self.kernels.shape[0]):
+ nn.init.xavier_normal_(self.kernels[i])
+ for i in range(self.bias.shape[0]):
+ nn.init.zeros_(self.bias[i])
+
self.to(device)
def forward(self, inputs):
@@ -483,18 +489,23 @@ def __init__(self, in_features, low_rank=32, num_experts=4, layer_num=2, device=
self.num_experts = num_experts
# U: (in_features, low_rank)
- self.U_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
- torch.empty(num_experts, in_features, low_rank))) for i in range(self.layer_num)])
+ self.U_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, in_features, low_rank))
# V: (in_features, low_rank)
- self.V_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
- torch.empty(num_experts, in_features, low_rank))) for i in range(self.layer_num)])
+ self.V_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, in_features, low_rank))
# C: (low_rank, low_rank)
- self.C_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
- torch.empty(num_experts, low_rank, low_rank))) for i in range(self.layer_num)])
+ self.C_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, low_rank, low_rank))
self.gating = nn.ModuleList([nn.Linear(in_features, 1, bias=False) for i in range(self.num_experts)])
- self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(
- torch.empty(in_features, 1))) for i in range(self.layer_num)])
+ self.bias = nn.Parameter(torch.Tensor(self.layer_num, in_features, 1))
+
+ init_para_list = [self.U_list, self.V_list, self.C_list]
+ for i in range(len(init_para_list)):
+ for j in range(self.layer_num):
+ nn.init.xavier_normal_(init_para_list[i][j])
+
+ for i in range(len(self.bias)):
+ nn.init.zeros_(self.bias[i])
+
self.to(device)
def forward(self, inputs):
diff --git a/deepctr_torch/layers/sequence.py b/deepctr_torch/layers/sequence.py
index 64736b5b..550e5878 100644
--- a/deepctr_torch/layers/sequence.py
+++ b/deepctr_torch/layers/sequence.py
@@ -39,7 +39,7 @@ def _sequence_mask(self, lengths, maxlen=None, dtype=torch.bool):
# Returns a mask tensor representing the first N positions of each cell.
if maxlen is None:
maxlen = lengths.max()
- row_vector = torch.arange(0, maxlen, 1).to(self.device)
+ row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
@@ -70,6 +70,7 @@ def forward(self, seq_value_len_list):
hist = torch.sum(hist, dim=1, keepdim=False)
if self.mode == 'mean':
+ self.eps = self.eps.to(user_behavior_length.device)
hist = torch.div(hist, user_behavior_length.type(torch.float32) + self.eps)
hist = torch.unsqueeze(hist, dim=1)
diff --git a/deepctr_torch/models/__init__.py b/deepctr_torch/models/__init__.py
index 09f1d7c3..43381369 100644
--- a/deepctr_torch/models/__init__.py
+++ b/deepctr_torch/models/__init__.py
@@ -2,6 +2,8 @@
from .deepfm import DeepFM
from .xdeepfm import xDeepFM
from .afm import AFM
+from .difm import DIFM
+from .ifm import IFM
from .autoint import AutoInt
from .dcn import DCN
from .dcnmix import DCNMix
diff --git a/deepctr_torch/models/afm.py b/deepctr_torch/models/afm.py
index 4d015e72..ae1556d4 100644
--- a/deepctr_torch/models/afm.py
+++ b/deepctr_torch/models/afm.py
@@ -27,16 +27,17 @@ class AFM(BaseModel):
:param seed: integer ,to use as random seed.
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
"""
def __init__(self, linear_feature_columns, dnn_feature_columns, use_attention=True, attention_factor=8,
l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_att=1e-5, afm_dropout=0, init_std=0.0001, seed=1024,
- task='binary', device='cpu'):
+ task='binary', device='cpu', gpus=None):
super(AFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
- device=device)
+ device=device, gpus=gpus)
self.use_attention = use_attention
diff --git a/deepctr_torch/models/autoint.py b/deepctr_torch/models/autoint.py
index a2001183..c39effb4 100644
--- a/deepctr_torch/models/autoint.py
+++ b/deepctr_torch/models/autoint.py
@@ -32,19 +32,20 @@ class AutoInt(BaseModel):
:param seed: integer ,to use as random seed.
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
-
+
"""
def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3, att_embedding_size=8, att_head_num=2,
att_res=True,
dnn_hidden_units=(256, 128), dnn_activation='relu',
l2_reg_dnn=0, l2_reg_embedding=1e-5, dnn_use_bn=False, dnn_dropout=0, init_std=0.0001, seed=1024,
- task='binary', device='cpu'):
+ task='binary', device='cpu', gpus=None):
super(AutoInt, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=0,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
- device=device)
+ device=device, gpus=gpus)
if len(dnn_hidden_units) <= 0 and att_layer_num <= 0:
raise ValueError("Either hidden_layer or att_layer_num must > 0")
diff --git a/deepctr_torch/models/basemodel.py b/deepctr_torch/models/basemodel.py
index 865c0bfc..bb9d1f7a 100644
--- a/deepctr_torch/models/basemodel.py
+++ b/deepctr_torch/models/basemodel.py
@@ -59,7 +59,7 @@ def __init__(self, feature_columns, feature_index, init_std=0.0001, device='cpu'
device))
torch.nn.init.normal_(self.weight, mean=0, std=init_std)
- def forward(self, X):
+ def forward(self, X, sparse_feat_refine_weight=None):
sparse_embedding_list = [self.embedding_dict[feat.embedding_name](
X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]].long()) for
@@ -73,26 +73,25 @@ def forward(self, X):
sparse_embedding_list += varlen_embedding_list
- if len(sparse_embedding_list) > 0 and len(dense_value_list) > 0:
- linear_sparse_logit = torch.sum(
- torch.cat(sparse_embedding_list, dim=-1), dim=-1, keepdim=False)
- linear_dense_logit = torch.cat(
+ linear_logit = torch.zeros([X.shape[0], 1]).to(sparse_embedding_list[0].device)
+ if len(sparse_embedding_list) > 0:
+ sparse_embedding_cat = torch.cat(sparse_embedding_list, dim=-1)
+ if sparse_feat_refine_weight is not None:
+ # w_{x,i}=m_{x,i} * w_i (in IFM and DIFM)
+ sparse_embedding_cat = sparse_embedding_cat * sparse_feat_refine_weight.unsqueeze(1)
+ sparse_feat_logit = torch.sum(sparse_embedding_cat, dim=-1, keepdim=False)
+ linear_logit += sparse_feat_logit
+ if len(dense_value_list) > 0:
+ dense_value_logit = torch.cat(
dense_value_list, dim=-1).matmul(self.weight)
- linear_logit = linear_sparse_logit + linear_dense_logit
- elif len(sparse_embedding_list) > 0:
- linear_logit = torch.sum(
- torch.cat(sparse_embedding_list, dim=-1), dim=-1, keepdim=False)
- elif len(dense_value_list) > 0:
- linear_logit = torch.cat(
- dense_value_list, dim=-1).matmul(self.weight)
- else:
- linear_logit = torch.zeros([X.shape[0], 1])
+ linear_logit += dense_value_logit
+
return linear_logit
class BaseModel(nn.Module):
def __init__(self, linear_feature_columns, dnn_feature_columns, l2_reg_linear=1e-5, l2_reg_embedding=1e-5,
- init_std=0.0001, seed=1024, task='binary', device='cpu'):
+ init_std=0.0001, seed=1024, task='binary', device='cpu', gpus=None):
super(BaseModel, self).__init__()
torch.manual_seed(seed)
@@ -100,7 +99,11 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, l2_reg_linear=1e
self.reg_loss = torch.zeros((1,), device=device)
self.aux_loss = torch.zeros((1,), device=device)
- self.device = device # device
+ self.device = device
+ self.gpus = gpus
+ if gpus and str(self.gpus[0]) not in self.device:
+ raise ValueError(
+ "`gpus[0]` should be the same gpu with `device`")
self.feature_index = build_input_features(
linear_feature_columns + dnn_feature_columns)
@@ -192,14 +195,21 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
torch.from_numpy(y))
if batch_size is None:
batch_size = 256
- train_loader = DataLoader(
- dataset=train_tensor_data, shuffle=shuffle, batch_size=batch_size)
- print(self.device, end="\n")
model = self.train()
loss_func = self.loss_func
optim = self.optim
+ if self.gpus:
+ print('parallel running on these gpus:', self.gpus)
+ model = torch.nn.DataParallel(model, device_ids=self.gpus)
+ batch_size *= len(self.gpus) # input `batch_size` is batch_size per gpu
+ else:
+ print(self.device)
+
+ train_loader = DataLoader(
+ dataset=train_tensor_data, shuffle=shuffle, batch_size=batch_size)
+
sample_num = len(train_tensor_data)
steps_per_epoch = (sample_num - 1) // batch_size + 1
@@ -224,7 +234,7 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
train_result = {}
try:
with tqdm(enumerate(train_loader), disable=verbose != 1) as t:
- for index, (x_train, y_train) in t:
+ for _, (x_train, y_train) in t:
x = x_train.to(self.device).float()
y = y_train.to(self.device).float()
@@ -323,7 +333,7 @@ def predict(self, x, batch_size=256):
pred_ans = []
with torch.no_grad():
- for index, x_test in enumerate(test_loader):
+ for _, x_test in enumerate(test_loader):
x = x_test[0].to(self.device).float()
y_pred = model(x).cpu().data.numpy() # .squeeze()
diff --git a/deepctr_torch/models/ccpm.py b/deepctr_torch/models/ccpm.py
index 73272b66..7ab098ae 100644
--- a/deepctr_torch/models/ccpm.py
+++ b/deepctr_torch/models/ccpm.py
@@ -34,6 +34,7 @@ class CCPM(BaseModel):
:param seed: integer ,to use as random seed.
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
"""
@@ -41,11 +42,11 @@ class CCPM(BaseModel):
def __init__(self, linear_feature_columns, dnn_feature_columns, conv_kernel_width=(6, 5),
conv_filters=(4, 4),
dnn_hidden_units=(256,), l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_dnn=0, dnn_dropout=0,
- init_std=0.0001, seed=1024, task='binary', device='cpu', dnn_use_bn=False, dnn_activation='relu'):
+ init_std=0.0001, seed=1024, task='binary', device='cpu', dnn_use_bn=False, dnn_activation='relu', gpus=None):
super(CCPM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
- device=device)
+ device=device, gpus=gpus)
if len(conv_kernel_width) != len(conv_filters):
raise ValueError(
diff --git a/deepctr_torch/models/dcn.py b/deepctr_torch/models/dcn.py
index 4528b9a7..f5ef03bf 100644
--- a/deepctr_torch/models/dcn.py
+++ b/deepctr_torch/models/dcn.py
@@ -36,18 +36,19 @@ class DCN(BaseModel):
:param dnn_activation: Activation function to use in DNN
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
-
+
"""
def __init__(self, linear_feature_columns, dnn_feature_columns, cross_num=2, cross_parameterization='vector',
dnn_hidden_units=(128, 128), l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_cross=0.00001,
l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False,
- task='binary', device='cpu'):
+ task='binary', device='cpu', gpus=None):
super(DCN, self).__init__(linear_feature_columns=linear_feature_columns,
dnn_feature_columns=dnn_feature_columns, l2_reg_embedding=l2_reg_embedding,
- init_std=init_std, seed=seed, task=task, device=device)
+ init_std=init_std, seed=seed, task=task, device=device, gpus=gpus)
self.dnn_hidden_units = dnn_hidden_units
self.cross_num = cross_num
self.dnn = DNN(self.compute_input_dim(dnn_feature_columns), dnn_hidden_units,
diff --git a/deepctr_torch/models/dcnmix.py b/deepctr_torch/models/dcnmix.py
index c01fd44c..9b0e97d4 100644
--- a/deepctr_torch/models/dcnmix.py
+++ b/deepctr_torch/models/dcnmix.py
@@ -36,8 +36,9 @@ class DCNMix(BaseModel):
:param num_experts: Positive integer, number of experts.
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
-
+
"""
def __init__(self, linear_feature_columns,
@@ -45,11 +46,11 @@ def __init__(self, linear_feature_columns,
dnn_hidden_units=(128, 128), l2_reg_linear=0.00001,
l2_reg_embedding=0.00001, l2_reg_cross=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024,
dnn_dropout=0, low_rank=32, num_experts=4,
- dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu'):
+ dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None):
super(DCNMix, self).__init__(linear_feature_columns=linear_feature_columns,
dnn_feature_columns=dnn_feature_columns, l2_reg_embedding=l2_reg_embedding,
- init_std=init_std, seed=seed, task=task, device=device)
+ init_std=init_std, seed=seed, task=task, device=device, gpus=gpus)
self.dnn_hidden_units = dnn_hidden_units
self.cross_num = cross_num
self.dnn = DNN(self.compute_input_dim(dnn_feature_columns), dnn_hidden_units,
diff --git a/deepctr_torch/models/deepfm.py b/deepctr_torch/models/deepfm.py
index f0dfb411..7f90faf7 100644
--- a/deepctr_torch/models/deepfm.py
+++ b/deepctr_torch/models/deepfm.py
@@ -30,8 +30,9 @@ class DeepFM(BaseModel):
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
-
+
"""
def __init__(self,
@@ -39,11 +40,11 @@ def __init__(self,
dnn_hidden_units=(256, 128),
l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024,
dnn_dropout=0,
- dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu'):
+ dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None):
super(DeepFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
- device=device)
+ device=device, gpus=gpus)
self.use_fm = use_fm
self.use_dnn = len(dnn_feature_columns) > 0 and len(
diff --git a/deepctr_torch/models/dien.py b/deepctr_torch/models/dien.py
index b86a9897..6f37c1aa 100644
--- a/deepctr_torch/models/dien.py
+++ b/deepctr_torch/models/dien.py
@@ -16,25 +16,27 @@
class DIEN(BaseModel):
"""Instantiates the Deep Interest Evolution Network architecture.
- :param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
- :param history_feature_list: list,to indicate sequence sparse field
- :param gru_type: str,can be GRU AIGRU AUGRU AGRU
- :param use_negsampling: bool, whether or not use negtive sampling
- :param alpha: float ,weight of auxiliary_loss
- :param use_bn: bool. Whether use BatchNormalization before activation or not in deep net
- :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
- :param dnn_activation: Activation function to use in DNN
- :param att_hidden_units: list,list of positive integer , the layer number and units in each layer of attention net
- :param att_activation: Activation function to use in attention net
- :param att_weight_normalization: bool.Whether normalize the attention score of local activation unit.
- :param l2_reg_dnn: float. L2 regularizer strength applied to DNN
- :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
- :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
- :param init_std: float,to use as the initialize std of embedding vector
- :param seed: integer ,to use as random seed.
- :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
- :param device: str, ``"cpu"`` or ``"cuda:0"``
- :return: A PyTorch model instance.
+ :param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
+ :param history_feature_list: list,to indicate sequence sparse field
+ :param gru_type: str,can be GRU AIGRU AUGRU AGRU
+ :param use_negsampling: bool, whether or not use negtive sampling
+ :param alpha: float ,weight of auxiliary_loss
+ :param use_bn: bool. Whether use BatchNormalization before activation or not in deep net
+ :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
+ :param dnn_activation: Activation function to use in DNN
+ :param att_hidden_units: list,list of positive integer , the layer number and units in each layer of attention net
+ :param att_activation: Activation function to use in attention net
+ :param att_weight_normalization: bool.Whether normalize the attention score of local activation unit.
+ :param l2_reg_dnn: float. L2 regularizer strength applied to DNN
+ :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
+ :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
+ :param init_std: float,to use as the initialize std of embedding vector
+ :param seed: integer ,to use as random seed.
+ :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
+ :param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
+ :return: A PyTorch model instance.
+
"""
def __init__(self,
@@ -43,9 +45,9 @@ def __init__(self,
dnn_activation='relu',
att_hidden_units=(64, 16), att_activation="relu", att_weight_normalization=True,
l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, init_std=0.0001, seed=1024, task='binary',
- device='cpu'):
+ device='cpu', gpus=None):
super(DIEN, self).__init__([], dnn_feature_columns, l2_reg_linear=0, l2_reg_embedding=l2_reg_embedding,
- init_std=init_std, seed=seed, task=task, device=device)
+ init_std=init_std, seed=seed, task=task, device=device, gpus=gpus)
self.item_features = history_feature_list
self.use_negsampling = use_negsampling
diff --git a/deepctr_torch/models/difm.py b/deepctr_torch/models/difm.py
new file mode 100644
index 00000000..13a3aaab
--- /dev/null
+++ b/deepctr_torch/models/difm.py
@@ -0,0 +1,106 @@
+# -*- coding:utf-8 -*-
+"""
+Author:
+ zanshuxun, zanshuxun@aliyun.com
+Reference:
+ [1] Lu W, Yu Y, Chang Y, et al. A Dual Input-aware Factorization Machine for CTR Prediction[C]//IJCAI. 2020: 3139-3145.(https://www.ijcai.org/Proceedings/2020/0434.pdf)
+"""
+import torch
+import torch.nn as nn
+
+from .basemodel import BaseModel
+from ..inputs import combined_dnn_input, SparseFeat, VarLenSparseFeat
+from ..layers import FM, DNN, InteractingLayer, concat_fun
+
+
+class DIFM(BaseModel):
+ """Instantiates the DIFM Network architecture.
+
+ :param linear_feature_columns: An iterable containing all the features used by linear part of the model.
+ :param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
+ :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
+ :param l2_reg_linear: float. L2 regularizer strength applied to linear part
+ :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
+ :param l2_reg_dnn: float. L2 regularizer strength applied to DNN
+ :param init_std: float,to use as the initialize std of embedding vector
+ :param seed: integer ,to use as random seed.
+ :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
+ :param dnn_activation: Activation function to use in DNN
+ :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN
+ :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
+ :param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on ``device`` . ``gpus[0]`` should be the same gpu with ``device`` .
+ :return: A PyTorch model instance.
+
+ """
+
+ def __init__(self,
+ linear_feature_columns, dnn_feature_columns, att_embedding_size=8, att_head_num=8,
+ att_res=True, dnn_hidden_units=(256, 128),
+ l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024,
+ dnn_dropout=0,
+ dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None):
+ super(DIFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
+ l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
+ device=device, gpus=gpus)
+
+ if not len(dnn_hidden_units) > 0:
+ raise ValueError("dnn_hidden_units is null!")
+
+ self.use_dnn = len(dnn_feature_columns) > 0 and len(
+ dnn_hidden_units) > 0
+ self.fm = FM()
+
+ # InteractingLayer (used in AutoInt) = multi-head self-attention + Residual Network
+ self.vector_wise_net = InteractingLayer(self.embedding_size, att_embedding_size,
+ att_head_num, att_res, scaling=True, device=device)
+
+ self.bit_wise_net = DNN(self.compute_input_dim(dnn_feature_columns, include_dense=False),
+ dnn_hidden_units, activation=dnn_activation, l2_reg=l2_reg_dnn,
+ dropout_rate=dnn_dropout,
+ use_bn=dnn_use_bn, init_std=init_std, device=device)
+ self.sparse_feat_num = len(list(filter(lambda x: isinstance(x, SparseFeat) or isinstance(x, VarLenSparseFeat),
+ dnn_feature_columns)))
+
+ self.transform_matrix_P_vec = nn.Linear(
+ self.sparse_feat_num*att_embedding_size*att_head_num, self.sparse_feat_num, bias=False).to(device)
+ self.transform_matrix_P_bit = nn.Linear(
+ dnn_hidden_units[-1], self.sparse_feat_num, bias=False).to(device)
+
+ self.add_regularization_weight(
+ filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.vector_wise_net.named_parameters()),
+ l2=l2_reg_dnn)
+ self.add_regularization_weight(
+ filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.bit_wise_net.named_parameters()),
+ l2=l2_reg_dnn)
+ self.add_regularization_weight(self.transform_matrix_P_vec.weight, l2=l2_reg_dnn)
+ self.add_regularization_weight(self.transform_matrix_P_bit.weight, l2=l2_reg_dnn)
+
+ self.to(device)
+
+ def forward(self, X):
+ sparse_embedding_list, _ = self.input_from_feature_columns(X, self.dnn_feature_columns,
+ self.embedding_dict)
+ if not len(sparse_embedding_list) > 0:
+ raise ValueError("there are no sparse features")
+
+ att_input = concat_fun(sparse_embedding_list, axis=1)
+ att_out = self.vector_wise_net(att_input)
+ att_out = att_out.reshape(att_out.shape[0], -1)
+ m_vec = self.transform_matrix_P_vec(att_out)
+
+ dnn_input = combined_dnn_input(sparse_embedding_list, [])
+ dnn_output = self.bit_wise_net(dnn_input)
+ m_bit = self.transform_matrix_P_bit(dnn_output)
+
+ m_x = m_vec + m_bit # m_x is the complete input-aware factor
+
+ logit = self.linear_model(X, sparse_feat_refine_weight=m_x)
+
+ fm_input = torch.cat(sparse_embedding_list, dim=1)
+ refined_fm_input = fm_input * m_x.unsqueeze(-1) # \textbf{v}_{x,i}=m_{x,i} * \textbf{v}_i
+ logit += self.fm(refined_fm_input)
+
+ y_pred = self.out(logit)
+
+ return y_pred
diff --git a/deepctr_torch/models/din.py b/deepctr_torch/models/din.py
index 0a8e46af..8bac0383 100644
--- a/deepctr_torch/models/din.py
+++ b/deepctr_torch/models/din.py
@@ -29,6 +29,8 @@ class DIN(BaseModel):
:param init_std: float,to use as the initialize std of embedding vector
:param seed: integer ,to use as random seed.
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
+ :param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
"""
@@ -37,9 +39,9 @@ def __init__(self, dnn_feature_columns, history_feature_list, dnn_use_bn=False,
dnn_hidden_units=(256, 128), dnn_activation='relu', att_hidden_size=(64, 16),
att_activation='Dice', att_weight_normalization=False, l2_reg_dnn=0.0,
l2_reg_embedding=1e-6, dnn_dropout=0, init_std=0.0001,
- seed=1024, task='binary', device='cpu'):
+ seed=1024, task='binary', device='cpu', gpus=None):
super(DIN, self).__init__([], dnn_feature_columns, l2_reg_linear=0, l2_reg_embedding=l2_reg_embedding,
- init_std=init_std, seed=seed, task=task, device=device)
+ init_std=init_std, seed=seed, task=task, device=device, gpus=gpus)
self.sparse_feature_columns = list(
filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns)) if dnn_feature_columns else []
diff --git a/deepctr_torch/models/fibinet.py b/deepctr_torch/models/fibinet.py
index f3e1b436..67ec4783 100644
--- a/deepctr_torch/models/fibinet.py
+++ b/deepctr_torch/models/fibinet.py
@@ -31,17 +31,18 @@ class FiBiNET(BaseModel):
:param dnn_activation: Activation function to use in DNN
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
-
+
"""
def __init__(self, linear_feature_columns, dnn_feature_columns, bilinear_type='interaction',
reduction_ratio=3, dnn_hidden_units=(128, 128), l2_reg_linear=1e-5,
l2_reg_embedding=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu',
- task='binary', device='cpu'):
+ task='binary', device='cpu', gpus=None):
super(FiBiNET, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
- device=device)
+ device=device, gpus=gpus)
self.linear_feature_columns = linear_feature_columns
self.dnn_feature_columns = dnn_feature_columns
self.filed_size = len(self.embedding_dict)
diff --git a/deepctr_torch/models/ifm.py b/deepctr_torch/models/ifm.py
new file mode 100644
index 00000000..4f057833
--- /dev/null
+++ b/deepctr_torch/models/ifm.py
@@ -0,0 +1,89 @@
+# -*- coding:utf-8 -*-
+"""
+Author:
+ zanshuxun, zanshuxun@aliyun.com
+Reference:
+ [1] Yu Y, Wang Z, Yuan B. An Input-aware Factorization Machine for Sparse Prediction[C]//IJCAI. 2019: 1466-1472.(https://www.ijcai.org/Proceedings/2019/0203.pdf)
+"""
+import torch
+import torch.nn as nn
+
+from .basemodel import BaseModel
+from ..inputs import combined_dnn_input, SparseFeat, VarLenSparseFeat
+from ..layers import FM, DNN
+
+
+class IFM(BaseModel):
+ """Instantiates the IFM Network architecture.
+
+ :param linear_feature_columns: An iterable containing all the features used by linear part of the model.
+ :param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
+ :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
+ :param l2_reg_linear: float. L2 regularizer strength applied to linear part
+ :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
+ :param l2_reg_dnn: float. L2 regularizer strength applied to DNN
+ :param init_std: float,to use as the initialize std of embedding vector
+ :param seed: integer ,to use as random seed.
+ :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
+ :param dnn_activation: Activation function to use in DNN
+ :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN
+ :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
+ :param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on ``device`` . ``gpus[0]`` should be the same gpu with ``device`` .
+ :return: A PyTorch model instance.
+
+ """
+
+ def __init__(self,
+ linear_feature_columns, dnn_feature_columns,
+ dnn_hidden_units=(256, 128),
+ l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024,
+ dnn_dropout=0,
+ dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None):
+ super(IFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
+ l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
+ device=device, gpus=gpus)
+
+ if not len(dnn_hidden_units) > 0:
+ raise ValueError("dnn_hidden_units is null!")
+
+ self.use_dnn = len(dnn_feature_columns) > 0 and len(
+ dnn_hidden_units) > 0
+ self.fm = FM()
+
+ self.factor_estimating_net = DNN(self.compute_input_dim(dnn_feature_columns, include_dense=False),
+ dnn_hidden_units, activation=dnn_activation, l2_reg=l2_reg_dnn,
+ dropout_rate=dnn_dropout,
+ use_bn=dnn_use_bn, init_std=init_std, device=device)
+ self.sparse_feat_num = len(list(filter(lambda x: isinstance(x, SparseFeat) or isinstance(x, VarLenSparseFeat),
+ dnn_feature_columns)))
+ self.transform_weight_matrix_P = nn.Linear(
+ dnn_hidden_units[-1], self.sparse_feat_num, bias=False).to(device)
+
+ self.add_regularization_weight(
+ filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.factor_estimating_net.named_parameters()),
+ l2=l2_reg_dnn)
+ self.add_regularization_weight(self.transform_weight_matrix_P.weight, l2=l2_reg_dnn)
+
+ self.to(device)
+
+ def forward(self, X):
+ sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns,
+ self.embedding_dict)
+ if not len(sparse_embedding_list) > 0:
+ raise ValueError("there are no sparse features")
+
+ dnn_input = combined_dnn_input(sparse_embedding_list, []) # (batch_size, feat_num * embedding_size)
+ dnn_output = self.factor_estimating_net(dnn_input)
+ dnn_output = self.transform_weight_matrix_P(dnn_output) # m'_{x}
+ input_aware_factor = self.sparse_feat_num * dnn_output.softmax(1) # input_aware_factor m_{x,i}
+
+ logit = self.linear_model(X, sparse_feat_refine_weight=input_aware_factor)
+
+ fm_input = torch.cat(sparse_embedding_list, dim=1)
+ refined_fm_input = fm_input * input_aware_factor.unsqueeze(-1) # \textbf{v}_{x,i}=m_{x,i}\textbf{v}_i
+ logit += self.fm(refined_fm_input)
+
+ y_pred = self.out(logit)
+
+ return y_pred
diff --git a/deepctr_torch/models/mlr.py b/deepctr_torch/models/mlr.py
index 9259694a..8cb60090 100644
--- a/deepctr_torch/models/mlr.py
+++ b/deepctr_torch/models/mlr.py
@@ -26,14 +26,15 @@ class MLR(BaseModel):
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param bias_feature_columns: An iterable containing all the features used by bias part of the model.
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
-
+
"""
def __init__(self, region_feature_columns, base_feature_columns=None, bias_feature_columns=None,
- region_num=4, l2_reg_linear=1e-5, init_std=0.0001, seed=1024, task='binary', device='cpu'
+ region_num=4, l2_reg_linear=1e-5, init_std=0.0001, seed=1024, task='binary', device='cpu', gpus=None
):
- super(MLR, self).__init__(region_feature_columns, region_feature_columns, task=task, device=device)
+ super(MLR, self).__init__(region_feature_columns, region_feature_columns, task=task, device=device, gpus=gpus)
if region_num <= 1:
raise ValueError("region_num must > 1")
diff --git a/deepctr_torch/models/nfm.py b/deepctr_torch/models/nfm.py
index 4120d5de..f01613c7 100644
--- a/deepctr_torch/models/nfm.py
+++ b/deepctr_torch/models/nfm.py
@@ -29,17 +29,18 @@ class NFM(BaseModel):
:param dnn_activation: Activation function to use in deep net
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
-
+
"""
def __init__(self,
linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(128, 128),
l2_reg_embedding=1e-5, l2_reg_linear=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, bi_dropout=0,
- dnn_dropout=0, dnn_activation='relu', task='binary', device='cpu'):
+ dnn_dropout=0, dnn_activation='relu', task='binary', device='cpu', gpus=None):
super(NFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
- device=device)
+ device=device, gpus=gpus)
self.dnn = DNN(self.compute_input_dim(dnn_feature_columns, include_sparse=False) + self.embedding_size,
dnn_hidden_units,
diff --git a/deepctr_torch/models/onn.py b/deepctr_torch/models/onn.py
index b4d4d085..49f59cca 100644
--- a/deepctr_torch/models/onn.py
+++ b/deepctr_torch/models/onn.py
@@ -50,18 +50,19 @@ class ONN(BaseModel):
:param reduce_sum: bool,whether apply reduce_sum on cross vector
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
-
+
"""
def __init__(self, linear_feature_columns, dnn_feature_columns,
dnn_hidden_units=(128, 128),
l2_reg_embedding=1e-5, l2_reg_linear=1e-5, l2_reg_dnn=0,
dnn_dropout=0, init_std=0.0001, seed=1024, dnn_use_bn=False, dnn_activation='relu',
- task='binary', device='cpu'):
+ task='binary', device='cpu', gpus=None):
super(ONN, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
- device=device)
+ device=device, gpus=gpus)
# second order part
embedding_size = self.embedding_size
diff --git a/deepctr_torch/models/pnn.py b/deepctr_torch/models/pnn.py
index 2cdeff0a..d72b2d5c 100644
--- a/deepctr_torch/models/pnn.py
+++ b/deepctr_torch/models/pnn.py
@@ -30,16 +30,17 @@ class PNN(BaseModel):
:param kernel_type: str,kernel_type used in outter-product,can be ``'mat'`` , ``'vec'`` or ``'num'``
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
-
+
"""
def __init__(self, dnn_feature_columns, dnn_hidden_units=(128, 128), l2_reg_embedding=1e-5, l2_reg_dnn=0,
init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', use_inner=True, use_outter=False,
- kernel_type='mat', task='binary', device='cpu', ):
+ kernel_type='mat', task='binary', device='cpu', gpus=None):
super(PNN, self).__init__([], dnn_feature_columns, l2_reg_linear=0, l2_reg_embedding=l2_reg_embedding,
- init_std=init_std, seed=seed, task=task, device=device)
+ init_std=init_std, seed=seed, task=task, device=device, gpus=gpus)
if kernel_type not in ['mat', 'vec', 'num']:
raise ValueError("kernel_type must be mat,vec or num")
diff --git a/deepctr_torch/models/wdl.py b/deepctr_torch/models/wdl.py
index 322b0920..6016eb0a 100644
--- a/deepctr_torch/models/wdl.py
+++ b/deepctr_torch/models/wdl.py
@@ -28,8 +28,9 @@ class WDL(BaseModel):
:param dnn_activation: Activation function to use in DNN
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
-
+
"""
def __init__(self,
@@ -37,11 +38,11 @@ def __init__(self,
l2_reg_linear=1e-5,
l2_reg_embedding=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu',
dnn_use_bn=False,
- task='binary', device='cpu'):
+ task='binary', device='cpu', gpus=None):
super(WDL, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
- device=device)
+ device=device, gpus=gpus)
self.use_dnn = len(dnn_feature_columns) > 0 and len(
dnn_hidden_units) > 0
diff --git a/deepctr_torch/models/xdeepfm.py b/deepctr_torch/models/xdeepfm.py
index 87cac472..7d5efa01 100644
--- a/deepctr_torch/models/xdeepfm.py
+++ b/deepctr_torch/models/xdeepfm.py
@@ -34,18 +34,19 @@ class xDeepFM(BaseModel):
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
-
+
"""
def __init__(self, linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(256, 256),
cin_layer_size=(256, 128,), cin_split_half=True, cin_activation='relu', l2_reg_linear=0.00001,
l2_reg_embedding=0.00001, l2_reg_dnn=0, l2_reg_cin=0, init_std=0.0001, seed=1024, dnn_dropout=0,
- dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu'):
+ dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None):
super(xDeepFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
- device=device)
+ device=device, gpus=gpus)
self.dnn_hidden_units = dnn_hidden_units
self.use_dnn = len(dnn_feature_columns) > 0 and len(dnn_hidden_units) > 0
if self.use_dnn:
diff --git a/docs/pics/DIFM.png b/docs/pics/DIFM.png
new file mode 100644
index 00000000..76a983b2
Binary files /dev/null and b/docs/pics/DIFM.png differ
diff --git a/docs/pics/IFM.png b/docs/pics/IFM.png
new file mode 100644
index 00000000..5adf940c
Binary files /dev/null and b/docs/pics/IFM.png differ
diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md
index 3399bb06..a7a4eb6e 100644
--- a/docs/source/FAQ.md
+++ b/docs/source/FAQ.md
@@ -60,6 +60,7 @@ model.fit(model_input,label)
```
## 4. How to run the demo with GPU ?
+
```python
import torch
device = 'cpu'
@@ -70,3 +71,9 @@ if use_cuda and torch.cuda.is_available():
model = DeepFM(...,device=device)
```
+
+## 5. How to run the demo with multiple GPUs ?
+
+```python
+model = DeepFM(..., device=device, gpus=[0, 1])
+```
diff --git a/docs/source/Features.md b/docs/source/Features.md
index cce54b89..2aaf6787 100644
--- a/docs/source/Features.md
+++ b/docs/source/Features.md
@@ -241,6 +241,27 @@ Feature Importance and Bilinear feature Interaction NETwork is proposed to dynam
[Huang T, Zhang Z, Zhang J. FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1905.09433, 2019.](https://arxiv.org/pdf/1905.09433.pdf)
+### IFM(Input-aware Factorization Machine)
+
+Input-aware Factorization Machine (IFM) learns a unique input-aware factor for the same feature in different instances via a neural network.
+
+[**IFM Model API**](./deepctr_torch.models.ifm.html)
+
+![IFM](../pics/IFM.png)
+
+[Yu Y, Wang Z, Yuan B. An Input-aware Factorization Machine for Sparse Prediction[C]//IJCAI. 2019: 1466-1472.](https://www.ijcai.org/Proceedings/2019/0203.pdf)
+
+### DIFM(Dual Input-aware Factorization Machine)
+
+Dual Inputaware Factorization Machines (DIFM) can adaptively reweight the original feature representations at the bit-wise and vector-wise levels simultaneously.Furthermore, DIFMs strategically integrate various components including Multi-Head Self-Attention, Residual Networks and DNNs into a unified end-to-end model.
+
+[**DFM Model API**](./deepctr_torch.models.difm.html)
+
+![DIFM](../pics/DIFM.png)
+
+[Lu W, Yu Y, Chang Y, et al. A Dual Input-aware Factorization Machine for CTR Prediction[C]//IJCAI. 2020: 3139-3145.](https://www.ijcai.org/Proceedings/2020/0434.pdf)
+
+
## Layers
diff --git a/docs/source/History.md b/docs/source/History.md
index 78f4b463..eef2f07b 100644
--- a/docs/source/History.md
+++ b/docs/source/History.md
@@ -1,4 +1,5 @@
# History
+- 04/04/2021 : [v0.2.6](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.6) released.Add add [IFM](./Features.html#ifm-input-aware-factorization-machine) and [DIFM](./Features.html#difm-dual-input-aware-factorization-machine);Support multi-gpus running([example](./FAQ.html#how-to-run-the-demo-with-multiple-gpus)).
- 02/12/2021 : [v0.2.5](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.5) released.Fix bug in DCN-M.
- 12/05/2020 : [v0.2.4](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.4) released.Imporve compatibility & fix issues.Add History callback.([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)).
- 10/18/2020 : [v0.2.3](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.3) released.Add [DCN-M](./Features.html#dcn-deep-cross-network)&[DCN-Mix](./Features.html#dcn-mix-improved-deep-cross-network-with-mix-of-experts-and-matrix-kernel).Add EarlyStopping and ModelCheckpoint callbacks([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)).
diff --git a/docs/source/Models.rst b/docs/source/Models.rst
index 52d96c28..a5eeb102 100644
--- a/docs/source/Models.rst
+++ b/docs/source/Models.rst
@@ -21,3 +21,5 @@ DeepCTR-Torch Models API
ONN
FGCNN
FiBiNET
+ IFM
+ DIFM
\ No newline at end of file
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 1dd328ae..d43d0eea 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -26,7 +26,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
-release = '0.2.5'
+release = '0.2.6'
# -- General configuration ---------------------------------------------------
diff --git a/docs/source/deepctr_torch.models.difm.rst b/docs/source/deepctr_torch.models.difm.rst
new file mode 100644
index 00000000..ae16a5b7
--- /dev/null
+++ b/docs/source/deepctr_torch.models.difm.rst
@@ -0,0 +1,7 @@
+deepctr\_torch.models.difm module
+================================
+
+.. automodule:: deepctr_torch.models.difm
+ :members:
+ :no-undoc-members:
+ :no-show-inheritance:
diff --git a/docs/source/deepctr_torch.models.ifm.rst b/docs/source/deepctr_torch.models.ifm.rst
new file mode 100644
index 00000000..e625757b
--- /dev/null
+++ b/docs/source/deepctr_torch.models.ifm.rst
@@ -0,0 +1,7 @@
+deepctr\_torch.models.ifm module
+================================
+
+.. automodule:: deepctr_torch.models.ifm
+ :members:
+ :no-undoc-members:
+ :no-show-inheritance:
diff --git a/docs/source/deepctr_torch.models.rst b/docs/source/deepctr_torch.models.rst
index 25041a55..599710b6 100644
--- a/docs/source/deepctr_torch.models.rst
+++ b/docs/source/deepctr_torch.models.rst
@@ -10,6 +10,7 @@ Submodules
deepctr_torch.models.autoint
deepctr_torch.models.basemodel
deepctr_torch.models.dcn
+ deepctr_torch.models.dcnmix
deepctr_torch.models.deepfm
deepctr_torch.models.fibinet
deepctr_torch.models.mlr
@@ -20,6 +21,8 @@ Submodules
deepctr_torch.models.xdeepfm
deepctr_torch.models.din
deepctr_torch.models.dien
+ deepctr_torch.models.ifm
+ deepctr_torch.models.difm
Module contents
---------------
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 2205e11d..bc4d2b1d 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -34,13 +34,12 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR-Torch and
News
-----
+04/04/2021 : Add `IFM <./Features.html#ifm-input-aware-factorization-machine>`_ and `DIFM <./Features.html#difm-dual-input-aware-factorization-machine>`_ . Support multi-gpus running(`example <./FAQ.html#how-to-run-the-demo-with-multiple-gpus>`_). `Changelog `_
+
02/12/2021 : Fix bug in DCN-M. `Changelog `_
12/05/2020 : Imporve compatibility & fix issues.Add History callback(`example `_). `Changelog `_
-10/18/2020 : Add `DCN-M <./Features.html#dcn-deep-cross-network>`_ and `DCN-Mix <./Features.html#dcn-mix-improved-deep-cross-network-with-mix-of-experts-and-matrix-kernel>`_ . Add EarlyStopping and ModelCheckpoint callbacks(`example `_). `Changelog `_
-
-
DisscussionGroup
-----------------------
diff --git a/setup.py b/setup.py
index 4d77342e..7060df42 100644
--- a/setup.py
+++ b/setup.py
@@ -9,7 +9,7 @@
setuptools.setup(
name="deepctr-torch",
- version="0.2.5",
+ version="0.2.6",
author="Weichen Shen",
author_email="weichenswc@163.com",
description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with PyTorch",
diff --git a/tests/models/DIFM_test.py b/tests/models/DIFM_test.py
new file mode 100644
index 00000000..0960232d
--- /dev/null
+++ b/tests/models/DIFM_test.py
@@ -0,0 +1,25 @@
+# -*- coding: utf-8 -*-
+import pytest
+
+from deepctr_torch.models import DIFM
+from ..utils import get_test_data, SAMPLE_SIZE, check_model, get_device
+
+
+@pytest.mark.parametrize(
+ 'att_head_num,dnn_hidden_units,sparse_feature_num',
+ [(1, (4,), 2), (2, (4, 4,), 2), (1, (4,), 1)]
+)
+def test_DIFM(att_head_num, dnn_hidden_units, sparse_feature_num):
+ model_name = "DIFM"
+ sample_size = SAMPLE_SIZE
+ x, y, feature_columns = get_test_data(
+ sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=sparse_feature_num)
+
+ model = DIFM(linear_feature_columns=feature_columns, dnn_feature_columns=feature_columns,
+ att_head_num=att_head_num,
+ dnn_hidden_units=dnn_hidden_units, dnn_dropout=0.5, device=get_device())
+ check_model(model, model_name, x, y)
+
+
+if __name__ == "__main__":
+ pass
diff --git a/tests/models/IFM_test.py b/tests/models/IFM_test.py
new file mode 100644
index 00000000..44dd9d89
--- /dev/null
+++ b/tests/models/IFM_test.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+import pytest
+
+from deepctr_torch.models import IFM
+from ..utils import get_test_data, SAMPLE_SIZE, check_model, get_device
+
+
+@pytest.mark.parametrize(
+ 'hidden_size,sparse_feature_num',
+ [((32,), 3),
+ ((32,), 2), ((32,), 1),
+ ]
+)
+def test_IFM(hidden_size, sparse_feature_num):
+ model_name = "IFM"
+ sample_size = SAMPLE_SIZE
+ x, y, feature_columns = get_test_data(
+ sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=sparse_feature_num)
+
+ model = IFM(feature_columns, feature_columns,
+ dnn_hidden_units=hidden_size, dnn_dropout=0.5, device=get_device())
+ check_model(model, model_name, x, y)
+
+
+if __name__ == "__main__":
+ pass
diff --git a/tests/utils.py b/tests/utils.py
index d8a8d6cb..4c79631e 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -70,7 +70,7 @@ def layer_test(layer_cls, kwargs = {}, input_shape=None,
input_dtype=torch.float32, input_data=None, expected_output=None,
expected_output_shape=None, expected_output_dtype=None, fixed_batch_size=False):
'''check layer is valid or not
-
+
:param layer_cls:
:param input_shape:
:param input_dtype:
|