From 9ab3ad542941ad3ff535f974ad93dc2b950d4559 Mon Sep 17 00:00:00 2001 From: Quentin Raquet Date: Tue, 2 Jun 2020 16:10:14 +0200 Subject: [PATCH] fix: sort by cat_idx into embedding generator --- pytorch_tabnet/tab_network.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 15c1131f..8e8c4531 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -424,6 +424,12 @@ def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dim): self.post_embed_dim = int(input_dim + np.sum(self.cat_emb_dims) - len(self.cat_emb_dims)) self.embeddings = torch.nn.ModuleList() + + # Sort dims by cat_idx + sorted_idxs = np.argsort(cat_idxs) + cat_dims = [cat_dims[i] for i in sorted_idxs] + self.cat_emb_dims = [self.cat_emb_dims[i] for i in sorted_idxs] + for cat_dim, emb_dim in zip(cat_dims, self.cat_emb_dims): self.embeddings.append(torch.nn.Embedding(cat_dim, emb_dim))