-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #102 from basf/rnn_branch
include tabularRNN
- Loading branch information
Showing
4 changed files
with
495 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import torch | ||
import torch.nn as nn | ||
from ..arch_utils.mlp_utils import MLP | ||
from ..configs.tabularnn_config import DefaultTabulaRNNConfig | ||
from .basemodel import BaseModel | ||
from ..arch_utils.embedding_layer import EmbeddingLayer | ||
from ..arch_utils.normalization_layers import ( | ||
RMSNorm, | ||
LayerNorm, | ||
LearnableLayerScaling, | ||
BatchNorm, | ||
InstanceNorm, | ||
GroupNorm, | ||
) | ||
|
||
|
||
class TabulaRNN(BaseModel): | ||
def __init__( | ||
self, | ||
cat_feature_info, | ||
num_feature_info, | ||
num_classes=1, | ||
config: DefaultTabulaRNNConfig = DefaultTabulaRNNConfig(), | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) | ||
|
||
self.lr = self.hparams.get("lr", config.lr) | ||
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) | ||
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) | ||
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) | ||
self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) | ||
self.cat_feature_info = cat_feature_info | ||
self.num_feature_info = num_feature_info | ||
|
||
norm_layer = self.hparams.get("norm", config.norm) | ||
if norm_layer == "RMSNorm": | ||
self.norm_f = RMSNorm( | ||
self.hparams.get("dim_feedforward", config.dim_feedforward) | ||
) | ||
elif norm_layer == "LayerNorm": | ||
self.norm_f = LayerNorm( | ||
self.hparams.get("dim_feedforward", config.dim_feedforward) | ||
) | ||
elif norm_layer == "BatchNorm": | ||
self.norm_f = BatchNorm( | ||
self.hparams.get("dim_feedforward", config.dim_feedforward) | ||
) | ||
elif norm_layer == "InstanceNorm": | ||
self.norm_f = InstanceNorm( | ||
self.hparams.get("dim_feedforward", config.dim_feedforward) | ||
) | ||
elif norm_layer == "GroupNorm": | ||
self.norm_f = GroupNorm( | ||
1, self.hparams.get("dim_feedforward", config.dim_feedforward) | ||
) | ||
elif norm_layer == "LearnableLayerScaling": | ||
self.norm_f = LearnableLayerScaling( | ||
self.hparams.get("dim_feedforward", config.dim_feedforward) | ||
) | ||
else: | ||
self.norm_f = None | ||
|
||
rnn_layer = {"RNN": nn.RNN, "LSTM": nn.LSTM, "GRU": nn.GRU}[config.model_type] | ||
self.rnn = rnn_layer( | ||
input_size=self.hparams.get("d_model", config.d_model), | ||
hidden_size=self.hparams.get("dim_feedforward", config.dim_feedforward), | ||
num_layers=self.hparams.get("n_layers", config.n_layers), | ||
bidirectional=self.hparams.get("bidirectional", config.bidirectional), | ||
batch_first=True, | ||
dropout=self.hparams.get("rnn_dropout", config.rnn_dropout), | ||
bias=self.hparams.get("bias", config.bias), | ||
nonlinearity=( | ||
self.hparams.get("rnn_activation", config.rnn_activation) | ||
if config.model_type == "RNN" | ||
else None | ||
), | ||
) | ||
|
||
self.embedding_layer = EmbeddingLayer( | ||
num_feature_info=num_feature_info, | ||
cat_feature_info=cat_feature_info, | ||
d_model=self.hparams.get("d_model", config.d_model), | ||
embedding_activation=self.hparams.get( | ||
"embedding_activation", config.embedding_activation | ||
), | ||
layer_norm_after_embedding=self.hparams.get( | ||
"layer_norm_after_embedding", config.layer_norm_after_embedding | ||
), | ||
use_cls=False, | ||
cls_position=-1, | ||
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding), | ||
) | ||
|
||
head_activation = self.hparams.get("head_activation", config.head_activation) | ||
|
||
self.tabular_head = MLP( | ||
self.hparams.get("dim_feedforward", config.dim_feedforward), | ||
hidden_units_list=self.hparams.get( | ||
"head_layer_sizes", config.head_layer_sizes | ||
), | ||
dropout_rate=self.hparams.get("head_dropout", config.head_dropout), | ||
use_skip_layers=self.hparams.get( | ||
"head_skip_layers", config.head_skip_layers | ||
), | ||
activation_fn=head_activation, | ||
use_batch_norm=self.hparams.get( | ||
"head_use_batch_norm", config.head_use_batch_norm | ||
), | ||
n_output_units=num_classes, | ||
) | ||
|
||
self.linear = nn.Linear(config.d_model, config.dim_feedforward) | ||
|
||
def forward(self, num_features, cat_features): | ||
""" | ||
Defines the forward pass of the model. | ||
Parameters | ||
---------- | ||
num_features : Tensor | ||
Tensor containing the numerical features. | ||
cat_features : Tensor | ||
Tensor containing the categorical features. | ||
Returns | ||
------- | ||
Tensor | ||
The output predictions of the model. | ||
""" | ||
|
||
x = self.embedding_layer(num_features, cat_features) | ||
# RNN forward pass | ||
out, _ = self.rnn(x) | ||
z = self.linear(torch.mean(x, dim=1)) | ||
|
||
if self.pooling_method == "avg": | ||
x = torch.mean(out, dim=1) | ||
elif self.pooling_method == "max": | ||
x, _ = torch.max(out, dim=1) | ||
elif self.pooling_method == "sum": | ||
x = torch.sum(out, dim=1) | ||
elif self.pooling_method == "last": | ||
x = x[:, -1, :] | ||
else: | ||
raise ValueError(f"Invalid pooling method: {self.pooling_method}") | ||
x = x + z | ||
if self.norm_f is not None: | ||
x = self.norm_f(x) | ||
preds = self.tabular_head(x) | ||
|
||
return preds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from dataclasses import dataclass | ||
import torch.nn as nn | ||
|
||
|
||
@dataclass | ||
class DefaultTabulaRNNConfig: | ||
""" | ||
Configuration class for the default TabulaRNN model with predefined hyperparameters. | ||
Parameters | ||
---------- | ||
lr : float, default=1e-04 | ||
Learning rate for the optimizer. | ||
model_type : str, default="RNN" | ||
type of model, one of "RNN", "LSTM", "GRU" | ||
lr_patience : int, default=10 | ||
Number of epochs with no improvement after which learning rate will be reduced. | ||
weight_decay : float, default=1e-06 | ||
Weight decay (L2 penalty) for the optimizer. | ||
lr_factor : float, default=0.1 | ||
Factor by which the learning rate will be reduced. | ||
d_model : int, default=64 | ||
Dimensionality of the model. | ||
n_layers : int, default=8 | ||
Number of layers in the transformer. | ||
norm : str, default="RMSNorm" | ||
Normalization method to be used. | ||
activation : callable, default=nn.SELU() | ||
Activation function for the transformer. | ||
embedding_activation : callable, default=nn.Identity() | ||
Activation function for numerical embeddings. | ||
head_layer_sizes : list, default=(128, 64, 32) | ||
Sizes of the layers in the head of the model. | ||
head_dropout : float, default=0.5 | ||
Dropout rate for the head layers. | ||
head_skip_layers : bool, default=False | ||
Whether to skip layers in the head. | ||
head_activation : callable, default=nn.SELU() | ||
Activation function for the head layers. | ||
head_use_batch_norm : bool, default=False | ||
Whether to use batch normalization in the head layers. | ||
layer_norm_after_embedding : bool, default=False | ||
Whether to apply layer normalization after embedding. | ||
pooling_method : str, default="cls" | ||
Pooling method to be used ('cls', 'avg', etc.). | ||
norm_first : bool, default=False | ||
Whether to apply normalization before other operations in each transformer block. | ||
bias : bool, default=True | ||
Whether to use bias in the linear layers. | ||
rnn_activation : callable, default=nn.SELU() | ||
Activation function for the transformer layers. | ||
bidirectional : bool, default=False. | ||
Whether to process data bidirectionally | ||
cat_encoding : str, default="int" | ||
Encoding method for categorical features. | ||
""" | ||
|
||
lr: float = 1e-04 | ||
model_type: str = "RNN" | ||
lr_patience: int = 10 | ||
weight_decay: float = 1e-06 | ||
lr_factor: float = 0.1 | ||
d_model: int = 128 | ||
n_layers: int = 4 | ||
rnn_dropout: float = 0.2 | ||
norm: str = "RMSNorm" | ||
activation: callable = nn.SELU() | ||
embedding_activation: callable = nn.Identity() | ||
head_layer_sizes: list = () | ||
head_dropout: float = 0.5 | ||
head_skip_layers: bool = False | ||
head_activation: callable = nn.SELU() | ||
head_use_batch_norm: bool = False | ||
layer_norm_after_embedding: bool = False | ||
pooling_method: str = "avg" | ||
norm_first: bool = False | ||
bias: bool = True | ||
rnn_activation: str = "relu" | ||
layer_norm_eps: float = 1e-05 | ||
dim_feedforward: int = 256 | ||
numerical_embedding: str = "ple" | ||
bidirectional: bool = False | ||
cat_encoding: str = "int" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.